From 60a824bc37d174cbfd5328897f3d9d9229954e05 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 30 Jun 2026 16:14:02 -0400 Subject: [PATCH 1/5] feat: ADBC-backed Spark DataSource for DataFusion table providers Add a Spark DataSourceV2 connector that reads from a DataFusion TableProvider through a standard ADBC driver. Spark talks to the upstream arrow-adbc Java driver manager (adbc-core + adbc-driver-jni), which loads a native DataFusion ADBC cdylib and returns arrow-java ArrowReaders consumed zero-copy as ArrowColumnVectors on the cluster-provided Arrow. Spark connector (spark/): - AdbcDatafusionTableProvider: registers the `adbc-datafusion` format; schema probed once on the driver via AdbcConnection.getTableSchema. - Scan with projection/filter/limit pushed into the scan as a Substrait plan, with a SQL fallback path. - Multi-partition reads via executePartitioned/readPartition; the target_partitions option tunes scan parallelism. - Per-executor AdbcConnectionPool caches the AdbcDatabase per driver + options key; each task opens its own connection off it (connections are not shared across tasks: the arrow-adbc FFI exporter aliases &mut to one connection, which would be undefined behavior). Because the database is cached per executor and the driver's plan cache is database-scoped, the N per-task connections of one scan deserialize the physical plan once, not once per task. AdbcSourceTest asserts this via the driver's plan-deserialize counter. Example driver (examples/adbc-datafusion-driver/): a DataFusion ADBC cdylib exercising the full stack, with a PySpark end-to-end script and partitioning tests. Pinned to a driver rev for reproducible builds. Docs: user-guide page "DataFusion as a Spark data source (ADBC)" in the documentation site, covering what the connector provides, how to read from Spark, and how DataFusion's parallelism-bound partitioning maps onto Spark's byte-bound task model. Co-Authored-By: Claude Opus 4.8 --- .gitignore | 2 + .../source/user-guide/adbc-spark-connector.md | 240 +++++++++++++ docs/source/user-guide/index.md | 4 +- examples/adbc-datafusion-driver/Cargo.toml | 55 +++ examples/adbc-datafusion-driver/PARTITIONS.md | 208 ++++++++++++ examples/adbc-datafusion-driver/README.md | 113 +++++++ .../adbc-datafusion-driver/pyspark_e2e.py | 104 ++++++ examples/adbc-datafusion-driver/src/lib.rs | 83 +++++ .../adbc-datafusion-driver/src/provider.rs | 112 ++++++ .../tests/partitions.rs | 68 ++++ pom.xml | 1 + spark/pom.xml | 182 ++++++++++ .../spark/AdbcColumnarPartitionReader.java | 163 +++++++++ .../datafusion/spark/AdbcConnectionPool.java | 317 +++++++++++++++++ .../spark/AdbcDatafusionTableProvider.java | 76 +++++ .../datafusion/spark/AdbcInputPartition.java | 60 ++++ .../apache/datafusion/spark/AdbcOptions.java | 156 +++++++++ .../spark/AdbcPartitionReaderFactory.java | 52 +++ .../datafusion/spark/AdbcScanBuilder.java | 110 ++++++ .../apache/datafusion/spark/AdbcScanImpl.java | 228 +++++++++++++ .../apache/datafusion/spark/AdbcTable.java | 63 ++++ .../datafusion/spark/SchemaConverter.java | 81 +++++ .../org/apache/datafusion/spark/SqlQuery.java | 123 +++++++ .../datafusion/spark/SubstraitPlan.java | 319 ++++++++++++++++++ ...pache.spark.sql.sources.DataSourceRegister | 1 + .../spark/AdbcConnectionPoolTest.java | 255 ++++++++++++++ .../datafusion/spark/AdbcOptionsTest.java | 83 +++++ .../datafusion/spark/AdbcSourceTest.java | 206 +++++++++++ .../datafusion/spark/SubstraitPlanTest.java | 102 ++++++ 29 files changed, 3566 insertions(+), 1 deletion(-) create mode 100644 docs/source/user-guide/adbc-spark-connector.md create mode 100644 examples/adbc-datafusion-driver/Cargo.toml create mode 100644 examples/adbc-datafusion-driver/PARTITIONS.md create mode 100644 examples/adbc-datafusion-driver/README.md create mode 100644 examples/adbc-datafusion-driver/pyspark_e2e.py create mode 100644 examples/adbc-datafusion-driver/src/lib.rs create mode 100644 examples/adbc-datafusion-driver/src/provider.rs create mode 100644 examples/adbc-datafusion-driver/tests/partitions.rs create mode 100644 spark/pom.xml create mode 100644 spark/src/main/java/org/apache/datafusion/spark/AdbcColumnarPartitionReader.java create mode 100644 spark/src/main/java/org/apache/datafusion/spark/AdbcConnectionPool.java create mode 100644 spark/src/main/java/org/apache/datafusion/spark/AdbcDatafusionTableProvider.java create mode 100644 spark/src/main/java/org/apache/datafusion/spark/AdbcInputPartition.java create mode 100644 spark/src/main/java/org/apache/datafusion/spark/AdbcOptions.java create mode 100644 spark/src/main/java/org/apache/datafusion/spark/AdbcPartitionReaderFactory.java create mode 100644 spark/src/main/java/org/apache/datafusion/spark/AdbcScanBuilder.java create mode 100644 spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java create mode 100644 spark/src/main/java/org/apache/datafusion/spark/AdbcTable.java create mode 100644 spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java create mode 100644 spark/src/main/java/org/apache/datafusion/spark/SqlQuery.java create mode 100644 spark/src/main/java/org/apache/datafusion/spark/SubstraitPlan.java create mode 100644 spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister create mode 100644 spark/src/test/java/org/apache/datafusion/spark/AdbcConnectionPoolTest.java create mode 100644 spark/src/test/java/org/apache/datafusion/spark/AdbcOptionsTest.java create mode 100644 spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java create mode 100644 spark/src/test/java/org/apache/datafusion/spark/SubstraitPlanTest.java diff --git a/.gitignore b/.gitignore index 25c9216..c6a4806 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ rust-target/ *.iml .DS_Store tpch-data/ +# Lockfile for the example ADBC driver crate (a demo cdylib, not a published lib). +examples/adbc-datafusion-driver/Cargo.lock .claude docs/superpowers docs/build/ diff --git a/docs/source/user-guide/adbc-spark-connector.md b/docs/source/user-guide/adbc-spark-connector.md new file mode 100644 index 0000000..601fd69 --- /dev/null +++ b/docs/source/user-guide/adbc-spark-connector.md @@ -0,0 +1,240 @@ + + +# DataFusion as a Spark data source (ADBC) + +This repository ships a Spark connector that lets a Spark job read from a +DataFusion `TableProvider` as a native Spark `DataSourceV2`. Spark sees an +ordinary external table; the data is produced by DataFusion in native code and +handed back as Apache Arrow batches. + +## What this repository provides + +Two pieces work together: + +- **The `adbc-datafusion` Spark connector** (`spark/`) — a Spark + `DataSourceV2` registered under the format name `adbc-datafusion`. It builds + the scan, pushes projection / filters / limit down, maps DataFusion's scan + partitions onto Spark input partitions, and imports each partition's Arrow + data into Spark **zero-copy** as an `ArrowColumnVector`. +- **An example DataFusion ADBC driver** (`examples/adbc-datafusion-driver/`) — + a small native library (cdylib) that exposes a DataFusion `TableProvider` + through the [ADBC](https://arrow.apache.org/adbc/) C API. It is a worked + example you can copy: your own driver registers whatever provider you want and + is loaded the same way. + +The connector is built on the standard Apache Arrow ADBC Java driver manager +(`adbc-core` + `adbc-driver-jni`), so it talks to the native driver through a +stable, public interface. Because that boundary is plain ADBC, the same +connector can front **any** ADBC-speaking source; the DataFusion driver shown +here is just one example, and you can point it at a different ADBC driver +without changing the connector. + +```text +Spark DataSourceV2 (adbc-datafusion) + → arrow-adbc Java driver manager (adbc-core + adbc-driver-jni) + → native ADBC driver (the DataFusion cdylib) + → your DataFusion TableProvider +``` + +## Reading from Spark + +Point the `driver` option at the built driver library and the `table` option at +a table the provider exposes: + +```scala +val df = spark.read + .format("adbc-datafusion") + .option("driver", "/abs/path/to/libadbc_datafusion_example_driver.so") + .option("entrypoint", "AdbcDatafusionExampleInit") + .option("table", "example") + .load() + +df.filter("id > 1").select("name").show() +``` + +The connector probes the schema once (on the driver) when the `DataFrame` is +created, then plans the scan. `df.filter(...)` and `df.select(...)` are pushed +into the DataFusion scan where possible (see below); the rest Spark evaluates +itself. + +## Two partitioning models that must be reconciled + +The heart of the connector is mapping DataFusion's idea of a partitioned scan +onto Spark's idea of tasks. The two engines size scans on **different +principles**, and a provider author who ignores the difference will get a job +that runs but performs badly. + +### DataFusion: parallelism-bound + +DataFusion partitions a scan to **saturate the cores of one machine**. Its +built-in `ListingTable`, for example, packs input files into roughly +`target_partitions` groups — where `target_partitions` defaults to the local +core count — using a minimum-size floor and optional single-file splitting only +to *reach* that count. The partition count is therefore approximately + +```text +N ≈ target_partitions ≈ cores +``` + +independent of how much data there is. More data means **bigger** partitions, +not more of them. + +### Spark: one task per partition, byte-bound + +Spark turns **each input partition into one task**, and a native Spark file +source sizes partitions by **bytes**: it targets roughly one partition per +`spark.sql.files.maxPartitionBytes` (default 128 MB), cutting files into splits +and bin-packing them to that size. (Spark shrinks the target below 128 MB when +the data is small, so that there are at least enough partitions to keep every +core busy.) + +So large data means **many** partitions — typically far more than the cluster +has cores — which run in waves. Each task processes a bounded slice (≈128 MB), +which is what keeps memory bounded and lets Spark reschedule stragglers. + +### How the connector joins them + +For a scan that only pushes projection / filter / limit (no DataFusion-side +join or aggregation), the physical plan's output partitioning **is** the +provider's `scan()` partitioning, and the mapping to Spark is **1:1**: + +```text +provider.scan() output partitions + → DataFusion physical plan partition count N + → N ADBC partition descriptors + → N Spark input partitions = N Spark tasks (scan stage) +``` + +Spark never splits or merges a partition afterward — `df.rdd().getNumPartitions() +== N`, and the cluster core count only bounds how many of the `N` run at once. + +**The consequence you must plan for:** suppose a provider keeps DataFusion's +default of `N = target_partitions ≈ cores` partitions. On a large dataset that +produces only a handful of partitions, each holding about `total_bytes / cores` +of data — so Spark runs a few very large tasks. Because a task is the unit of +work, those oversized tasks bring memory pressure, stragglers that hold up the +whole stage, and no way for Spark to rebalance. To feed Spark well, a provider +should size partitions by **bytes** — the way a Spark file source does — rather +than by the `target_partitions` count. + +## Sizing the scan + +Aim for the Spark sweet spot: + +- **Floor:** `N` ≥ total executor cores, or cores sit idle. `~2–4× cores` gives + slack for skew and stragglers. +- **Per-partition size:** big enough to amortize per-task overhead — the + canonical target is **128–256 MB** of data per partition (matching Spark's + `spark.sql.files.maxPartitionBytes` default of 128 MB). +- **Ceiling — and ADBC lowers it.** Beyond Spark's generic per-task cost, this + path adds two per-partition costs: each of the `N` ADBC descriptors carries a + copy of the **whole serialized physical plan**, and each task **deserializes** + that plan when it reads its partition. (The driver caches a deserialized plan + per connection, but that does not help here: each Spark task uses its own + connection and reads a single partition, so the plan is deserialized once per + task and never reused.) Tens of thousands of tiny partitions is far costlier + here than for a native file source. Keep `N` in the hundreds; prefer bigger + partitions over a huge count. + +Rule of thumb, balanced by **bytes** rather than split count: + +```text +N ≈ clamp(total_bytes / target_bytes, floor=cores, ceiling≈hundreds) +``` + +## Writing a provider that feeds Spark well + +Your provider's `scan()` receives the session, so it can read the connector's +parallelism hint and then decide based on what it knows about the data: + +```text +T = state.config().target_partitions() // connector hint ≈ desired parallelism +target_bytes = ~128–256 MB // bias larger for ADBC to keep N down +min_bytes = floor so partitions never get tiny // ~ tens of MB + +if total_bytes known AND splittable (files / row groups / key ranges): + N = clamp(ceil(total_bytes / target_bytes), 1, CEILING) // byte-bound + bin-pack splits into N groups BALANCED BY BYTES // not by count + do not split below min_bytes; merge small splits + +elif split_count known but not bytes (shards / external partitions): + N = min(split_count, CEILING) // natural partitioning + if split_count >> T: coalesce shards into ~T balanced groups + +else (size and splits unknown; opaque / streaming): + N = T // fall back to the hint + lean slightly high — stragglers hurt more than overhead +``` + +- **Balance by bytes, not by split count.** A task is atomic; Spark cannot + rebalance a fat partition. One 10 GB partition among 100 small ones stalls the + whole stage. Bin-pack so partition bytes are even. +- **`target_partitions` is a hint, cap, and fallback** — not the data-bound + count. Use it when bytes are unknown or as a parallelism ceiling; when you know + the data size, let **bytes** drive `N`. + +### Planning happens once, on the driver + +Your provider's `scan()` runs **once**, on the driver, while the multi-partition +descriptors are built. The driver plans the query, fixes the partitioning, and +**serializes the whole physical plan** into each descriptor. Each executor task +then deserializes that plan and executes its partition index. So partition +`index i` always means the same slice: there is no second planning pass that +could disagree with the first, and no need to make `N` stable across a +driver/executor boundary. + +The one requirement this places on a provider is that its plan be +**serializable**: built-in DataFusion nodes round-trip through the default codec, +but a custom `ExecutionPlan` node needs a `PhysicalExtensionCodec` registered +with the driver so it can be encoded on the driver and decoded on the executor. + +## Connector options + +| Option | Required | Meaning | +| --- | --- | --- | +| `driver` | yes | Path to (or manifest name of) the native ADBC driver library. | +| `entrypoint` | depends on driver | Name of the driver's C init symbol (e.g. `AdbcDatafusionExampleInit`). | +| `table` | yes | Table the provider exposes; its schema is probed on the driver. | +| `target_partitions` | no | Parallelism hint. Defaults to `SparkContext.defaultParallelism` (total executor cores). The connector issues `SET datafusion.execution.target_partitions = N` on the planning session, so it influences the partition count when the driver builds the physical plan; that plan (with its partitioning fixed) is what gets serialized into the descriptors. Pass `k × cores` to raise parallelism. | +| `manifest.path` | no | Extra search path for ADBC driver manifests when `driver` is a name rather than an absolute path. | +| *(anything else)* | no | Forwarded verbatim as native ADBC database options, so provider-specific knobs pass straight through. | + +Two things are intentionally **not** connector options: + +- **`maxPartitionBytes`** — byte-sizing lives in the provider's `scan()`; the + connector does not know the data size and cannot derive it. Pass a sizing knob + as a provider-specific option if your provider supports one. +- **Reported partitioning / storage-partitioned joins + (`SupportsReportPartitioning`)** — not available over ADBC today: the + descriptor carries no distribution metadata, and DataFusion's hash + partitioning does not match Spark's key-grouped / bucketed model. + +## Summary + +- The connector maps the provider's `scan()` partition count **1:1** onto Spark + tasks. +- DataFusion sizes scans to **cores** (parallelism-bound); Spark sizes them to + **bytes** (one task per ~128 MB). To feed Spark well, size your provider's + partitions by bytes — typically 128–256 MB each, floored at executor cores and + capped in the hundreds because every descriptor carries the full serialized + plan and each task deserializes it. +- Planning happens once on the driver; the physical plan is serialized into each + descriptor and executors run it as-is, so a custom `ExecutionPlan` node just + needs a registered `PhysicalExtensionCodec` to be serializable. diff --git a/docs/source/user-guide/index.md b/docs/source/user-guide/index.md index a8d42be..68351b0 100644 --- a/docs/source/user-guide/index.md +++ b/docs/source/user-guide/index.md @@ -26,7 +26,8 @@ DataFrame queries execute in native Rust; results return to the JVM as Data Interface. This guide covers installation, the `SessionContext` and `DataFrame` APIs, -and Parquet ingestion. +Parquet ingestion, table providers, and using DataFusion as a Spark data +source over ADBC. ```{toctree} :maxdepth: 1 @@ -39,6 +40,7 @@ parquet proto-plans scalar-udf table-provider +adbc-spark-connector api-reference ``` diff --git a/examples/adbc-datafusion-driver/Cargo.toml b/examples/adbc-datafusion-driver/Cargo.toml new file mode 100644 index 0000000..58f9d81 --- /dev/null +++ b/examples/adbc-datafusion-driver/Cargo.toml @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "adbc-datafusion-example-driver" +version = "0.1.0" +edition = "2021" +publish = false +description = "End-to-end example: an ADBC driver that exposes a custom DataFusion TableProvider" + +# Standalone crate that stands in for a real, separate driver repo. The empty +# [workspace] table makes it its own workspace root so it is NOT pulled into the +# parent datafusion-java workspace (whose members are listed explicitly). +[workspace] + +[dependencies] +# The ready-made ADBC-over-DataFusion driver. We depend on it as a LIBRARY and +# only customize the SessionContext via its new_with_context_init hook -- no +# fork, no datafusion-ffi module loading. Pulled from its own repo (git) since +# that is where it lives. +adbc-driver-datafusion = { git = "https://github.com/adbc-drivers/datafusion" } +adbc_core = "0.23" +adbc_ffi = "0.23" +datafusion = "53" +async-trait = "0.1" + +# Build both an rlib (so the provider/driver can be unit-tested or embedded) and +# a cdylib (the artifact the arrow-adbc Java driver manager loads at runtime). +[lib] +crate-type = ["cdylib", "lib"] + +# Temporary: redirect adbc-driver-datafusion to the PR branch that implements +# execute_partitions / read_partition plus the shared (database-scoped) plan +# cache (adbc-drivers/datafusion#32), so we can run a full multi-partition +# system test before it merges. The crate is not on crates.io, so this patches +# the git source (not [patch.crates-io]). Pinned to a rev (not the branch) so +# builds are reproducible and the plan-cache assertion in AdbcSourceTest is +# stable. Remove once the PR is merged and the canonical dependency carries the +# change. +[patch."https://github.com/adbc-drivers/datafusion"] +adbc-driver-datafusion = { git = "https://github.com/timsaucer/adbc-driver-datafusion", rev = "9b51df2b13d20d81c563ef3ee818db1c02082fac" } diff --git a/examples/adbc-datafusion-driver/PARTITIONS.md b/examples/adbc-datafusion-driver/PARTITIONS.md new file mode 100644 index 0000000..1a2fe47 --- /dev/null +++ b/examples/adbc-datafusion-driver/PARTITIONS.md @@ -0,0 +1,208 @@ + + +# Multi-partition execution: upstream driver patch + +The Spark connector in this repo can spread a scan across executors using ADBC's +partitioned-execution API (`AdbcStatement.executePartitioned()` → +`AdbcConnection.readPartition(descriptor)`). The Java JNI driver manager binds +both calls, but the **DataFusion driver does not implement them yet**: + +```rust +// adbc-driver-datafusion/src/lib.rs (today) +fn execute_partitions(&mut self) -> Result { + Err(ErrorHelper::not_implemented().message("execute_partitions").to_adbc()) +} +fn read_partition(&self, _partition: impl AsRef<[u8]>) -> Result<...> { + Err(ErrorHelper::not_implemented().message("read_partition").to_adbc()) +} +``` + +So the connector calls `executePartitioned()`, catches `NOT_IMPLEMENTED`, and +falls back to a single `executeQuery()` partition. To get real parallelism, the +driver must produce one ADBC partition per DataFusion output partition. This is +an upstream change to [`adbc-driver-datafusion`](https://github.com/adbc-drivers/datafusion); +the patch below is PR-ready against the structures in its `src/lib.rs`. Until it +lands, point this example's git dependency at a branch carrying it. + +## Descriptor: self-contained, by construction + +A partition descriptor is shipped to another process and handed back via +`read_partition` with no live handle. It must therefore reconstruct the work on +its own. We encode three things: + +``` +[u32 LE target_partitions][u32 LE partition_index][u8 kind][query bytes...] +kind 0 = Substrait plan (prost-encoded substrait.proto.Plan) +kind 1 = SQL text (utf-8) +``` + +The `query bytes` rebuild the logical plan in the executor's `SessionContext` +(which has the same providers registered via `ContextInit`), and +`partition_index` selects the slice. + +## The correctness subtlety: deterministic partitioning + +`read_partition` re-plans from the query, so partition `i` must mean the **same** +slice it meant when `execute_partitions` counted the partitions. DataFusion +planning is deterministic given the same plan **and the same +`target_partitions`** — but that defaults to the machine's CPU count, which +differs between the driver host and an executor. Left unpinned, the executor +could re-plan into a different partition count and read the wrong (or an +out-of-range) slice. + +Fix: capture `target_partitions` at `execute_partitions` time, encode it in every +descriptor, and re-plan in `read_partition` with that value pinned. Both sides +then build the identical physical plan. + +## Patch + +```rust +// --- new: descriptor codec ------------------------------------------------- +use prost::Message; // already a dependency + +enum DescQuery { Substrait(Vec), Sql(String) } + +fn encode_query(q: &QueryState) -> adbc_core::error::Result<(u8, Vec)> { + match q { + QueryState::Substrait(plan) => Ok((0, plan.encode_to_vec())), + QueryState::Sql(sql) => Ok((1, sql.clone().into_bytes())), + QueryState::Prepared(_) => Err(ErrorHelper::not_implemented() + .message("partitioned execution of prepared statements") + .to_adbc()), + } +} + +fn encode_descriptor(target_partitions: u32, index: u32, kind: u8, query: &[u8]) -> Vec { + let mut out = Vec::with_capacity(9 + query.len()); + out.extend_from_slice(&target_partitions.to_le_bytes()); + out.extend_from_slice(&index.to_le_bytes()); + out.push(kind); + out.extend_from_slice(query); + out +} + +fn decode_descriptor(bytes: &[u8]) -> adbc_core::error::Result<(usize, u32, DescQuery)> { + if bytes.len() < 9 { + return Err(ErrorHelper::invalid_argument().message("short partition descriptor").to_adbc()); + } + let target_partitions = u32::from_le_bytes(bytes[0..4].try_into().unwrap()) as usize; + let index = u32::from_le_bytes(bytes[4..8].try_into().unwrap()); + let query = match bytes[8] { + 0 => DescQuery::Substrait(bytes[9..].to_vec()), + 1 => DescQuery::Sql(String::from_utf8(bytes[9..].to_vec()) + .map_err(|e| ErrorHelper::invalid_argument().message(e.to_string()).to_adbc())?), + other => return Err(ErrorHelper::invalid_argument() + .format(format_args!("unknown descriptor kind {other}")).to_adbc()), + }; + Ok((target_partitions, index, query)) +} + +// --- DataFusionReader: construct from a single-partition stream ------------ +impl DataFusionReader { + pub fn from_stream( + runtime: Arc, + stream: datafusion::execution::SendableRecordBatchStream, + schema: SchemaRef, + ) -> Self { + Self { runtime, stream, schema } + } +} + +// --- DataFusionStatement::execute_partitions ------------------------------- +fn execute_partitions(&mut self) -> adbc_core::error::Result { + let query = self.query.as_ref().ok_or_else(|| { + ErrorHelper::invalid_state().message("no query or Substrait plan has been set").to_adbc() + })?; + self.runtime.block_on(async { + let df = query.execute(&self.ctx).await?; // registers object store, plans logically + let schema = df.schema().as_arrow().clone(); + let physical = df.create_physical_plan().await.map_err(ErrorHelper::from_datafusion)?; + let n = physical.output_partitioning().partition_count() as u32; + let target_partitions = self.ctx.copied_config().target_partitions() as u32; + let (kind, query_bytes) = encode_query(query)?; + let partitions = (0..n) + .map(|i| encode_descriptor(target_partitions, i, kind, &query_bytes)) + .collect::>(); + Ok(adbc_core::PartitionedResult { partitions, schema, rows_affected: -1 }) + }) +} + +// --- DataFusionConnection::read_partition ---------------------------------- +fn read_partition( + &self, + partition: impl AsRef<[u8]>, +) -> adbc_core::error::Result> { + let (target_partitions, index, query) = decode_descriptor(partition.as_ref())?; + self.runtime.block_on(async { + // Pin target_partitions so the physical plan matches execute_partitions'. + let state = datafusion::execution::session_state::SessionStateBuilder::new_from_existing( + self.ctx.state(), + ) + .with_config(self.ctx.copied_config().with_target_partitions(target_partitions)) + .build(); + + let plan = match query { + DescQuery::Substrait(bytes) => { + let proto = Plan::decode(bytes.as_slice()) + .map_err(|e| ErrorHelper::invalid_argument().message(e.to_string()).to_adbc())?; + from_substrait_plan(&state, &proto).await.map_err(ErrorHelper::from_datafusion)? + } + DescQuery::Sql(sql) => { + state.create_logical_plan(&sql).await.map_err(ErrorHelper::from_datafusion)? + } + }; + register_object_store_for_plan(&self.ctx, &plan).await.map_err(ErrorHelper::from_datafusion)?; + let schema = plan.schema().as_arrow().clone(); + let physical = state.create_physical_plan(&plan).await.map_err(ErrorHelper::from_datafusion)?; + let stream = physical + .execute(index as usize, state.task_ctx()) + .map_err(ErrorHelper::from_datafusion)?; + Ok(Box::new(DataFusionReader::from_stream(self.runtime.clone(), stream, schema.into())) + as Box) + }) +} +``` + +## How the connector consumes it (already implemented here) + +`AdbcScanImpl.planInputPartitions` (driver side): + +1. build the pushed Substrait plan, `setSubstraitPlan`, call `executePartitioned()`; +2. on success, emit one `AdbcInputPartition` per `PartitionDescriptor` + (`descriptor == true`, payload = the descriptor bytes); +3. on `NOT_IMPLEMENTED` / `NOT_FOUND`, emit a single plan-carrying partition + (`descriptor == false`). + +`AdbcColumnarPartitionReader` (executor side): for a descriptor partition it calls +`connection.readPartition(payload)`; otherwise `setSubstraitPlan` + `executeQuery`. +Both yield an `ArrowReader` wrapped zero-copy into `ArrowColumnVector`s. + +So once this patch lands upstream, the connector lights up multi-partition with +no further change on the Java side. + +## Caveats / follow-ups + +- `Prepared` statements aren't covered (no portable plan serialization without + `datafusion-proto`); they return `NOT_IMPLEMENTED`, and the connector falls + back to single-partition. The connector only sends Substrait, so this is fine. +- `register_object_store_for_plan` is applied in `read_partition` too, so + object-store-backed providers work on the executor. +- Verify `output_partitioning().partition_count()` against the target source; for + some providers a `RepartitionExec` may be needed to get N > 1. diff --git a/examples/adbc-datafusion-driver/README.md b/examples/adbc-datafusion-driver/README.md new file mode 100644 index 0000000..1712b99 --- /dev/null +++ b/examples/adbc-datafusion-driver/README.md @@ -0,0 +1,113 @@ + + +# Example: an ADBC driver for a custom DataFusion provider + +End-to-end demonstration of the ADBC route ("Part A"): expose a custom +DataFusion `TableProvider` as a standard ADBC driver, then read it from the +`adbc-datafusion` Spark connector in this repo. + +This crate stands in for a **separate driver repo**. It is its own Cargo +workspace (note the empty `[workspace]` in `Cargo.toml`) and is *not* part of +the parent `datafusion-java` workspace. + +## What it does + +- `src/provider.rs` -- `ExampleTableProvider`, an ordinary `TableProvider` + exposing one fixed table `example(id BIGINT, name STRING)`. Replace this with + your real provider. +- `src/lib.rs` -- `ExampleDriver`, a thin wrapper over + [`adbc-driver-datafusion`](https://github.com/adbc-drivers/datafusion)'s + `DataFusionDriver`. Its `Default` builds the inner driver with a + `new_with_context_init` hook that registers the provider into every session. + `adbc_ffi::export_driver!` emits the C `AdbcDatafusionExampleInit` entrypoint. + +No fork of the driver, no `datafusion-ffi` module loading -- the provider is +compiled in. + +## Build + +```bash +cargo build --release +# produces target/release/libadbc_datafusion_example_driver.{so,dylib,dll} +``` + +## Use from the Spark connector + +The `adbc-datafusion` Spark DataSource in this repo loads the cdylib through the +arrow-adbc Java driver manager. Point the `driver` option at the built library +and the `table` option at the provider's table: + +```scala +val df = spark.read + .format("adbc-datafusion") + .option("driver", "/abs/path/to/libadbc_datafusion_example_driver.so") + .option("table", "example") + .load() + +df.show() +// +---+-----+ +// | id| name| +// +---+-----+ +// | 1|alice| +// | 2| bob| +// | 3|carol| +// +---+-----+ +``` + +`option("target_partitions", N)` tunes scan parallelism: the connector issues +`SET datafusion.execution.target_partitions = N` on the planning session, and the +driver pins it into each partition descriptor so executors re-plan identically. +It defaults to the cluster parallelism (`SparkContext.defaultParallelism`). +Repartition-aware providers (e.g. file scans) use it to choose the partition +count; fixed-partition providers (like this in-memory example) keep their +intrinsic count. + +Any other `option(...)` keys (besides `driver`, `table`, `target_partitions`) are +forwarded verbatim as ADBC database options, so a real provider can be configured +through them. + +## Use from any ADBC client + +The same cdylib works with any ADBC driver manager, e.g. Python: + +```python +import adbc_driver_manager.dbapi as dbapi + +with dbapi.connect( + driver="/abs/path/to/libadbc_datafusion_example_driver.so", + entrypoint="AdbcDatafusionExampleInit", +) as conn, conn.cursor() as cur: + cur.execute("SELECT * FROM example") + print(cur.fetch_arrow_table()) +``` + +## Notes + +- `datafusion` / `arrow` versions are pinned to match `adbc-driver-datafusion` + (datafusion 53, arrow 58) so the shared provider type is ABI-compatible across + the driver and any in-repo Rust that registers the same provider. +- Pushdown: ADBC clients send SQL or a Substrait plan; DataFusion's optimizer + applies projection/filter/limit against the provider. The Spark connector + pushes a Substrait plan with projection/filter/limit folded in. +- Multi-partition: with the `execute_partitions` / `read_partition` driver + support (adbc-drivers/datafusion#32) and arrow-adbc >= 0.24's JNI bridge, the + Spark connector reads one partition per provider partition; against older + arrow-adbc it degrades to a single partition. See + [PARTITIONS.md](PARTITIONS.md) for the driver-side code. diff --git a/examples/adbc-datafusion-driver/pyspark_e2e.py b/examples/adbc-datafusion-driver/pyspark_e2e.py new file mode 100644 index 0000000..6cbf80f --- /dev/null +++ b/examples/adbc-datafusion-driver/pyspark_e2e.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""End-to-end PySpark test of the adbc-datafusion connector + example driver. + +PySpark drives the same JVM Spark, so the connector is used exactly as from +Scala/Java: spark.read.format("adbc-datafusion"). No Python-side connector code. + +Prereqs: + - arrow-adbc `main` (0.24.0-SNAPSHOT) built into ~/.m2 (Substrait + partitions), + - the connector jar packaged (mvn -pl spark -Padbc-snapshot package), + - its runtime deps copied to a dir (mvn dependency:copy-dependencies), + - arrow-c-data (matching Spark's bundled Arrow, e.g. 18.1.0) added to that dir + -- vanilla Spark bundles arrow-vector/arrow-memory but NOT arrow-c-data, + - the example cdylib built (cargo build --release). + +Configure via env vars (defaults assume this repo's build layout): + CONNECTOR_JAR path to datafusion-spark-*.jar + ADBC_JARS_DIR dir of runtime dep jars (adbc-*, substrait core, datafusion-java, ...) + ADBC_DRIVER_LIB path to libadbc_datafusion_example_driver.{dylib,so} +""" + +import glob +import os +import sys + +from pyspark.sql import SparkSession + +REPO = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +LIBEXT = "dylib" if sys.platform == "darwin" else "so" + +connector_jar = os.environ.get( + "CONNECTOR_JAR", + glob.glob(os.path.join(REPO, "spark/target/datafusion-spark-*.jar"))[0], +) +jars_dir = os.environ.get("ADBC_JARS_DIR", "/tmp/adbc-libs") +driver_lib = os.environ.get( + "ADBC_DRIVER_LIB", + os.path.join(REPO, f"rust-target/release/libadbc_datafusion_example_driver.{LIBEXT}"), +) + +jars = [connector_jar] + glob.glob(os.path.join(jars_dir, "*.jar")) +assert os.path.exists(driver_lib), f"driver cdylib not found: {driver_lib}" + +# Java 17 opens Spark/Arrow need (Spark 4.0 sets most; be explicit for the data path). +opens = " ".join( + f"--add-opens=java.base/{p}=ALL-UNNAMED" + for p in ("java.nio", "sun.nio.ch", "java.lang", "java.util") +) + +spark = ( + SparkSession.builder.master("local[2]") + .appName("adbc-datafusion-pyspark-e2e") + .config("spark.jars", ",".join(jars)) + .config("spark.driver.extraJavaOptions", opens) + .config("spark.executor.extraJavaOptions", opens) + .getOrCreate() +) + +try: + df = ( + spark.read.format("adbc-datafusion") + .option("driver", driver_lib) + .option("entrypoint", "AdbcDatafusionExampleInit") + .option("table", "example") + .load() + ) + + print("=== schema ===") + df.printSchema() + + num_partitions = df.rdd.getNumPartitions() + print("numPartitions:", num_partitions) + + ids = sorted(r["id"] for r in df.collect()) + names = sorted(r["name"] for r in df.select("name").collect()) + filtered = sorted(r["id"] for r in df.filter("id > 1").collect()) + print("full scan ids:", ids) + print("projection names:", names) + print("filter id>1:", filtered) + + assert ids == [1, 2, 3], ids + assert names == ["alice", "bob", "carol"], names + assert filtered == [2, 3], filtered + assert num_partitions >= 2, f"expected multi-partition, got {num_partitions}" + + print("\nPYSPARK E2E OK (multi-partition + projection + filter pushdown)") +finally: + spark.stop() diff --git a/examples/adbc-datafusion-driver/src/lib.rs b/examples/adbc-datafusion-driver/src/lib.rs new file mode 100644 index 0000000..f57556d --- /dev/null +++ b/examples/adbc-datafusion-driver/src/lib.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An end-to-end ADBC driver for a specific, compiled-in DataFusion provider. +//! +//! This is "Part A" of the ADBC route: rather than fork [`adbc-driver-datafusion`] +//! or load providers over `datafusion-ffi`, we depend on it as a library and use +//! its [`DataFusionDriver::new_with_context_init`] hook to register our own +//! [`ExampleTableProvider`] into every database's `SessionContext`. The provider +//! is statically linked; ADBC clients (the arrow-adbc Java driver manager, the +//! Spark connector, Python `adbc_driver_manager`, ...) then query it by SQL or +//! Substrait. +//! +//! [`export_driver!`] emits the C `AdbcDriverInit` entrypoint a driver manager +//! dlopen's. Build the cdylib (`cargo build --release`) and point an ADBC client +//! at the resulting shared library. + +mod provider; + +use std::sync::Arc; + +use adbc_core::error::Result; +use adbc_core::options::{OptionDatabase, OptionValue}; +use adbc_core::Driver; +use adbc_driver_datafusion::{ContextInit, DataFusionDatabase, DataFusionDriver}; +use datafusion::prelude::SessionContext; + +pub use provider::ExampleTableProvider; + +/// An ADBC driver that registers [`ExampleTableProvider`] into each session. +/// +/// A thin wrapper over [`DataFusionDriver`]: it exists only so that `Default` +/// (which [`export_driver!`] calls) constructs the inner driver with our +/// `ContextInit` hook. All `Driver` behavior is delegated. +pub struct ExampleDriver(DataFusionDriver); + +impl Default for ExampleDriver { + fn default() -> Self { + let init: ContextInit = Arc::new(|ctx: &mut SessionContext, _opts| { + // Register our provider. A real driver would read connection + // options out of `_opts` (DatabaseOpts) to decide what to expose. + ctx.register_table( + ExampleTableProvider::TABLE_NAME, + Arc::new(ExampleTableProvider::new()), + )?; + Ok(()) + }); + ExampleDriver(DataFusionDriver::new_with_context_init(None, init)) + } +} + +impl Driver for ExampleDriver { + type DatabaseType = DataFusionDatabase; + + fn new_database(&mut self) -> Result { + self.0.new_database() + } + + fn new_database_with_opts( + &mut self, + opts: impl IntoIterator, + ) -> Result { + self.0.new_database_with_opts(opts) + } +} + +// Export the C ADBC entrypoint. The arrow-adbc driver manager resolves this +// symbol after dlopen'ing the cdylib. +adbc_ffi::export_driver!(AdbcDatafusionExampleInit, ExampleDriver); diff --git a/examples/adbc-datafusion-driver/src/provider.rs b/examples/adbc-datafusion-driver/src/provider.rs new file mode 100644 index 0000000..175b59a --- /dev/null +++ b/examples/adbc-datafusion-driver/src/provider.rs @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The "custom" `TableProvider` this example exposes. +//! +//! In a real deployment this is *your* provider crate -- reading your store, +//! your format, your catalog. Here it is a tiny fixed table so the example is +//! self-contained; the only thing that matters for the integration is that it +//! is an ordinary [`datafusion::catalog::TableProvider`]. Execution is delegated +//! to an in-memory [`MemTable`]; swap [`ExampleTableProvider::scan`] for your +//! real source and the rest of the pipeline (ADBC driver, Spark connector) is +//! unchanged. + +use std::any::Any; +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::arrow::array::{Int64Array, StringArray}; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::datasource::MemTable; +use datafusion::error::Result; +use datafusion::logical_expr::TableType; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::Expr; + +/// A custom provider exposing a single fixed table named [`Self::TABLE_NAME`]. +#[derive(Debug)] +pub struct ExampleTableProvider { + inner: MemTable, +} + +impl ExampleTableProvider { + /// The name the provider is registered under; the ADBC client scans this. + pub const TABLE_NAME: &'static str = "example"; + + pub fn new() -> Self { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + // Three partitions (one row each) so the table has a non-trivial output + // partitioning -- this is what makes ADBC execute_partitions / + // read_partition return more than one partition, exercising distributed + // reads. + let rows = [(1_i64, "alice"), (2, "bob"), (3, "carol")]; + let partitions: Vec> = rows + .iter() + .map(|(id, name)| { + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(vec![*id])), + Arc::new(StringArray::from(vec![*name])), + ], + ) + .expect("example record batch"); + vec![batch] + }) + .collect(); + let inner = MemTable::try_new(schema, partitions).expect("example in-memory table"); + Self { inner } + } +} + +impl Default for ExampleTableProvider { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl TableProvider for ExampleTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.inner.schema() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + // A real provider builds its own ExecutionPlan here, honoring the + // pushed projection / filters / limit. We delegate to MemTable. + self.inner.scan(state, projection, filters, limit).await + } +} diff --git a/examples/adbc-datafusion-driver/tests/partitions.rs b/examples/adbc-datafusion-driver/tests/partitions.rs new file mode 100644 index 0000000..747e515 --- /dev/null +++ b/examples/adbc-datafusion-driver/tests/partitions.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Full system test of ADBC partitioned execution against the example driver. +//! +//! Exercises the upstream `execute_partitions` / `read_partition` (PR +//! adbc-drivers/datafusion#32, wired in via the [patch] in Cargo.toml) end to +//! end with our custom provider: plan a query, get one descriptor per output +//! partition, then read each partition on its own and assert the union is the +//! whole table -- the same contract the Spark connector relies on. + +use adbc_core::{Connection, Database, Driver, Statement}; +use adbc_datafusion_example_driver::{ExampleDriver, ExampleTableProvider}; +use datafusion::arrow::array::{Array, Int64Array}; + +#[test] +fn execute_partitions_then_read_each_partition() { + let mut driver = ExampleDriver::default(); + let db = driver.new_database().expect("new_database"); + let mut conn = db.new_connection().expect("new_connection"); + + // Plan the scan; the provider has three partitions. + let mut stmt = conn.new_statement().expect("new_statement"); + stmt.set_sql_query(format!("SELECT id, name FROM {}", ExampleTableProvider::TABLE_NAME)) + .expect("set_sql_query"); + let result = stmt.execute_partitions().expect("execute_partitions"); + + assert!( + result.partitions.len() >= 2, + "expected multiple partitions, got {}", + result.partitions.len() + ); + + // Read each partition independently (as a separate executor would) and + // collect the ids; the union must be exactly the whole table, once each. + let mut ids = Vec::new(); + for descriptor in &result.partitions { + let reader = conn.read_partition(descriptor).expect("read_partition"); + for batch in reader { + let batch = batch.expect("batch"); + let column = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("id column is Int64"); + for i in 0..column.len() { + ids.push(column.value(i)); + } + } + } + + ids.sort(); + assert_eq!(ids, vec![1, 2, 3], "every row read exactly once across partitions"); +} diff --git a/pom.xml b/pom.xml index 7ceec07..a48be6c 100644 --- a/pom.xml +++ b/pom.xml @@ -33,6 +33,7 @@ under the License. core examples + spark diff --git a/spark/pom.xml b/spark/pom.xml new file mode 100644 index 0000000..cf6e3dc --- /dev/null +++ b/spark/pom.xml @@ -0,0 +1,182 @@ + + + + 4.0.0 + + + org.apache.datafusion + datafusion-java-parent + 0.2.0-SNAPSHOT + + + datafusion-spark + DataFusion Spark DataSource + A Spark DataSourceV2 backed by a DataFusion TableProvider, over either the plain-C scan ABI or an ADBC driver. + + + 4.0.0 + 2.13 + + 18.1.0 + + + 0.23.0 + + 0.57.0 + + + + + + org.apache.datafusion + datafusion-java + ${project.version} + + + org.apache.arrow + * + + + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + + + org.apache.arrow + arrow-c-data + ${spark.arrow.version} + provided + + + + + org.apache.arrow.adbc + adbc-core + ${adbc.version} + + + org.apache.arrow + * + + + + + org.apache.arrow.adbc + adbc-driver-jni + ${adbc.version} + + + org.apache.arrow + * + + + + + + + io.substrait + core + ${substrait.version} + + + + + org.junit.jupiter + junit-jupiter + test + + + + org.mockito + mockito-core + 5.14.2 + test + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + + --add-opens=java.base/java.lang=ALL-UNNAMED + --add-opens=java.base/java.lang.invoke=ALL-UNNAMED + --add-opens=java.base/java.io=ALL-UNNAMED + --add-opens=java.base/java.net=ALL-UNNAMED + --add-opens=java.base/java.nio=ALL-UNNAMED + --add-opens=java.base/java.util=ALL-UNNAMED + --add-opens=java.base/java.util.concurrent=ALL-UNNAMED + --add-opens=java.base/sun.nio.ch=ALL-UNNAMED + --add-opens=java.base/sun.security.action=ALL-UNNAMED + + + + + + + + + + adbc-snapshot + + 0.24.0-SNAPSHOT + + + + diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcColumnarPartitionReader.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcColumnarPartitionReader.java new file mode 100644 index 0000000..530c47c --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcColumnarPartitionReader.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcStatement; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * Reads one ADBC scan partition as Spark {@link ColumnarBatch}es, zero-copy. + * + *

This executor borrows a cached database from {@link AdbcConnectionPool} (cached per executor + * so the DataFusion cdylib is loaded and its providers registered once per executor, not once per + * task) and opens its own per-task connection off it, then obtains an {@link ArrowReader} one of + * two ways depending on how the scan was planned (see {@link AdbcInputPartition}): + * + *

    + *
  • partition descriptor -> {@code AdbcConnection.readPartition(descriptor)} (multi-partition); + *
  • Substrait plan -> {@code setSubstraitPlan} + {@code executeQuery} (single-partition + * fallback). + *
+ * + *

The imported Arrow vectors are wrapped directly in Spark {@link ArrowColumnVector}s -- no + * per-cell copy -- which requires the executor JVM to have a single arrow-java (the cluster's Spark + * Arrow), shared by the ADBC driver manager and Spark. + * + *

Lifecycle: the Arrow vectors are owned by the {@link ArrowReader}. We do not close the {@link + * ColumnarBatch} (which would double-free the vectors); {@link #close()} closes the per-task ADBC + * handles and releases the {@link AdbcConnectionPool.Lease}. The cached database and its root + * allocator are owned by the pool and outlive this reader. + */ +final class AdbcColumnarPartitionReader implements PartitionReader { + + private final AdbcConnectionPool.Lease lease; + // Non-null only on the executeQuery (single-partition fallback) path. + private final AdbcStatement statement; + private final AdbcStatement.QueryResult queryResult; + private final ArrowReader reader; + private final VectorSchemaRoot root; + private final ColumnarBatch batch; + + AdbcColumnarPartitionReader(AdbcInputPartition partition) { + AdbcConnectionPool.Lease ls = null; + AdbcStatement stmt = null; + AdbcStatement.QueryResult result = null; + try { + ls = AdbcConnectionPool.acquire(partition.options); + AdbcConnection conn = ls.connection(); + + ArrowReader r; + switch (partition.kind) { + case DESCRIPTOR -> { + // Multi-partition path: read one opaque partition descriptor. + r = conn.readPartition(directBuffer(partition.payload)); + } + case SUBSTRAIT -> { + stmt = conn.createStatement(); + stmt.setSubstraitPlan(directBuffer(partition.payload)); + result = stmt.executeQuery(); + r = result.getReader(); + } + case SQL -> { + stmt = conn.createStatement(); + stmt.setSqlQuery(new String(partition.payload, java.nio.charset.StandardCharsets.UTF_8)); + result = stmt.executeQuery(); + r = result.getReader(); + } + default -> throw new IllegalStateException("unknown partition kind " + partition.kind); + } + + this.lease = ls; + this.statement = stmt; + this.queryResult = result; + this.reader = r; + this.root = reader.getVectorSchemaRoot(); + this.batch = new ColumnarBatch(wrap(root)); + } catch (Exception e) { + try { + // Close only task-owned handles plus the lease; the lease leaves the cached + // database/connection alone. + AutoCloseables.close(result, stmt, ls); + } catch (Exception suppressed) { + e.addSuppressed(suppressed); + } + throw new RuntimeException("failed to open ADBC scan partition", e); + } + } + + /** Wrap each Arrow vector of the (reused) root as a Spark column vector, once. */ + private static ColumnVector[] wrap(VectorSchemaRoot root) { + ColumnVector[] columns = new ColumnVector[root.getFieldVectors().size()]; + int i = 0; + for (FieldVector vector : root.getFieldVectors()) { + columns[i++] = new ArrowColumnVector(vector); + } + return columns; + } + + private static ByteBuffer directBuffer(byte[] bytes) { + ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length); + buffer.put(bytes).flip(); + return buffer; + } + + @Override + public boolean next() throws IOException { + // The root's vectors are reloaded in place each batch; skip empty batches. + while (reader.loadNextBatch()) { + int rows = root.getRowCount(); + if (rows > 0) { + batch.setNumRows(rows); + return true; + } + } + return false; + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + try { + // On the executeQuery path the QueryResult owns the reader, so close it instead + // of the reader; on the readPartition path close the reader directly. The lease + // releases the per-task child allocator and the per-task connection; the cached + // database and root allocator are owned by the pool. + AutoCloseable readerHandle = queryResult != null ? queryResult : reader; + AutoCloseables.close(readerHandle, statement, lease); + } catch (Exception e) { + throw new IOException("failed to close ADBC scan partition", e); + } + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcConnectionPool.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcConnectionPool.java new file mode 100644 index 0000000..4574cfe --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcConnectionPool.java @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.driver.jni.JniDriver; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; + +/** + * Per-executor-JVM cache of the expensive native ADBC objects, so the task slots on one executor do + * not each repeat driver load, provider registration, and session setup against the same cdylib. + * + *

A Spark executor runs many tasks (one per slot) in one long-lived JVM. Each task's {@link + * AdbcColumnarPartitionReader} needs an {@link AdbcConnection}; opening one per task re-runs the + * native session construction and -- the costly part -- the driver's {@code ContextInit} provider + * registration, which the DataFusion ADBC driver runs once per database. This holder keeps + * one {@link CachedDatabase} per {@link Key} (the inputs that determine the native database; see + * {@link AdbcOptions#cacheKey()}) for the life of the JVM, and each task opens its own short-lived + * connection off that shared database. + * + *

Connections are not shared across tasks. The arrow-adbc Rust FFI exporter + * reaches the inner connection on every C call via a {@code &mut} to the shared object with no + * lock, so two task threads calling into one connection concurrently would alias {@code &mut} -- + * undefined behavior, regardless of the driver being internally {@code Sync}. Caching the database + * (where {@code ContextInit} runs) already removes the per-task cost, so the connection stays + * per-task. + * + *

Cached handles outlive every task and are closed only by a JVM shutdown hook (executors are + * long-lived). A per-task {@link Lease} hands out a task-owned connection plus a child allocator + * and, on {@link Lease#close()}, releases only those task-owned handles -- never the cached + * database or its root allocator. + */ +final class AdbcConnectionPool { + + private static final ConcurrentHashMap CACHE = new ConcurrentHashMap<>(); + private static final AtomicBoolean HOOK_INSTALLED = new AtomicBoolean(false); + + // Observability counters: how many native databases were built and how many per-task connections + // were opened over the JVM's lifetime. The point of the pool is that the first stays at + // one-per-executor (provider registration runs once); the second scales with tasks. Read by the + // driver-gated E2E benchmark. + private static final AtomicInteger NATIVE_DATABASES = new AtomicInteger(); + private static final AtomicInteger TASK_CONNECTIONS = new AtomicInteger(); + + // Indirection so unit/concurrency tests can supply cheap doubles in place of native handles. + private static volatile DatabaseBuilder builder = AdbcConnectionPool::buildNative; + + private AdbcConnectionPool() {} + + /** + * Identity of a native ADBC database: the driver path/name plus the database-option map, compared + * order-insensitively (map order does not change the native database). Derived from {@link + * AdbcOptions} on the executor with no new wire fields -- both inputs already ride along inside + * the serializable {@link AdbcOptions}. + */ + static final class Key { + private final String driver; + private final List> sortedOptions; + + Key(String driver, Map databaseOptions) { + this.driver = Objects.requireNonNull(driver, "driver"); + List> entries = new ArrayList<>(databaseOptions.size()); + for (Map.Entry e : databaseOptions.entrySet()) { + entries.add(new AbstractMap.SimpleImmutableEntry<>(e.getKey(), e.getValue())); + } + entries.sort(Map.Entry.comparingByKey()); + this.sortedOptions = List.copyOf(entries); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Key)) { + return false; + } + Key k = (Key) o; + return driver.equals(k.driver) && sortedOptions.equals(k.sortedOptions); + } + + @Override + public int hashCode() { + return Objects.hash(driver, sortedOptions); + } + + @Override + public String toString() { + return "AdbcConnectionPool.Key[driver=" + driver + ", options=" + sortedOptions + "]"; + } + } + + /** One cached native database plus its root allocator. */ + static final class CachedDatabase implements AutoCloseable { + final BufferAllocator rootAllocator; + final AdbcDatabase database; + final AtomicInteger activeLeases = new AtomicInteger(); + volatile long lastAccessNanos; + + CachedDatabase(BufferAllocator rootAllocator, AdbcDatabase database) { + this.rootAllocator = rootAllocator; + this.database = database; + } + + @Override + public void close() throws Exception { + // Close order: database -> allocator (the allocator must outlive the handles opened against + // it). + AutoCloseables.close(database, rootAllocator); + } + } + + /** Builds the native objects for a key; replaceable in tests. */ + @FunctionalInterface + interface DatabaseBuilder { + CachedDatabase build(Key key, AdbcOptions options) throws AdbcException; + } + + /** + * A task's borrowed handle on a {@link CachedDatabase}. Carries the task-owned connection the + * reader uses and a per-task child allocator. {@link #close()} releases only what the task owns; + * the cached database and root allocator are not touched. + */ + static final class Lease implements AutoCloseable { + private final CachedDatabase shared; + private final BufferAllocator taskAllocator; + private final AdbcConnection connection; + + Lease(CachedDatabase shared, BufferAllocator taskAllocator, AdbcConnection connection) { + this.shared = shared; + this.taskAllocator = taskAllocator; + this.connection = connection; + } + + AdbcConnection connection() { + return connection; + } + + BufferAllocator allocator() { + return taskAllocator; + } + + @Override + public void close() throws Exception { + try { + // Close the task-owned connection and allocator. The cached database and root allocator are + // never closed here -- only by the JVM shutdown hook. + AutoCloseables.close(connection, taskAllocator); + } finally { + shared.activeLeases.decrementAndGet(); + } + } + } + + /** Fetch-or-create the cached database for these options and return a per-task lease. */ + static Lease acquire(AdbcOptions options) throws AdbcException { + installShutdownHookOnce(); + Key key = options.cacheKey(); + CachedDatabase shared = getOrCreate(key, options); + shared.activeLeases.incrementAndGet(); + shared.lastAccessNanos = System.nanoTime(); + + BufferAllocator taskAllocator = null; + try { + // A child of the shared root: per-task accounting/isolation; closed when the lease closes. + taskAllocator = shared.rootAllocator.newChildAllocator("adbc-task", 0, Long.MAX_VALUE); + AdbcConnection conn = shared.database.connect(); + TASK_CONNECTIONS.incrementAndGet(); + return new Lease(shared, taskAllocator, conn); + } catch (Exception e) { + shared.activeLeases.decrementAndGet(); + if (taskAllocator != null) { + try { + taskAllocator.close(); + } catch (Exception suppressed) { + e.addSuppressed(suppressed); + } + } + if (e instanceof AdbcException) { + throw (AdbcException) e; + } + throw new RuntimeException("failed to acquire ADBC connection lease", e); + } + } + + private static CachedDatabase getOrCreate(Key key, AdbcOptions options) throws AdbcException { + try { + // computeIfAbsent runs the mapping function at most once per absent key; concurrent callers + // on the same key block until it completes, so exactly one database is built and ContextInit + // runs once. AdbcException is checked and cannot escape the mapping Function, so wrap it and + // unwrap below; on failure nothing is inserted, so the next caller retries cleanly. + return CACHE.computeIfAbsent( + key, + k -> { + try { + return builder.build(k, options); + } catch (AdbcException e) { + throw new UncheckedAdbcException(e); + } + }); + } catch (UncheckedAdbcException e) { + throw e.cause(); + } + } + + private static CachedDatabase buildNative(Key key, AdbcOptions options) throws AdbcException { + BufferAllocator root = new RootAllocator(); + AdbcDatabase db = null; + try { + db = new JniDriver(root).open(options.driverParameters()); + NATIVE_DATABASES.incrementAndGet(); + return new CachedDatabase(root, db); + } catch (Exception e) { + try { + AutoCloseables.close(db, root); + } catch (Exception suppressed) { + e.addSuppressed(suppressed); + } + if (e instanceof AdbcException) { + throw (AdbcException) e; + } + throw new RuntimeException("failed to open ADBC database", e); + } + } + + private static void installShutdownHookOnce() { + if (HOOK_INSTALLED.compareAndSet(false, true)) { + Runtime.getRuntime() + .addShutdownHook(new Thread(AdbcConnectionPool::closeAll, "adbc-pool-shutdown")); + } + } + + /** Remove and close every cached database. Run by the shutdown hook; best-effort. */ + static void closeAll() { + for (Key key : new ArrayList<>(CACHE.keySet())) { + CachedDatabase d = CACHE.remove(key); + if (d != null) { + try { + d.close(); + } catch (Exception e) { + // Best effort: the JVM is going away. Nothing useful to do with the exception here. + } + } + } + } + + /** Unchecked carrier so a checked {@link AdbcException} can cross the computeIfAbsent lambda. */ + private static final class UncheckedAdbcException extends RuntimeException { + private static final long serialVersionUID = 1L; + + UncheckedAdbcException(AdbcException cause) { + super(cause); + } + + AdbcException cause() { + return (AdbcException) getCause(); + } + } + + // ---- Test hooks (package-private) ---- + + static void setBuilderForTesting(DatabaseBuilder b) { + builder = b; + } + + static int cacheSizeForTesting() { + return CACHE.size(); + } + + /** Native databases built over the JVM lifetime; one per executor when the pool works. */ + static int databasesBuiltForTesting() { + return NATIVE_DATABASES.get(); + } + + /** Per-task connections opened over the JVM lifetime. */ + static int taskConnectionsOpenedForTesting() { + return TASK_CONNECTIONS.get(); + } + + /** Close and clear all cached state and restore production defaults. For test isolation. */ + static void resetForTesting() { + closeAll(); + builder = AdbcConnectionPool::buildNative; + NATIVE_DATABASES.set(0); + TASK_CONNECTIONS.set(0); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcDatafusionTableProvider.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcDatafusionTableProvider.java new file mode 100644 index 0000000..bec7db8 --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcDatafusionTableProvider.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.util.Map; + +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.driver.jni.JniDriver; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableProvider; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.sources.DataSourceRegister; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * Entry point for the {@code adbc-datafusion} Spark data source. + * + *

The native boundary is a standard ADBC driver: the arrow-adbc Java driver manager loading a + * DataFusion ADBC cdylib. {@code spark.read.format("adbc-datafusion").option("driver", + * ...).option("table", ...).load()} resolves here; the schema is probed once, on the driver, via + * {@link AdbcConnection#getTableSchema}. + */ +public final class AdbcDatafusionTableProvider implements TableProvider, DataSourceRegister { + + @Override + public String shortName() { + return "adbc-datafusion"; + } + + @Override + public StructType inferSchema(CaseInsensitiveStringMap options) { + AdbcOptions opts = AdbcOptions.fromOptions(options); + try (BufferAllocator allocator = new RootAllocator(); + AdbcDatabase db = new JniDriver(allocator).open(opts.driverParameters()); + AdbcConnection conn = db.connect()) { + Schema arrow = conn.getTableSchema(null, null, opts.table()); + return SchemaConverter.toSparkSchema(arrow); + } catch (Exception e) { + throw new RuntimeException("failed to probe ADBC schema for table " + opts.table(), e); + } + } + + @Override + public Table getTable( + StructType schema, Transform[] partitioning, Map properties) { + AdbcOptions opts = AdbcOptions.fromOptions(new CaseInsensitiveStringMap(properties)); + return new AdbcTable(schema, opts); + } + + @Override + public boolean supportsExternalMetadata() { + return false; + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcInputPartition.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcInputPartition.java new file mode 100644 index 0000000..ed880dd --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcInputPartition.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import org.apache.spark.sql.connector.read.InputPartition; + +/** + * A serializable slice of an ADBC scan shipped to an executor. Carries only the connector options + * and an opaque payload -- never a native handle or connection, which are meaningless in another + * process. Each executor reopens its own ADBC database/connection from {@code options}. + * + *

The {@code kind} says how the executor turns {@code payload} into an {@code ArrowReader}: + * + *

    + *
  • {@link Kind#DESCRIPTOR}: {@code payload} is an ADBC partition descriptor from {@code + * executePartitioned()}; the executor runs {@code AdbcConnection.readPartition(payload)} (the + * multi-partition path, one partition per descriptor). + *
  • {@link Kind#SUBSTRAIT}: {@code payload} is a serialized Substrait plan; the executor runs + * {@code setSubstraitPlan} + {@code executeQuery} (single partition). + *
  • {@link Kind#SQL}: {@code payload} is UTF-8 SQL; the executor runs {@code setSqlQuery} + + * {@code executeQuery} (single partition; used when the driver/JNI lacks Substrait support). + *
+ */ +final class AdbcInputPartition implements InputPartition { + + private static final long serialVersionUID = 1L; + + enum Kind { + DESCRIPTOR, + SUBSTRAIT, + SQL + } + + final AdbcOptions options; + final byte[] payload; + final Kind kind; + + AdbcInputPartition(AdbcOptions options, byte[] payload, Kind kind) { + this.options = options; + this.payload = payload; + this.kind = kind; + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcOptions.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcOptions.java new file mode 100644 index 0000000..8251b03 --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcOptions.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.io.Serializable; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.OptionalInt; + +import org.apache.arrow.adbc.driver.jni.JniDriver; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * Connector options for the {@code adbc-datafusion} data source, decoded from the Spark options + * map. + * + *

Serializable so it can ride along inside an {@link AdbcInputPartition} to each executor, which + * reopens its own ADBC database/connection. Three keys are connector-control; everything else is + * forwarded verbatim as native ADBC database options (so provider-specific options pass straight + * through to the registered {@code TableProvider}). + * + *

    + *
  • {@code driver} (required) -- path to, or name of, the native ADBC driver shared library + * (our DataFusion ADBC cdylib). Resolved by the C driver manager. + *
  • {@code manifest.path} (optional) -- extra search path for ADBC driver manifests, when the + * driver is given by manifest name rather than an absolute path. + *
  • {@code table} (required) -- the table name the registered provider exposes; becomes the + * Substrait {@code NamedScan} target. + *
+ */ +final class AdbcOptions implements Serializable { + + private static final long serialVersionUID = 1L; + + static final String DRIVER = "driver"; + static final String TABLE = "table"; + static final String TARGET_PARTITIONS = "target_partitions"; + + private final String driver; + private final String table; + // DataFusion target_partitions, or null to default to the cluster parallelism. + // Stored as a (Serializable) Integer rather than OptionalInt -- this object rides + // to executors inside AdbcInputPartition, and OptionalInt is not Serializable. + private final Integer targetPartitions; + // Provider-specific / passthrough ADBC database options. + private final LinkedHashMap databaseOptions; + + private AdbcOptions( + String driver, + String table, + Integer targetPartitions, + LinkedHashMap databaseOptions) { + this.driver = driver; + this.table = table; + this.targetPartitions = targetPartitions; + this.databaseOptions = databaseOptions; + } + + static AdbcOptions fromOptions(CaseInsensitiveStringMap options) { + String driver = require(options, DRIVER); + String table = require(options, TABLE); + Integer targetPartitions = + parsePositiveInt(options, TARGET_PARTITIONS).isPresent() + ? parsePositiveInt(options, TARGET_PARTITIONS).getAsInt() + : null; + + LinkedHashMap passthrough = new LinkedHashMap<>(); + for (Map.Entry entry : options.entrySet()) { + String key = entry.getKey(); + if (key.equalsIgnoreCase(DRIVER) + || key.equalsIgnoreCase(TABLE) + || key.equalsIgnoreCase(TARGET_PARTITIONS)) { + continue; + } + passthrough.put(key, entry.getValue()); + } + return new AdbcOptions(driver, table, targetPartitions, passthrough); + } + + String table() { + return table; + } + + /** Explicit DataFusion target_partitions, or empty to use the cluster parallelism. */ + OptionalInt targetPartitions() { + return targetPartitions == null ? OptionalInt.empty() : OptionalInt.of(targetPartitions); + } + + /** + * Build the parameter map for {@link JniDriver#open}. All values must be strings; the {@code + * jni.driver} key (set via {@link JniDriver#PARAM_DRIVER}) is consumed by the driver manager to + * locate the shared library, the rest become native database options. + */ + Map driverParameters() { + Map params = new LinkedHashMap<>(); + JniDriver.PARAM_DRIVER.set(params, driver); + params.putAll(databaseOptions); + return params; + } + + /** + * Identity of the native ADBC database these options open. Two option sets that produce the same + * {@link #driverParameters()} share a key, so {@link AdbcConnectionPool} caches one native + * database/connection per executor JVM for them. Derived from exactly the inputs {@code + * driverParameters()} consumes -- the {@code driver} path/name plus the passthrough database + * options -- and must stay adjacent to it so the two cannot drift. {@code table} and {@code + * target_partitions} are connector-control (not forwarded to the native database) and so are + * deliberately excluded. + */ + AdbcConnectionPool.Key cacheKey() { + return new AdbcConnectionPool.Key(driver, databaseOptions); + } + + private static String require(CaseInsensitiveStringMap options, String key) { + String value = options.getOrDefault(key, null); + if (value == null || value.isEmpty()) { + throw new IllegalArgumentException( + "the 'adbc-datafusion' data source requires the '" + key + "' option"); + } + return value; + } + + private static OptionalInt parsePositiveInt(CaseInsensitiveStringMap options, String key) { + String value = options.getOrDefault(key, null); + if (value == null || value.isEmpty()) { + return OptionalInt.empty(); + } + int parsed; + try { + parsed = Integer.parseInt(value.trim()); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("'" + key + "' must be an integer, got: " + value, e); + } + if (parsed < 1) { + throw new IllegalArgumentException("'" + key + "' must be >= 1, got: " + parsed); + } + return OptionalInt.of(parsed); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcPartitionReaderFactory.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcPartitionReaderFactory.java new file mode 100644 index 0000000..eb1bf67 --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcPartitionReaderFactory.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * Creates a columnar ADBC reader per partition. Serialized to executors, so it holds no state. + * + *

Reads are columnar: {@link #supportColumnarReads} returns true, so Spark consumes Arrow + * buffers directly via {@link AdbcColumnarPartitionReader}. The row reader is unsupported. + */ +final class AdbcPartitionReaderFactory implements PartitionReaderFactory { + + private static final long serialVersionUID = 1L; + + @Override + public boolean supportColumnarReads(InputPartition partition) { + return true; + } + + @Override + public PartitionReader createColumnarReader(InputPartition partition) { + return new AdbcColumnarPartitionReader((AdbcInputPartition) partition); + } + + @Override + public PartitionReader createReader(InputPartition partition) { + throw new UnsupportedOperationException("adbc-datafusion source reads are columnar"); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcScanBuilder.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanBuilder.java new file mode 100644 index 0000000..b60e6ac --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanBuilder.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.OptionalLong; +import java.util.Set; + +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.SupportsPushDownFilters; +import org.apache.spark.sql.connector.read.SupportsPushDownLimit; +import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructType; + +/** + * Builds an {@link AdbcScanImpl}, pushing projection / filter / limit down as Substrait. + * + *

Pushdown decision: a filter is pushed only if {@link SubstraitPlan#canPush} maps it + * to a standard-catalog Substrait predicate over a known column; the rest are returned to Spark. + * Projection becomes a Substrait emit. Limit is pushed only when no filters were left to Spark -- + * otherwise Spark must still filter after the scan, so a scan-side limit would be wrong. + */ +final class AdbcScanBuilder + implements ScanBuilder, + SupportsPushDownRequiredColumns, + SupportsPushDownFilters, + SupportsPushDownLimit { + + private final StructType fullSchema; + private final AdbcOptions options; + + private StructType requiredSchema; + private List pushedFilters = new ArrayList<>(); + private boolean hasResidualFilters = false; + private OptionalLong limit = OptionalLong.empty(); + + AdbcScanBuilder(StructType schema, AdbcOptions options) { + this.fullSchema = schema; + this.requiredSchema = schema; + this.options = options; + } + + @Override + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; + } + + @Override + public Filter[] pushFilters(Filter[] filters) { + Set columns = Set.of(fullSchema.fieldNames()); + List pushable = new ArrayList<>(); + List residual = new ArrayList<>(); + for (Filter f : filters) { + if (SubstraitPlan.canPush(f, columns)) { + pushable.add(f); + } else { + residual.add(f); + } + } + this.pushedFilters = pushable; + this.hasResidualFilters = !residual.isEmpty(); + return residual.toArray(new Filter[0]); + } + + @Override + public Filter[] pushedFilters() { + return pushedFilters.toArray(new Filter[0]); + } + + @Override + public boolean pushLimit(int limit) { + // Only safe to push when Spark has no filters left to apply after the scan; + // otherwise the limit would be applied before that filtering. + if (hasResidualFilters) { + return false; + } + this.limit = OptionalLong.of(limit); + return true; + } + + @Override + public Scan build() { + List projection = + Arrays.equals(requiredSchema.fieldNames(), fullSchema.fieldNames()) + ? null + : Arrays.asList(requiredSchema.fieldNames()); + return new AdbcScanImpl(requiredSchema, options, projection, pushedFilters, limit); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java new file mode 100644 index 0000000..6c4a4cf --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.OptionalLong; + +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.core.AdbcStatement; +import org.apache.arrow.adbc.core.AdbcStatusCode; +import org.apache.arrow.adbc.core.PartitionDescriptor; +import org.apache.arrow.adbc.driver.jni.JniDriver; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.datafusion.spark.AdbcInputPartition.Kind; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructType; + +/** + * A planned ADBC-backed scan as a Spark {@link Scan}/{@link Batch}. + * + *

{@link #planInputPartitions()} runs on the driver and probes the driver's capabilities, since + * the arrow-adbc JNI manager and the underlying driver implement different subsets of ADBC: + * + *

    + *
  1. Pick the query wire: prefer a typed Substrait plan; if the driver reports {@code + * NOT_IMPLEMENTED} for {@code setSubstraitPlan}, fall back to a SQL string. + *
  2. Pick the partitioning: try {@code executePartitioned()} for one descriptor per output + * partition; if {@code NOT_IMPLEMENTED}, emit a single partition carrying the chosen query. + *
+ * + * No native handle ever crosses the wire -- each {@link AdbcInputPartition} carries only opaque + * bytes, and each executor reopens its own ADBC connection. + */ +final class AdbcScanImpl implements Scan, Batch { + + private final StructType readSchema; + private final AdbcOptions options; + // null projection means all columns; column names in output order otherwise. + private final List projection; + private final List pushedFilters; + private final OptionalLong limit; + + AdbcScanImpl( + StructType readSchema, + AdbcOptions options, + List projection, + List pushedFilters, + OptionalLong limit) { + this.readSchema = readSchema; + this.options = options; + this.projection = projection; + this.pushedFilters = pushedFilters; + this.limit = limit; + } + + @Override + public StructType readSchema() { + return readSchema; + } + + @Override + public Batch toBatch() { + return this; + } + + @Override + public InputPartition[] planInputPartitions() { + try (BufferAllocator allocator = new RootAllocator(); + AdbcDatabase db = new JniDriver(allocator).open(options.driverParameters()); + AdbcConnection conn = db.connect()) { + // Set DataFusion's target_partitions on the planning session. execute_partitions + // pins it into each descriptor, so executors re-plan into the same partitioning. + // Defaults to the cluster's parallelism (total executor cores). Repartition-aware + // providers (e.g. file scans) use this to choose N; fixed-partition providers + // (e.g. an in-memory table) keep their intrinsic partition count. + int targetPartitions = options.targetPartitions().orElseGet(AdbcScanImpl::clusterParallelism); + applyTargetPartitions(conn, targetPartitions); + + Schema arrow = conn.getTableSchema(null, null, options.table()); + byte[] substrait = + SubstraitPlan.build(options.table(), arrow, projection, pushedFilters, limit); + String sql = SqlQuery.build(options.table(), projection, pushedFilters, limit); + return plan(conn, substrait, sql); + } catch (Exception e) { + throw new RuntimeException("failed to plan ADBC scan for table " + options.table(), e); + } + } + + private static void applyTargetPartitions(AdbcConnection conn, int targetPartitions) + throws Exception { + try (AdbcStatement stmt = conn.createStatement()) { + stmt.setSqlQuery("SET datafusion.execution.target_partitions = " + targetPartitions); + stmt.executeUpdate(); + } + } + + /** Total executor cores via the active SparkSession; falls back to local cores. */ + private static int clusterParallelism() { + try { + return SparkSession.active().sparkContext().defaultParallelism(); + } catch (Throwable t) { + return Runtime.getRuntime().availableProcessors(); + } + } + + /** Probe the query wire and partitioning, returning the input partitions. */ + private InputPartition[] plan(AdbcConnection conn, byte[] substrait, String sql) + throws Exception { + Kind singleKind; + byte[] singlePayload; + + // Escape hatch: force the SQL wire (e.g. when the Substrait round-trip plans + // to fewer partitions than SQL). Defaults to preferring Substrait. + boolean forceSql = "sql".equalsIgnoreCase(System.getProperty("adbc.wire", "")); + + AdbcStatement stmt = conn.createStatement(); + try { + if (forceSql) { + stmt.setSqlQuery(sql); + singleKind = Kind.SQL; + singlePayload = sql.getBytes(StandardCharsets.UTF_8); + } else { + try { + stmt.setSubstraitPlan(directBuffer(substrait)); + singleKind = Kind.SUBSTRAIT; + singlePayload = substrait; + } catch (AdbcException e) { + if (!notImplemented(e)) { + throw e; + } + // Driver/JNI without Substrait support -> SQL. + AutoCloseables.close(stmt); + stmt = conn.createStatement(); + stmt.setSqlQuery(sql); + singleKind = Kind.SQL; + singlePayload = sql.getBytes(StandardCharsets.UTF_8); + } + } + + List descriptors = tryPartition(stmt); + if (descriptors != null) { + return descriptors.toArray(new InputPartition[0]); + } + } finally { + AutoCloseables.close(stmt); + } + // Single-partition fallback carrying the chosen query. + return new InputPartition[] {new AdbcInputPartition(options, singlePayload, singleKind)}; + } + + /** + * Attempt ADBC partitioned execution on the (already-query-set) statement. Returns one partition + * per descriptor, or {@code null} if the driver does not implement partitioning. + */ + private List tryPartition(AdbcStatement stmt) throws Exception { + try { + AdbcStatement.PartitionResult result = stmt.executePartitioned(); + List descriptors = result.getPartitionDescriptors(); + if (descriptors.isEmpty()) { + return null; + } + List partitions = new ArrayList<>(descriptors.size()); + for (PartitionDescriptor descriptor : descriptors) { + partitions.add( + new AdbcInputPartition(options, toBytes(descriptor.getDescriptor()), Kind.DESCRIPTOR)); + } + return partitions; + } catch (AdbcException e) { + if (notImplemented(e)) { + return null; + } + throw e; + } + } + + private static boolean notImplemented(AdbcException e) { + return e.getStatus() == AdbcStatusCode.NOT_IMPLEMENTED + || e.getStatus() == AdbcStatusCode.NOT_FOUND; + } + + private static ByteBuffer directBuffer(byte[] bytes) { + ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length); + buffer.put(bytes).flip(); + return buffer; + } + + private static byte[] toBytes(ByteBuffer buffer) { + ByteBuffer dup = buffer.duplicate(); + byte[] bytes = new byte[dup.remaining()]; + dup.get(bytes); + return bytes; + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new AdbcPartitionReaderFactory(); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcTable.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcTable.java new file mode 100644 index 0000000..6467bc8 --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcTable.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.util.EnumSet; +import java.util.Set; + +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * A readable table over a DataFusion provider exposed via ADBC; produces {@link AdbcScanBuilder}s. + */ +final class AdbcTable implements SupportsRead { + + private final StructType schema; + private final AdbcOptions options; + + AdbcTable(StructType schema, AdbcOptions options) { + this.schema = schema; + this.options = options; + } + + @Override + public String name() { + return "adbc-datafusion"; + } + + @Override + public StructType schema() { + return schema; + } + + @Override + public Set capabilities() { + return EnumSet.of(TableCapability.BATCH_READ); + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap scanOptions) { + return new AdbcScanBuilder(schema, options); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java b/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java new file mode 100644 index 0000000..d299d79 --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; + +/** + * Converts an Arrow schema (produced by the ADBC scan) into a Spark {@link StructType}. + * + *

Done directly rather than through Spark's {@code ArrowUtils} so the connector depends only on + * our Arrow version, never Spark's bundled one. Covers the primitive types the columnar reader + * produces; unsupported types fail fast. + */ +final class SchemaConverter { + + private SchemaConverter() {} + + static StructType toSparkSchema(Schema arrowSchema) { + StructType struct = new StructType(); + for (Field field : arrowSchema.getFields()) { + struct = struct.add(field.getName(), toSparkType(field), field.isNullable()); + } + return struct; + } + + static DataType toSparkType(Field field) { + ArrowType type = field.getType(); + if (type instanceof ArrowType.Int i) { + if (!i.getIsSigned()) { + throw unsupported(field); + } + return switch (i.getBitWidth()) { + case 8 -> DataTypes.ByteType; + case 16 -> DataTypes.ShortType; + case 32 -> DataTypes.IntegerType; + case 64 -> DataTypes.LongType; + default -> throw unsupported(field); + }; + } + if (type instanceof ArrowType.FloatingPoint fp) { + return fp.getPrecision() == FloatingPointPrecision.DOUBLE + ? DataTypes.DoubleType + : DataTypes.FloatType; + } + if (type instanceof ArrowType.Utf8 || type instanceof ArrowType.LargeUtf8) { + return DataTypes.StringType; + } + if (type instanceof ArrowType.Bool) { + return DataTypes.BooleanType; + } + throw unsupported(field); + } + + private static IllegalArgumentException unsupported(Field field) { + return new IllegalArgumentException( + "unsupported Arrow type for column '" + field.getName() + "': " + field.getType()); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/SqlQuery.java b/spark/src/main/java/org/apache/datafusion/spark/SqlQuery.java new file mode 100644 index 0000000..07cfb0d --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/SqlQuery.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.util.List; +import java.util.OptionalLong; +import java.util.stream.Collectors; + +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; + +/** + * Renders the pushed-down scan as an ANSI SQL {@code SELECT}. + * + *

This is the fallback wire for ADBC clients that don't accept Substrait. The arrow-adbc JNI + * driver manager (0.23) forwards {@code SetSqlQuery} but not {@code SetSubstraitPlan}, so the + * connector prefers Substrait and falls back to this. The pushable predicate set is identical to + * {@link SubstraitPlan} (so the {@link AdbcScanBuilder} decision is encoding-independent) -- which + * is why {@link SubstraitPlan#canPush} excludes non-finite floats: they have no SQL literal. + * + *

Identifiers are double-quoted (ANSI, DataFusion's default) and string literals single-quoted, + * both with doubling-based escaping. + */ +final class SqlQuery { + + private SqlQuery() {} + + static String build( + String table, List projection, List filters, OptionalLong limit) { + StringBuilder sql = new StringBuilder("SELECT "); + if (projection == null || projection.isEmpty()) { + sql.append("*"); + } else { + sql.append(projection.stream().map(SqlQuery::quoteId).collect(Collectors.joining(", "))); + } + sql.append(" FROM ").append(quoteId(table)); + if (!filters.isEmpty()) { + sql.append(" WHERE ") + .append(filters.stream().map(SqlQuery::predicate).collect(Collectors.joining(" AND "))); + } + if (limit.isPresent()) { + sql.append(" LIMIT ").append(limit.getAsLong()); + } + return sql.toString(); + } + + private static String predicate(Filter f) { + if (f instanceof EqualTo e) { + return binary(e.attribute(), "=", e.value()); + } else if (f instanceof GreaterThan e) { + return binary(e.attribute(), ">", e.value()); + } else if (f instanceof GreaterThanOrEqual e) { + return binary(e.attribute(), ">=", e.value()); + } else if (f instanceof LessThan e) { + return binary(e.attribute(), "<", e.value()); + } else if (f instanceof LessThanOrEqual e) { + return binary(e.attribute(), "<=", e.value()); + } else if (f instanceof IsNull e) { + return "(" + quoteId(e.attribute()) + " IS NULL)"; + } else if (f instanceof IsNotNull e) { + return "(" + quoteId(e.attribute()) + " IS NOT NULL)"; + } else if (f instanceof And e) { + return "(" + predicate(e.left()) + " AND " + predicate(e.right()) + ")"; + } else if (f instanceof Or e) { + return "(" + predicate(e.left()) + " OR " + predicate(e.right()) + ")"; + } else if (f instanceof Not e) { + return "(NOT " + predicate(e.child()) + ")"; + } + throw new IllegalArgumentException("filter is not pushable to SQL: " + f); + } + + private static String binary(String attribute, String op, Object value) { + return "(" + quoteId(attribute) + " " + op + " " + literal(value) + ")"; + } + + private static String literal(Object v) { + if (v instanceof Boolean b) { + return b ? "TRUE" : "FALSE"; + } else if (v instanceof String s) { + return "'" + s.replace("'", "''") + "'"; + } else if (v instanceof Integer + || v instanceof Long + || v instanceof Short + || v instanceof Byte) { + return v.toString(); + } else if (v instanceof Double || v instanceof Float) { + // SubstraitPlan.canPush has already rejected NaN/Infinity, so toString is a + // valid SQL numeric literal here. + return v.toString(); + } + throw new IllegalArgumentException("unsupported SQL literal type: " + v.getClass()); + } + + private static String quoteId(String identifier) { + return "\"" + identifier.replace("\"", "\"\"") + "\""; + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/SubstraitPlan.java b/spark/src/main/java/org/apache/datafusion/spark/SubstraitPlan.java new file mode 100644 index 0000000..83a24c0 --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/SubstraitPlan.java @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.OptionalLong; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; + +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; +import io.substrait.plan.Plan; +import io.substrait.plan.PlanProtoConverter; +import io.substrait.relation.NamedScan; +import io.substrait.relation.Rel; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; + +/** + * Builds the Substrait plan pushed to the ADBC statement ({@code setSubstraitPlan}). + * + *

Shape: {@code NamedScan -> [Filter] -> [Project(emit)] -> [Fetch]}. The plan is built with + * substrait-java's {@link SubstraitBuilder}, and every pushed predicate resolves against the + * standard Substrait function catalog ({@link + * DefaultExtensionCatalog#FUNCTIONS_COMPARISON} / {@link + * DefaultExtensionCatalog#FUNCTIONS_BOOLEAN}). That is the key to interop: those are the function + * URIs/signatures DataFusion's Substrait consumer ({@code from_substrait_plan}) recognizes, so we + * never emit a custom extension the two sides could disagree on. Anything outside the whitelist in + * {@link #canPush} is left to Spark. + * + *

Pushdown is carried as typed Substrait rather than a SQL string, so literal fidelity + * (float/decimal/NaN/null-safe) is not lost in a text round-trip. + */ +final class SubstraitPlan { + + /** Load the standard extension declarations once; immutable and shareable. */ + private static final SimpleExtension.ExtensionCollection EXTENSIONS = + SimpleExtension.loadDefaults(); + + /** Predicates produce a (nullable, three-valued) boolean. */ + private static final Type BOOL = TypeCreator.NULLABLE.BOOLEAN; + + private SubstraitPlan() {} + + /** + * Build the scan plan with the given pushdown. + * + * @param table the table name the provider exposes + * @param schema the table's full Arrow schema + * @param projection kept column names (in output order), or {@code null} for all columns + * @param filters the pushed predicates (must all satisfy {@link #canPush}); ANDed together + * @param limit an optional row limit applied after filtering + */ + static byte[] build( + String table, + Schema schema, + List projection, + List filters, + OptionalLong limit) { + SubstraitBuilder b = new SubstraitBuilder(EXTENSIONS); + + List names = new ArrayList<>(schema.getFields().size()); + List types = new ArrayList<>(schema.getFields().size()); + Map columnIndex = new java.util.HashMap<>(); + int i = 0; + for (Field field : schema.getFields()) { + names.add(field.getName()); + types.add(toSubstraitType(field)); + columnIndex.put(field.getName(), i++); + } + + NamedScan scan = b.namedScan(List.of(table), names, types); + Rel rel = scan; + + if (!filters.isEmpty()) { + rel = b.filter(input -> conjunction(b, input, columnIndex, filters), rel); + } + + List outputNames = names; + if (projection != null && !projection.equals(names)) { + List keep = projection.stream().map(columnIndex::get).collect(Collectors.toList()); + int inputWidth = names.size(); + rel = + b.project( + input -> + keep.stream() + .map(k -> (Expression) b.fieldReference(input, k)) + .collect(Collectors.toList()), + // Project appends expressions after the input columns; emit only the + // appended ones so the output is exactly the kept columns. + Rel.Remap.offset(inputWidth, keep.size()), + rel); + outputNames = projection; + } + + if (limit.isPresent()) { + rel = b.limit(limit.getAsLong(), rel); + } + + Plan plan = b.plan(Plan.Root.builder().input(rel).addAllNames(outputNames).build()); + return new PlanProtoConverter().toProto(plan).toByteArray(); + } + + // --- Spark Filter -> Substrait Expression --------------------------------- + + /** Whether {@code f} maps to a standard-catalog Substrait predicate over known columns. */ + static boolean canPush(Filter f, Set columns) { + if (f instanceof EqualTo e) { + return columns.contains(e.attribute()) && isPushableLiteral(e.value()); + } else if (f instanceof GreaterThan e) { + return columns.contains(e.attribute()) && isPushableLiteral(e.value()); + } else if (f instanceof GreaterThanOrEqual e) { + return columns.contains(e.attribute()) && isPushableLiteral(e.value()); + } else if (f instanceof LessThan e) { + return columns.contains(e.attribute()) && isPushableLiteral(e.value()); + } else if (f instanceof LessThanOrEqual e) { + return columns.contains(e.attribute()) && isPushableLiteral(e.value()); + } else if (f instanceof IsNull e) { + return columns.contains(e.attribute()); + } else if (f instanceof IsNotNull e) { + return columns.contains(e.attribute()); + } else if (f instanceof And e) { + return canPush(e.left(), columns) && canPush(e.right(), columns); + } else if (f instanceof Or e) { + return canPush(e.left(), columns) && canPush(e.right(), columns); + } else if (f instanceof Not e) { + return canPush(e.child(), columns); + } + return false; + } + + private static Expression conjunction( + SubstraitBuilder b, Rel input, Map idx, List filters) { + List terms = + filters.stream().map(f -> translate(b, input, idx, f)).collect(Collectors.toList()); + if (terms.size() == 1) { + return terms.get(0); + } + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, + "and:bool", + BOOL, + terms.toArray(new Expression[0])); + } + + private static Expression translate( + SubstraitBuilder b, Rel input, Map idx, Filter f) { + if (f instanceof EqualTo e) { + return cmp(b, input, idx, "equal:any_any", e.attribute(), e.value()); + } else if (f instanceof GreaterThan e) { + return cmp(b, input, idx, "gt:any_any", e.attribute(), e.value()); + } else if (f instanceof GreaterThanOrEqual e) { + return cmp(b, input, idx, "gte:any_any", e.attribute(), e.value()); + } else if (f instanceof LessThan e) { + return cmp(b, input, idx, "lt:any_any", e.attribute(), e.value()); + } else if (f instanceof LessThanOrEqual e) { + return cmp(b, input, idx, "lte:any_any", e.attribute(), e.value()); + } else if (f instanceof IsNull e) { + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_COMPARISON, + "is_null:any", + BOOL, + b.fieldReference(input, idx.get(e.attribute()))); + } else if (f instanceof IsNotNull e) { + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_COMPARISON, + "is_not_null:any", + BOOL, + b.fieldReference(input, idx.get(e.attribute()))); + } else if (f instanceof And e) { + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, + "and:bool", + BOOL, + translate(b, input, idx, e.left()), + translate(b, input, idx, e.right())); + } else if (f instanceof Or e) { + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, + "or:bool", + BOOL, + translate(b, input, idx, e.left()), + translate(b, input, idx, e.right())); + } else if (f instanceof Not e) { + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, + "not:bool", + BOOL, + translate(b, input, idx, e.child())); + } + throw new IllegalArgumentException("filter is not pushable: " + f); + } + + private static Expression cmp( + SubstraitBuilder b, + Rel input, + Map idx, + String key, + String attribute, + Object value) { + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_COMPARISON, + key, + BOOL, + b.fieldReference(input, idx.get(attribute)), + literal(value)); + } + + private static boolean isPushableLiteral(Object v) { + // Exclude non-finite floats: they have no SQL literal, and the connector may + // encode this predicate as SQL (when the driver lacks Substrait support), so + // the pushable set must be expressible in both wires. + if (v instanceof Double d) { + return !d.isNaN() && !d.isInfinite(); + } + if (v instanceof Float f) { + return !f.isNaN() && !f.isInfinite(); + } + return v instanceof Integer + || v instanceof Long + || v instanceof Short + || v instanceof Byte + || v instanceof Boolean + || v instanceof String; + } + + private static Expression literal(Object v) { + if (v instanceof Integer n) { + return ExpressionCreator.i32(true, n); + } else if (v instanceof Long n) { + return ExpressionCreator.i64(true, n); + } else if (v instanceof Short n) { + return ExpressionCreator.i16(true, n); + } else if (v instanceof Byte n) { + return ExpressionCreator.i8(true, n); + } else if (v instanceof Double n) { + return ExpressionCreator.fp64(true, n); + } else if (v instanceof Float n) { + return ExpressionCreator.fp32(true, n); + } else if (v instanceof Boolean n) { + return ExpressionCreator.bool(true, n); + } else if (v instanceof String n) { + return ExpressionCreator.string(true, n); + } + throw new IllegalArgumentException("unsupported literal type: " + v.getClass()); + } + + // --- Arrow type -> Substrait type ----------------------------------------- + + private static Type toSubstraitType(Field field) { + ArrowType type = field.getType(); + TypeCreator t = field.isNullable() ? TypeCreator.NULLABLE : TypeCreator.REQUIRED; + if (type instanceof ArrowType.Int n) { + if (!n.getIsSigned()) { + throw unsupported(field); + } + return switch (n.getBitWidth()) { + case 8 -> t.I8; + case 16 -> t.I16; + case 32 -> t.I32; + case 64 -> t.I64; + default -> throw unsupported(field); + }; + } + if (type instanceof ArrowType.FloatingPoint fp) { + return fp.getPrecision() == FloatingPointPrecision.DOUBLE ? t.FP64 : t.FP32; + } + if (type instanceof ArrowType.Utf8 || type instanceof ArrowType.LargeUtf8) { + return t.STRING; + } + if (type instanceof ArrowType.Bool) { + return t.BOOLEAN; + } + throw unsupported(field); + } + + private static IllegalArgumentException unsupported(Field field) { + return new IllegalArgumentException( + "unsupported Arrow type for column '" + field.getName() + "': " + field.getType()); + } +} diff --git a/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000..9237b20 --- /dev/null +++ b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.datafusion.spark.AdbcDatafusionTableProvider diff --git a/spark/src/test/java/org/apache/datafusion/spark/AdbcConnectionPoolTest.java b/spark/src/test/java/org/apache/datafusion/spark/AdbcConnectionPoolTest.java new file mode 100644 index 0000000..75332cc --- /dev/null +++ b/spark/src/test/java/org/apache/datafusion/spark/AdbcConnectionPoolTest.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +/** Unit + concurrency tests for {@link AdbcConnectionPool}; no native driver required. */ +class AdbcConnectionPoolTest { + + @AfterEach + void tearDown() { + AdbcConnectionPool.resetForTesting(); + } + + private static AdbcOptions options(Map raw) { + return AdbcOptions.fromOptions(new CaseInsensitiveStringMap(new LinkedHashMap<>(raw))); + } + + private static LinkedHashMap base() { + LinkedHashMap m = new LinkedHashMap<>(); + m.put("driver", "/path/to/lib.so"); + m.put("table", "t"); + return m; + } + + // ---- cacheKey() ---- + + @Test + void cacheKeyIgnoresOptionOrder() { + LinkedHashMap a = base(); + a.put("foo", "1"); + a.put("bar", "2"); + LinkedHashMap b = base(); + b.put("bar", "2"); + b.put("foo", "1"); + + AdbcConnectionPool.Key ka = options(a).cacheKey(); + AdbcConnectionPool.Key kb = options(b).cacheKey(); + assertEquals(ka, kb); + assertEquals(ka.hashCode(), kb.hashCode()); + } + + @Test + void cacheKeyDistinguishesDriverAndOptionValues() { + LinkedHashMap a = base(); + a.put("foo", "1"); + LinkedHashMap diffValue = base(); + diffValue.put("foo", "2"); + LinkedHashMap diffDriver = base(); + diffDriver.put("driver", "/other/lib.so"); + diffDriver.put("foo", "1"); + + assertNotEquals(options(a).cacheKey(), options(diffValue).cacheKey()); + assertNotEquals(options(a).cacheKey(), options(diffDriver).cacheKey()); + } + + @Test + void cacheKeyIgnoresControlKeys() { + // table and target_partitions are connector-control, not forwarded to the native database, + // so they must not affect the cache key. + LinkedHashMap a = base(); + a.put("target_partitions", "4"); + LinkedHashMap b = base(); + b.put("table", "other-table"); + b.put("target_partitions", "16"); + + assertEquals(options(a).cacheKey(), options(b).cacheKey()); + } + + // ---- computeIfAbsent: at-most-once build under concurrency ---- + + @Test + void concurrentAcquireBuildsExactlyOnce() throws Exception { + AtomicInteger builds = new AtomicInteger(); + AdbcConnectionPool.setBuilderForTesting( + (key, opts) -> { + builds.incrementAndGet(); + return new AdbcConnectionPool.CachedDatabase( + new RootAllocator(), mock(AdbcDatabase.class)); + }); + + AdbcOptions opts = options(base()); + int threads = 16; + CyclicBarrier barrier = new CyclicBarrier(threads); + ConcurrentLinkedQueue leases = new ConcurrentLinkedQueue<>(); + ConcurrentLinkedQueue errors = new ConcurrentLinkedQueue<>(); + + runConcurrently( + threads, + () -> { + try { + barrier.await(); + leases.add(AdbcConnectionPool.acquire(opts)); + } catch (Throwable t) { + errors.add(t); + } + }); + + assertTrue(errors.isEmpty(), () -> "unexpected errors: " + errors); + assertEquals(1, builds.get(), "database built exactly once for one key"); + assertEquals(1, AdbcConnectionPool.cacheSizeForTesting()); + assertEquals(threads, leases.size()); + for (AdbcConnectionPool.Lease lease : leases) { + lease.close(); + } + } + + @Test + void distinctKeysBuildDistinctEntries() throws Exception { + AdbcConnectionPool.setBuilderForTesting( + (key, opts) -> + new AdbcConnectionPool.CachedDatabase(new RootAllocator(), mock(AdbcDatabase.class))); + + LinkedHashMap a = base(); + a.put("foo", "1"); + LinkedHashMap b = base(); + b.put("foo", "2"); + + AdbcConnectionPool.Lease la = AdbcConnectionPool.acquire(options(a)); + AdbcConnectionPool.Lease lb = AdbcConnectionPool.acquire(options(b)); + assertEquals(2, AdbcConnectionPool.cacheSizeForTesting()); + la.close(); + lb.close(); + } + + @Test + void buildFailureSurfacesAdbcExceptionAndLeavesCacheEmpty() throws Exception { + AtomicInteger attempts = new AtomicInteger(); + AdbcConnectionPool.setBuilderForTesting( + (key, opts) -> { + attempts.incrementAndGet(); + throw AdbcException.io("boom"); + }); + + AdbcOptions opts = options(base()); + int threads = 8; + CyclicBarrier barrier = new CyclicBarrier(threads); + ConcurrentLinkedQueue errors = new ConcurrentLinkedQueue<>(); + + runConcurrently( + threads, + () -> { + try { + barrier.await(); + AdbcConnectionPool.acquire(opts); + } catch (AdbcException e) { + errors.add(e); // expected + } catch (Throwable t) { + errors.add(t); + } + }); + + assertEquals(threads, errors.size()); + for (Throwable t : errors) { + assertTrue(t instanceof AdbcException, () -> "expected AdbcException, got " + t); + } + assertEquals(0, AdbcConnectionPool.cacheSizeForTesting(), "failed build leaves no entry"); + + // A subsequent good build inserts exactly one entry (retry works). + AdbcConnectionPool.setBuilderForTesting( + (key, o) -> + new AdbcConnectionPool.CachedDatabase(new RootAllocator(), mock(AdbcDatabase.class))); + AdbcConnectionPool.Lease lease = AdbcConnectionPool.acquire(opts); + assertEquals(1, AdbcConnectionPool.cacheSizeForTesting()); + lease.close(); + } + + // ---- Lease close semantics ---- + + @Test + void leaseCloseClosesTaskConnectionButNotDatabase() throws Exception { + try (BufferAllocator root = new RootAllocator()) { + AdbcDatabase db = mock(AdbcDatabase.class); + AdbcConnection taskConn = mock(AdbcConnection.class); + AdbcConnectionPool.CachedDatabase cached = new AdbcConnectionPool.CachedDatabase(root, db); + BufferAllocator child = root.newChildAllocator("task", 0, Long.MAX_VALUE); + + AdbcConnectionPool.Lease lease = new AdbcConnectionPool.Lease(cached, child, taskConn); + lease.close(); + + verify(taskConn, times(1)).close(); + verify(db, never()).close(); + } + } + + // ---- acquire opens a per-task connection off the one cached database ---- + + @Test + void acquireOpensConnectionPerTask() throws Exception { + AdbcDatabase db = mock(AdbcDatabase.class); + when(db.connect()).thenAnswer(invocation -> mock(AdbcConnection.class)); + AdbcConnectionPool.setBuilderForTesting( + (key, opts) -> new AdbcConnectionPool.CachedDatabase(new RootAllocator(), db)); + + AdbcOptions opts = options(base()); + AdbcConnectionPool.Lease l1 = AdbcConnectionPool.acquire(opts); + AdbcConnectionPool.Lease l2 = AdbcConnectionPool.acquire(opts); + + assertEquals(1, AdbcConnectionPool.cacheSizeForTesting(), "one cached database"); + verify(db, times(2)).connect(); // one connection per task + assertNotEquals(l1.connection(), l2.connection()); + l1.close(); + l2.close(); + } + + private static void runConcurrently(int threads, Runnable body) throws InterruptedException { + Thread[] pool = new Thread[threads]; + for (int i = 0; i < threads; i++) { + pool[i] = new Thread(body, "acquire-" + i); + pool[i].start(); + } + for (Thread t : pool) { + t.join(); + } + } +} diff --git a/spark/src/test/java/org/apache/datafusion/spark/AdbcOptionsTest.java b/spark/src/test/java/org/apache/datafusion/spark/AdbcOptionsTest.java new file mode 100644 index 0000000..bde17e9 --- /dev/null +++ b/spark/src/test/java/org/apache/datafusion/spark/AdbcOptionsTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Map; + +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.junit.jupiter.api.Test; + +class AdbcOptionsTest { + + private static CaseInsensitiveStringMap opts(Map m) { + return new CaseInsensitiveStringMap(m); + } + + @Test + void parsesTargetPartitions() { + AdbcOptions o = + AdbcOptions.fromOptions( + opts(Map.of("driver", "/lib.so", "table", "t", "target_partitions", "16"))); + assertTrue(o.targetPartitions().isPresent()); + assertEquals(16, o.targetPartitions().getAsInt()); + } + + @Test + void targetPartitionsAbsentByDefault() { + AdbcOptions o = AdbcOptions.fromOptions(opts(Map.of("driver", "/lib.so", "table", "t"))); + assertFalse(o.targetPartitions().isPresent()); + } + + @Test + void rejectsNonPositiveOrNonNumeric() { + assertThrows( + IllegalArgumentException.class, + () -> + AdbcOptions.fromOptions( + opts(Map.of("driver", "/lib.so", "table", "t", "target_partitions", "0")))); + assertThrows( + IllegalArgumentException.class, + () -> + AdbcOptions.fromOptions( + opts(Map.of("driver", "/lib.so", "table", "t", "target_partitions", "abc")))); + } + + @Test + void controlKeysNotForwardedAsDatabaseOptions() { + AdbcOptions o = + AdbcOptions.fromOptions( + opts( + Map.of( + "driver", "/lib.so", + "table", "t", + "target_partitions", "8", + "my.provider.opt", "v"))); + Map params = o.driverParameters(); + // driver -> jni.driver; provider opt passed through; control keys absent. + assertTrue(params.containsKey("my.provider.opt")); + assertFalse(params.containsKey("target_partitions")); + assertFalse(params.containsKey("table")); + } +} diff --git a/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java b/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java new file mode 100644 index 0000000..b94d136 --- /dev/null +++ b/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.arrow.adbc.core.TypedKey; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +/** + * End-to-end test of the {@code adbc-datafusion} source against the example driver cdylib (built + * from {@code examples/adbc-datafusion-driver}). It drives the full stack: Spark DataSourceV2 -> + * arrow-adbc JNI driver manager -> the DataFusion ADBC cdylib -> our custom provider. + * + *

Requires the cdylib path in the {@code adbc.example.driver.path} system property; the test is + * skipped (not failed) when it is absent, so the build is green without the native artifact. To + * run: + * + *

+ *   (cd examples/adbc-datafusion-driver && cargo build --release)
+ *   mvn -pl spark -am test -Dtest=AdbcSourceTest \
+ *     -Dadbc.example.driver.path=$PWD/rust-target/release/libadbc_datafusion_example_driver.dylib
+ * 
+ */ +class AdbcSourceTest { + + private static final String ENTRYPOINT = "AdbcDatafusionExampleInit"; + private static final String TABLE = "example"; + + private static SparkSession spark; + private static String driverPath; + + @BeforeAll + static void setUp() { + driverPath = System.getProperty("adbc.example.driver.path"); + assumeTrue( + driverPath != null && Files.exists(Path.of(driverPath)), + "set -Dadbc.example.driver.path to the example driver cdylib to run this test"); + // local[8]: several task slots in one executor JVM, so the per-executor connection cache + // (AdbcConnectionPool) is exercised across concurrent tasks. + spark = SparkSession.builder().appName("adbc-source-test").master("local[8]").getOrCreate(); + } + + @AfterAll + static void tearDown() { + if (spark != null) { + spark.stop(); + } + } + + private Dataset load() { + return spark + .read() + .format("adbc-datafusion") + .option("driver", driverPath) + .option("entrypoint", ENTRYPOINT) + .option("table", TABLE) + .load(); + } + + @Test + void fullScanAcrossPartitions() { + Dataset df = load(); + + // Schema inferred via ADBC getTableSchema. + assertEquals("id", df.schema().fields()[0].name()); + assertEquals("name", df.schema().fields()[1].name()); + + // The provider has three partitions; with arrow-adbc >= 0.24 (executePartitioned + // forwarded by the JNI bridge) they surface as multiple Spark input partitions. + assertTrue( + df.rdd().getNumPartitions() >= 2, + "expected multiple Spark partitions, got " + df.rdd().getNumPartitions()); + + List ids = + df.collectAsList().stream().map(r -> r.getLong(0)).sorted().collect(Collectors.toList()); + assertEquals(List.of(1L, 2L, 3L), ids); + } + + @Test + void projectionPushdown() { + List columns = + load().select("name").collectAsList().stream() + .map(r -> r.getString(0)) + .sorted() + .collect(Collectors.toList()); + assertEquals(List.of("alice", "bob", "carol"), columns); + } + + @Test + void filterPushdown() { + List ids = + load().filter("id > 1").collectAsList().stream() + .map(r -> r.getLong(0)) + .sorted() + .collect(Collectors.toList()); + assertEquals(List.of(2L, 3L), ids); + } + + /** + * Regression for the per-executor cache: across a multi-partition scan over many task slots, the + * pool builds the native database (and so runs the driver's {@code ContextInit} provider + * registration) exactly once per executor JVM -- not once per task. Each task still opens its own + * connection off that one cached database. + */ + @Test + void providerRegisteredOncePerExecutor() { + // Force reader tasks to run across the executor's slots. + long rows = load().collectAsList().size(); + assertTrue(rows > 0, "scan returned rows"); + + // All tasks share one driver+options key, so the pool builds exactly one native database for + // this executor JVM regardless of partition/task count. + assertEquals( + 1, + AdbcConnectionPool.databasesBuiltForTesting(), + "expected one native database per executor JVM, got " + + AdbcConnectionPool.databasesBuiltForTesting()); + + // One cached database, but each task opens its own connection off it. + assertTrue( + AdbcConnectionPool.taskConnectionsOpenedForTesting() > 0, + "expected a per-task connection off the one cached database"); + } + + /** + * Proves the driver's database-scoped plan cache is actually used. Every partition descriptor of + * one query carries the same serialized physical plan; the driver deserializes it once and caches + * it across all connections opened from the (per-executor) cached database. So the {@code N} + * per-task connections of one scan must trigger exactly one deserialize, not {@code N}. + * + *

The driver exposes its deserialize count as the read-only int option {@code + * adbc.datafusion.plan_deserialize_count}. The count is JVM-lifetime and shared across tests, so + * we assert on the delta around a single scan, and use a distinct predicate so the plan bytes are + * fresh (not already cached by another test) -- making the expected delta exactly 1. + */ + @Test + void planDeserializedOncePerExecutorNotPerTask() throws Exception { + TypedKey deserializeCount = + new TypedKey<>("adbc.datafusion.plan_deserialize_count", Long.class); + + // Distinct predicate -> distinct physical plan -> plan bytes not cached by another test. + Dataset df = load().filter("id >= 1"); + int partitions = df.rdd().getNumPartitions(); + assertTrue(partitions >= 2, "need multiple partitions to test the collapse, got " + partitions); + + // Same options the scan uses -> same pool cache key -> same cached database (and counter). + LinkedHashMap raw = new LinkedHashMap<>(); + raw.put("driver", driverPath); + raw.put("entrypoint", ENTRYPOINT); + raw.put("table", TABLE); + AdbcOptions opts = AdbcOptions.fromOptions(new CaseInsensitiveStringMap(raw)); + + long before; + try (AdbcConnectionPool.Lease lease = AdbcConnectionPool.acquire(opts)) { + before = lease.connection().getOption(deserializeCount); + } + + df.collectAsList(); // runs read_partition on `partitions` separate task connections + + long after; + try (AdbcConnectionPool.Lease lease = AdbcConnectionPool.acquire(opts)) { + after = lease.connection().getOption(deserializeCount); + } + + assertEquals( + 1L, + after - before, + "the " + + partitions + + " partitions of one query must share ONE plan deserialize (database-scoped cache); " + + "a per-connection cache would deserialize once per task = " + + partitions); + } +} diff --git a/spark/src/test/java/org/apache/datafusion/spark/SubstraitPlanTest.java b/spark/src/test/java/org/apache/datafusion/spark/SubstraitPlanTest.java new file mode 100644 index 0000000..5c8b44f --- /dev/null +++ b/spark/src/test/java/org/apache/datafusion/spark/SubstraitPlanTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.OptionalLong; +import java.util.Set; + +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.StringContains; +import org.junit.jupiter.api.Test; + +import io.substrait.relation.Fetch; +import io.substrait.relation.Filter; +import io.substrait.relation.NamedScan; +import io.substrait.relation.Project; +import io.substrait.relation.Rel; + +/** + * Verifies the Substrait pushdown encoding without needing the native driver: build a plan, then + * re-parse the proto bytes and assert the relation tree and the standard function extensions. + */ +class SubstraitPlanTest { + + private static Schema schema() { + return new Schema( + List.of( + Field.nullable("id", new ArrowType.Int(64, true)), + Field.nullable("name", ArrowType.Utf8.INSTANCE))); + } + + @Test + void buildsScanFilterProjectFetch() throws Exception { + List filters = + List.of(new GreaterThan("id", 1L), new EqualTo("name", "bob")); + byte[] bytes = SubstraitPlan.build("t", schema(), List.of("name"), filters, OptionalLong.of(2)); + + // Re-parse the proto and walk it back into the substrait-java model. + io.substrait.proto.Plan proto = io.substrait.proto.Plan.parseFrom(bytes); + io.substrait.plan.Plan plan = new io.substrait.plan.ProtoPlanConverter().from(proto); + + Rel top = plan.getRoots().get(0).getInput(); + // Fetch(limit) -> Project(projection) -> Filter(predicates) -> NamedScan(table). + Fetch fetch = (Fetch) top; + assertEquals(OptionalLong.of(2), fetch.getCount()); + Project project = (Project) fetch.getInput(); + assertEquals(1, project.getExpressions().size()); // only the kept column "name" + Filter filter = (Filter) project.getInput(); + NamedScan scan = (NamedScan) filter.getInput(); + assertEquals(List.of("t"), scan.getNames()); + + // The interop guarantee: predicates resolve against the STANDARD catalog, the + // same function URIs DataFusion's Substrait consumer recognizes. + String text = proto.toString(); + assertTrue(text.contains("functions_comparison"), "uses standard comparison functions"); + assertTrue(text.contains("functions_boolean"), "uses standard boolean functions"); + } + + @Test + void wholeTableScanWhenNoPushdown() throws Exception { + byte[] bytes = SubstraitPlan.build("t", schema(), null, List.of(), OptionalLong.empty()); + io.substrait.proto.Plan proto = io.substrait.proto.Plan.parseFrom(bytes); + io.substrait.plan.Plan plan = new io.substrait.plan.ProtoPlanConverter().from(proto); + assertTrue(plan.getRoots().get(0).getInput() instanceof NamedScan); + } + + @Test + void canPushWhitelist() { + Set columns = Set.of("id", "name"); + assertTrue(SubstraitPlan.canPush(new GreaterThan("id", 1L), columns)); + assertTrue(SubstraitPlan.canPush(new EqualTo("name", "bob"), columns)); + // Unknown column -> not pushable. + assertFalse(SubstraitPlan.canPush(new GreaterThan("missing", 1L), columns)); + // Unsupported predicate shape -> not pushable. + assertFalse(SubstraitPlan.canPush(new StringContains("name", "b"), columns)); + } +} From ff9a205d10e74cfe40a82e8db72a8e9477c0197d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 1 Jul 2026 15:26:23 -0400 Subject: [PATCH 2/5] feat(spark): cast non-Spark-native Arrow types source-side in ADBC scan The connector's SchemaConverter mapped only a handful of Arrow types and failed load() on anything else, since inferSchema probes the full table schema up front. Extend it to cover every type Spark's ArrowColumnVector can read, and for the types it cannot read directly, push a cast into the scan so executors emit Spark-native Arrow (zero-copy import stays intact). - SchemaConverter: full recursive Arrow -> Spark map. Directly representable types (binary, nested list/struct/map, date, decimal, us timestamp, null) pass through; cast-required types (unsigned ints, Float16, non-us timestamps, time) map to their widened Spark type and expose an arrow_cast target string. u64 -> Decimal(20,0) to stay lossless past i64::MAX. - SqlQuery: wrap cast columns in arrow_cast(col, '') AS col. - AdbcScanImpl: force the SQL wire when the schema needs a cast (Substrait cannot encode unsigned/Float16); partitioning is unaffected (ADBC executePartitioned works on the physical plan, and the cast is a partition-preserving projection). - AdbcScanBuilder: keep filter pushdown off cast columns (a pushed predicate runs against the pre-cast source domain). Tests: SchemaConverterTest and SqlQueryTest cover the type map and the arrow_cast string generation. The example driver gains a second `types` table spanning these Arrow types with known values (built as self-contained single-row batches so sliced-array offsets do not corrupt variable-width columns across the C-data boundary), and AdbcSourceTest asserts the schema, per-value round-trip through the casts, and multi-partition execution. Co-Authored-By: Claude Opus 4.8 --- examples/adbc-datafusion-driver/Cargo.toml | 2 + .../adbc-datafusion-driver/pyspark_e2e.py | 73 +++- examples/adbc-datafusion-driver/src/lib.rs | 9 +- .../adbc-datafusion-driver/src/provider.rs | 200 ++++++++++- .../tests/partitions.rs | 13 +- .../datafusion/spark/AdbcScanBuilder.java | 12 +- .../apache/datafusion/spark/AdbcScanImpl.java | 36 +- .../datafusion/spark/SchemaConverter.java | 337 +++++++++++++++++- .../org/apache/datafusion/spark/SqlQuery.java | 19 +- .../datafusion/spark/AdbcSourceTest.java | 112 +++++- .../datafusion/spark/SchemaConverterTest.java | 208 +++++++++++ .../apache/datafusion/spark/SqlQueryTest.java | 54 +++ 12 files changed, 1036 insertions(+), 39 deletions(-) create mode 100644 spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java create mode 100644 spark/src/test/java/org/apache/datafusion/spark/SqlQueryTest.java diff --git a/examples/adbc-datafusion-driver/Cargo.toml b/examples/adbc-datafusion-driver/Cargo.toml index 58f9d81..dc213c9 100644 --- a/examples/adbc-datafusion-driver/Cargo.toml +++ b/examples/adbc-datafusion-driver/Cargo.toml @@ -37,6 +37,8 @@ adbc_core = "0.23" adbc_ffi = "0.23" datafusion = "53" async-trait = "0.1" +# For building Float16 example data (arrow's half-precision element type). +half = "2" # Build both an rlib (so the provider/driver can be unit-tested or embedded) and # a cdylib (the artifact the arrow-adbc Java driver manager loads at runtime). diff --git a/examples/adbc-datafusion-driver/pyspark_e2e.py b/examples/adbc-datafusion-driver/pyspark_e2e.py index 6cbf80f..edd5405 100644 --- a/examples/adbc-datafusion-driver/pyspark_e2e.py +++ b/examples/adbc-datafusion-driver/pyspark_e2e.py @@ -35,12 +35,18 @@ ADBC_DRIVER_LIB path to libadbc_datafusion_example_driver.{dylib,so} """ +import datetime import glob import os import sys +from decimal import Decimal from pyspark.sql import SparkSession +# Field metadata flag the connector sets on columns it casts source-side to a Spark-native +# layout (SchemaConverter.CAST_METADATA_KEY). +CAST_META_KEY = "org.apache.datafusion.spark.adbc.cast" + REPO = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) LIBEXT = "dylib" if sys.platform == "darwin" else "so" @@ -72,16 +78,21 @@ .getOrCreate() ) -try: - df = ( +def read(table): + return ( spark.read.format("adbc-datafusion") .option("driver", driver_lib) .option("entrypoint", "AdbcDatafusionExampleInit") - .option("table", "example") + .option("table", table) .load() ) - print("=== schema ===") + +try: + # --- `example`: two Spark-native columns, three partitions ------------------------------- + df = read("example") + + print("=== example schema ===") df.printSchema() num_partitions = df.rdd.getNumPartitions() @@ -99,6 +110,58 @@ assert filtered == [2, 3], filtered assert num_partitions >= 2, f"expected multi-partition, got {num_partitions}" - print("\nPYSPARK E2E OK (multi-partition + projection + filter pushdown)") + # --- `types`: the schema-conversion + source-side arrow_cast coverage -------------------- + dft = read("types") + print("=== types schema ===") + dft.printSchema() + + # cast columns are flagged (so filter pushdown stays off them); pass-through ones are not. + cast_cols = {f.name for f in dft.schema.fields if f.metadata.get(CAST_META_KEY)} + print("cast columns:", sorted(cast_cols)) + assert cast_cols == {"channel", "big", "event_time", "score", "tags"}, cast_cols + + # value correctness: collecting whole rows forces the vectorized reader to decode each + # ArrowColumnVector, so a bad cast (wrong layout, unit relabel instead of rescale, unsigned + # overflow) fails here. + rows = {r["id"]: r for r in dft.collect()} + r1, r2, r3 = rows[1], rows[2], rows[3] + + # unsigned UInt16 -> Integer, widened past i16::MAX + assert [r1["channel"], r2["channel"], r3["channel"]] == [100, 40000, 65535] + + # unsigned UInt64 -> Decimal(20,0): lossless for values past i64::MAX (a Long would overflow) + assert r1["big"] == Decimal("18446744073709551615") + assert r2["big"] == Decimal(0) + assert r3["big"] == Decimal("9223372036854775808") + + # nanosecond Timestamp -> microsecond TimestampNTZ, rescaled (not relabeled -> would be ~1970) + assert r1["event_time"] == datetime.datetime(2020, 9, 13, 12, 26, 40) + assert r2["event_time"] == datetime.datetime(2021, 1, 7, 6, 13, 20) + assert r3["event_time"] == datetime.datetime(2021, 5, 3, 0, 0, 0) + + # Float16 -> Float + assert [r1["score"], r2["score"], r3["score"]] == [1.5, 2.5, 3.5] + + # Binary passes through + assert bytes(r1["payload"]) == b"\x01\x02" + assert bytes(r2["payload"]) == b"" + assert bytes(r3["payload"]) == b"\xff\xfe" + + # nested List -> Array (recursive cast) + assert r1["tags"] == [1, 2] + assert r2["tags"] == [] + assert r3["tags"] == [3] + + # nested List> passes through + assert [(x["key"], x["val"]) for x in r1["attrs"]] == [("a", "1")] + assert r2["attrs"] == [] + assert [(x["key"], x["val"]) for x in r3["attrs"]] == [("b", "2"), ("c", "3")] + + # --- filter on a cast column: excluded from pushdown, evaluated by Spark on the cast value - + channel_gt = sorted(r["id"] for r in dft.filter(dft.channel > 100).collect()) + print("filter channel>100 ids:", channel_gt) + assert channel_gt == [2, 3], channel_gt + + print("\nPYSPARK E2E OK (example: multi-partition + pushdown; types: casts + nested + filter)") finally: spark.stop() diff --git a/examples/adbc-datafusion-driver/src/lib.rs b/examples/adbc-datafusion-driver/src/lib.rs index f57556d..e2c89f7 100644 --- a/examples/adbc-datafusion-driver/src/lib.rs +++ b/examples/adbc-datafusion-driver/src/lib.rs @@ -39,7 +39,7 @@ use adbc_core::Driver; use adbc_driver_datafusion::{ContextInit, DataFusionDatabase, DataFusionDriver}; use datafusion::prelude::SessionContext; -pub use provider::ExampleTableProvider; +pub use provider::{ExampleTableProvider, TypesTableProvider}; /// An ADBC driver that registers [`ExampleTableProvider`] into each session. /// @@ -57,6 +57,13 @@ impl Default for ExampleDriver { ExampleTableProvider::TABLE_NAME, Arc::new(ExampleTableProvider::new()), )?; + // A second table whose schema spans the Arrow types the Spark connector must cast + // or map (unsigned, ns timestamp, binary, nested list/struct); the E2E test scans + // it to check schema conversion and source-side arrow_cast with known values. + ctx.register_table( + TypesTableProvider::TABLE_NAME, + Arc::new(TypesTableProvider::new()), + )?; Ok(()) }); ExampleDriver(DataFusionDriver::new_with_context_init(None, init)) diff --git a/examples/adbc-datafusion-driver/src/provider.rs b/examples/adbc-datafusion-driver/src/provider.rs index 175b59a..ff18fe8 100644 --- a/examples/adbc-datafusion-driver/src/provider.rs +++ b/examples/adbc-datafusion-driver/src/provider.rs @@ -15,21 +15,42 @@ // specific language governing permissions and limitations // under the License. -//! The "custom" `TableProvider` this example exposes. +//! The "custom" `TableProvider`s this example exposes. //! //! In a real deployment this is *your* provider crate -- reading your store, -//! your format, your catalog. Here it is a tiny fixed table so the example is -//! self-contained; the only thing that matters for the integration is that it -//! is an ordinary [`datafusion::catalog::TableProvider`]. Execution is delegated -//! to an in-memory [`MemTable`]; swap [`ExampleTableProvider::scan`] for your -//! real source and the rest of the pipeline (ADBC driver, Spark connector) is -//! unchanged. +//! your format, your catalog. Here they are fixed tables so the example is +//! self-contained; the only thing that matters for the integration is that they +//! are ordinary [`datafusion::catalog::TableProvider`]s. Execution is delegated +//! to in-memory [`MemTable`]s; swap the `scan` for your real source and the rest +//! of the pipeline (ADBC driver, Spark connector) is unchanged. +//! +//! Two tables: +//! +//! - [`ExampleTableProvider`] (`example`) -- two Spark-native columns across +//! three partitions. Keeps the partitioned-execution system test simple and +//! lets the connector use its preferred Substrait wire. +//! - [`TypesTableProvider`] (`types`) -- a schema spanning both categories the +//! Spark connector's `SchemaConverter` must handle, so the end-to-end test +//! covers them with known values and no external server: +//! - *directly representable* (pass through to `ArrowColumnVector`): +//! `Int64`, `Utf8`, `Binary`, `List>`; +//! - *cast-required* (source-side `arrow_cast` to a Spark-native layout): +//! unsigned `UInt16`/`UInt64`, nanosecond `Timestamp`, `Float16`, nested +//! `List`. +//! Each row is built as its own self-contained single-row batch (rather than +//! slicing one batch), so every partition's arrays start at offset 0 -- a +//! sliced array carries a non-zero offset into shared buffers, which the Arrow +//! C-data export at the ADBC boundary does not rebase, corrupting +//! variable-width (binary/list) columns on the Spark side. use std::any::Any; use std::sync::Arc; use async_trait::async_trait; -use datafusion::arrow::array::{Int64Array, StringArray}; +use datafusion::arrow::array::{ + ArrayRef, BinaryArray, Float16Array, Int64Array, ListBuilder, StringArray, StringBuilder, + StructBuilder, TimestampNanosecondArray, UInt16Array, UInt16Builder, UInt64Array, +}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::catalog::{Session, TableProvider}; @@ -38,6 +59,7 @@ use datafusion::error::Result; use datafusion::logical_expr::TableType; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::Expr; +use half::f16; /// A custom provider exposing a single fixed table named [`Self::TABLE_NAME`]. #[derive(Debug)] @@ -110,3 +132,165 @@ impl TableProvider for ExampleTableProvider { self.inner.scan(state, projection, filters, limit).await } } + +/// One row of the [`TypesTableProvider`] table, as plain Rust values. +struct Row { + id: i64, + name: &'static str, + payload: &'static [u8], + attrs: &'static [(&'static str, &'static str)], + channel: u16, + big: u64, + /// Nanoseconds since the Unix epoch (no timezone). + event_time: i64, + score: f32, + tags: &'static [u16], +} + +// Values are chosen to exercise the widening edges: channel spans past i16::MAX, big includes +// u64 values past i64::MAX (must stay lossless via Decimal(20,0)), and event_time is nanoseconds +// that must rescale to microseconds rather than be relabeled. +const TYPE_ROWS: [Row; 3] = [ + Row { + id: 1, + name: "alice", + payload: &[0x01, 0x02], + attrs: &[("a", "1")], + channel: 100, + big: u64::MAX, + event_time: 1_600_000_000_000_000_000, // 2020-09-13 + score: 1.5, + tags: &[1, 2], + }, + Row { + id: 2, + name: "bob", + payload: &[], + attrs: &[], + channel: 40_000, + big: 0, + event_time: 1_610_000_000_000_000_000, // 2021-01-07 + score: 2.5, + tags: &[], + }, + Row { + id: 3, + name: "carol", + payload: &[0xff, 0xfe], + attrs: &[("b", "2"), ("c", "3")], + channel: 65_535, + big: 9_223_372_036_854_775_808, // 2^63, past i64::MAX + event_time: 1_620_000_000_000_000_000, // 2021-05-03 + score: 3.5, + tags: &[3], + }, +]; + +/// A provider whose schema exercises every Arrow type the Spark connector's converter handles. +#[derive(Debug)] +pub struct TypesTableProvider { + inner: MemTable, +} + +impl TypesTableProvider { + pub const TABLE_NAME: &'static str = "types"; + + pub fn new() -> Self { + // Build the schema once from single-row prototype arrays so the declared types (in + // particular the nested list/struct types) always match the produced data exactly. + let proto = row_columns(&TYPE_ROWS[0]); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + Field::new("payload", proto[2].data_type().clone(), true), + Field::new("attrs", proto[3].data_type().clone(), true), + Field::new("channel", proto[4].data_type().clone(), true), + Field::new("big", proto[5].data_type().clone(), true), + Field::new("event_time", proto[6].data_type().clone(), true), + Field::new("score", proto[7].data_type().clone(), true), + Field::new("tags", proto[8].data_type().clone(), true), + ])); + + // One self-contained single-row batch per partition (see the module docs on why we do + // not slice a shared batch). + let partitions: Vec> = TYPE_ROWS + .iter() + .map(|row| { + let batch = RecordBatch::try_new(schema.clone(), row_columns(row)) + .expect("types record batch"); + vec![batch] + }) + .collect(); + let inner = MemTable::try_new(schema, partitions).expect("types in-memory table"); + Self { inner } + } +} + +impl Default for TypesTableProvider { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl TableProvider for TypesTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.inner.schema() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + self.inner.scan(state, projection, filters, limit).await + } +} + +/// Build the nine single-row column arrays for one [`Row`], in schema order. +fn row_columns(row: &Row) -> Vec { + let mut tags = ListBuilder::new(UInt16Builder::new()); + for t in row.tags { + tags.values().append_value(*t); + } + tags.append(true); + + let attr_fields = vec![ + Field::new("key", DataType::Utf8, true), + Field::new("val", DataType::Utf8, true), + ]; + let mut attrs = ListBuilder::new(StructBuilder::from_fields(attr_fields, 0)); + for (key, val) in row.attrs { + let s = attrs.values(); + s.field_builder::(0) + .unwrap() + .append_value(key); + s.field_builder::(1) + .unwrap() + .append_value(val); + s.append(true); + } + attrs.append(true); + + vec![ + Arc::new(Int64Array::from(vec![row.id])), + Arc::new(StringArray::from(vec![row.name])), + Arc::new(BinaryArray::from_iter_values([row.payload])), + Arc::new(attrs.finish()) as ArrayRef, + Arc::new(UInt16Array::from(vec![row.channel])), + Arc::new(UInt64Array::from(vec![row.big])), + Arc::new(TimestampNanosecondArray::from(vec![row.event_time])), + Arc::new(Float16Array::from(vec![f16::from_f32(row.score)])), + Arc::new(tags.finish()) as ArrayRef, + ] +} diff --git a/examples/adbc-datafusion-driver/tests/partitions.rs b/examples/adbc-datafusion-driver/tests/partitions.rs index 747e515..c76c2e6 100644 --- a/examples/adbc-datafusion-driver/tests/partitions.rs +++ b/examples/adbc-datafusion-driver/tests/partitions.rs @@ -35,8 +35,11 @@ fn execute_partitions_then_read_each_partition() { // Plan the scan; the provider has three partitions. let mut stmt = conn.new_statement().expect("new_statement"); - stmt.set_sql_query(format!("SELECT id, name FROM {}", ExampleTableProvider::TABLE_NAME)) - .expect("set_sql_query"); + stmt.set_sql_query(format!( + "SELECT id, name FROM {}", + ExampleTableProvider::TABLE_NAME + )) + .expect("set_sql_query"); let result = stmt.execute_partitions().expect("execute_partitions"); assert!( @@ -64,5 +67,9 @@ fn execute_partitions_then_read_each_partition() { } ids.sort(); - assert_eq!(ids, vec![1, 2, 3], "every row read exactly once across partitions"); + assert_eq!( + ids, + vec![1, 2, 3], + "every row read exactly once across partitions" + ); } diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcScanBuilder.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanBuilder.java index b60e6ac..b18fa80 100644 --- a/spark/src/main/java/org/apache/datafusion/spark/AdbcScanBuilder.java +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanBuilder.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.LinkedHashSet; import java.util.List; import java.util.OptionalLong; import java.util.Set; @@ -31,6 +32,7 @@ import org.apache.spark.sql.connector.read.SupportsPushDownLimit; import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; /** @@ -68,7 +70,15 @@ public void pruneColumns(StructType requiredSchema) { @Override public Filter[] pushFilters(Filter[] filters) { - Set columns = Set.of(fullSchema.fieldNames()); + // Push filters only against Spark-native (non-cast) columns: a cast column is cast in the + // scan projection, but a pushed predicate runs against the pre-cast source column, so its + // Spark-domain literal would not match. Those filters are left to Spark (see SchemaConverter). + Set columns = new LinkedHashSet<>(Arrays.asList(fullSchema.fieldNames())); + for (StructField field : fullSchema.fields()) { + if (field.metadata().contains(SchemaConverter.CAST_METADATA_KEY)) { + columns.remove(field.name()); + } + } List pushable = new ArrayList<>(); List residual = new ArrayList<>(); for (Filter f : filters) { diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java index 6c4a4cf..7f2596d 100644 --- a/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java @@ -107,9 +107,31 @@ public InputPartition[] planInputPartitions() { applyTargetPartitions(conn, targetPartitions); Schema arrow = conn.getTableSchema(null, null, options.table()); - byte[] substrait = - SubstraitPlan.build(options.table(), arrow, projection, pushedFilters, limit); - String sql = SqlQuery.build(options.table(), projection, pushedFilters, limit); + + List columns = + SchemaConverter.projectionColumns(arrow, projection); + boolean anyCast = columns.stream().anyMatch(c -> c.castType() != null); + // SELECT * only when no columns are projected away and none need a cast; otherwise the + // columns must be listed so the casts can be injected. + List sqlColumns = + (projection == null && !anyCast) ? null : columns; + String sql = SqlQuery.build(options.table(), sqlColumns, pushedFilters, limit); + + // The casts (unsigned, Float16, non-µs timestamps, time) live only in the SQL projection, so + // any schema needing one must use the SQL wire. The gate is the full schema, not just the + // projection: the Substrait NamedScan declares every field's type, so an unprojected cast + // column would still misdeclare the base schema. Substrait also can't encode several other + // Spark-native Arrow types (binary, nested, decimal, ...); build() throws for those, which we + // likewise treat as "not Substrait-representable" and fall back. + boolean schemaNeedsCast = arrow.getFields().stream().anyMatch(SchemaConverter::needsCast); + byte[] substrait = null; + if (!schemaNeedsCast) { + try { + substrait = SubstraitPlan.build(options.table(), arrow, projection, pushedFilters, limit); + } catch (RuntimeException e) { + substrait = null; + } + } return plan(conn, substrait, sql); } catch (Exception e) { throw new RuntimeException("failed to plan ADBC scan for table " + options.table(), e); @@ -139,9 +161,11 @@ private InputPartition[] plan(AdbcConnection conn, byte[] substrait, String sql) Kind singleKind; byte[] singlePayload; - // Escape hatch: force the SQL wire (e.g. when the Substrait round-trip plans - // to fewer partitions than SQL). Defaults to preferring Substrait. - boolean forceSql = "sql".equalsIgnoreCase(System.getProperty("adbc.wire", "")); + // Force the SQL wire when Substrait can't encode this scan (null plan), or via the escape + // hatch (e.g. when the Substrait round-trip plans to fewer partitions than SQL). Defaults to + // preferring Substrait. + boolean forceSql = + substrait == null || "sql".equalsIgnoreCase(System.getProperty("adbc.wire", "")); AdbcStatement stmt = conn.createStatement(); try { diff --git a/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java b/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java index d299d79..6f2ce51 100644 --- a/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java +++ b/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java @@ -19,48 +19,130 @@ package org.apache.datafusion.spark; +import java.util.ArrayList; +import java.util.List; + import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.MetadataBuilder; +import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; /** - * Converts an Arrow schema (produced by the ADBC scan) into a Spark {@link StructType}. + * Converts an Arrow schema (produced by the ADBC scan) into a Spark {@link StructType}, and plans + * the source-side casts that make the scan emit Spark-native Arrow. * *

Done directly rather than through Spark's {@code ArrowUtils} so the connector depends only on - * our Arrow version, never Spark's bundled one. Covers the primitive types the columnar reader - * produces; unsupported types fail fast. + * our Arrow version, never Spark's bundled one. + * + *

Spark's vectorized {@code ArrowColumnVector} reads a fixed set of Arrow layouts: signed ints, + * 32/64-bit floats, microsecond timestamps, string/binary, decimal, date, and nested + * list/struct/map of those. Two categories of source type need handling: + * + *

    + *
  • Directly representable -- the layout already matches; only the type mapping was + * missing (binary, nested list/struct/map, date, decimal, µs timestamp, null). These pass + * through untouched. + *
  • Cast required -- the layout differs from what {@code ArrowColumnVector} expects, so + * the scan must cast at the source (unsigned ints, Float16, non-µs timestamps, time). We map + * these to the Spark type they will be cast to, and {@link #castTargetString} names + * the Arrow target so {@link SqlQuery} can wrap the column in {@code arrow_cast}. + *
+ * + *

The cast is pushed into the scan (see {@link SqlQuery}), so this converter and the reader only + * ever agree on Spark-native types: the reported Spark type of a column always equals what {@code + * ArrowColumnVector} produces from the (possibly cast) Arrow output. */ final class SchemaConverter { + /** + * Metadata flag set on a top-level Spark field whose source column needs a cast. Read by {@link + * AdbcScanBuilder} to keep filter pushdown off these columns: a pushed predicate runs against the + * pre-cast source column, so its literal would be in the wrong (source, not Spark) domain. + */ + static final String CAST_METADATA_KEY = "org.apache.datafusion.spark.adbc.cast"; + private SchemaConverter() {} + /** A projected output column: its name and, when a cast is required, the Arrow target type. */ + record ProjectionColumn(String name, String castType) {} + static StructType toSparkSchema(Schema arrowSchema) { StructType struct = new StructType(); for (Field field : arrowSchema.getFields()) { - struct = struct.add(field.getName(), toSparkType(field), field.isNullable()); + Metadata metadata = + needsCast(field) + ? new MetadataBuilder().putBoolean(CAST_METADATA_KEY, true).build() + : Metadata.empty(); + struct = struct.add(field.getName(), toSparkType(field), field.isNullable(), metadata); } return struct; } + /** + * Plan the projected columns for the pushed scan: the kept columns in output order, each with the + * Arrow cast target (or {@code null} to pass through). + * + * @param schema the full source Arrow schema + * @param projection kept column names in output order, or {@code null} for all columns + */ + static List projectionColumns(Schema schema, List projection) { + List fields = new ArrayList<>(); + if (projection == null) { + fields.addAll(schema.getFields()); + } else { + for (String name : projection) { + fields.add(schema.findField(name)); + } + } + List columns = new ArrayList<>(fields.size()); + for (Field field : fields) { + columns.add(new ProjectionColumn(field.getName(), castTargetString(field))); + } + return columns; + } + + /** The Arrow type string for {@code arrow_cast}, or {@code null} if the column needs no cast. */ + static String castTargetString(Field field) { + return needsCast(field) ? renderArrowType(field) : null; + } + + // --- Arrow type -> Spark type --------------------------------------------- + static DataType toSparkType(Field field) { ArrowType type = field.getType(); + if (type instanceof ArrowType.Bool) { + return DataTypes.BooleanType; + } if (type instanceof ArrowType.Int i) { - if (!i.getIsSigned()) { - throw unsupported(field); + if (i.getIsSigned()) { + return switch (i.getBitWidth()) { + case 8 -> DataTypes.ByteType; + case 16 -> DataTypes.ShortType; + case 32 -> DataTypes.IntegerType; + case 64 -> DataTypes.LongType; + default -> throw unsupported(field); + }; } + // Unsigned: widened to the next signed width it will be cast to (u64 has no lossless + // signed 64-bit target, so it becomes Decimal(20,0)). return switch (i.getBitWidth()) { - case 8 -> DataTypes.ByteType; - case 16 -> DataTypes.ShortType; - case 32 -> DataTypes.IntegerType; - case 64 -> DataTypes.LongType; + case 8 -> DataTypes.ShortType; + case 16 -> DataTypes.IntegerType; + case 32 -> DataTypes.LongType; + case 64 -> DataTypes.createDecimalType(20, 0); default -> throw unsupported(field); }; } if (type instanceof ArrowType.FloatingPoint fp) { + // Float16 has no Spark type; it is widened to Float. return fp.getPrecision() == FloatingPointPrecision.DOUBLE ? DataTypes.DoubleType : DataTypes.FloatType; @@ -68,12 +150,245 @@ static DataType toSparkType(Field field) { if (type instanceof ArrowType.Utf8 || type instanceof ArrowType.LargeUtf8) { return DataTypes.StringType; } + if (type instanceof ArrowType.Binary + || type instanceof ArrowType.LargeBinary + || type instanceof ArrowType.FixedSizeBinary) { + return DataTypes.BinaryType; + } + if (type instanceof ArrowType.Date) { + return DataTypes.DateType; + } + if (type instanceof ArrowType.Timestamp ts) { + // Unit is normalized to microseconds by the cast; the timezone decides NTZ vs zoned. + return ts.getTimezone() == null ? DataTypes.TimestampNTZType : DataTypes.TimestampType; + } + if (type instanceof ArrowType.Time t) { + // Spark has no time-of-day accessor; the value is cast to its raw integer of ticks. + return t.getBitWidth() == 32 ? DataTypes.IntegerType : DataTypes.LongType; + } + if (type instanceof ArrowType.Decimal d) { + return DataTypes.createDecimalType(d.getPrecision(), d.getScale()); + } + if (type instanceof ArrowType.Null) { + return DataTypes.NullType; + } + if (type instanceof ArrowType.Duration) { + return DataTypes.createDayTimeIntervalType(); + } + if (type instanceof ArrowType.Interval iv) { + return switch (iv.getUnit()) { + case YEAR_MONTH -> DataTypes.createYearMonthIntervalType(); + case DAY_TIME -> DataTypes.createDayTimeIntervalType(); + default -> throw unsupported(field); + }; + } + if (type instanceof ArrowType.List + || type instanceof ArrowType.LargeList + || type instanceof ArrowType.FixedSizeList) { + Field element = field.getChildren().get(0); + return DataTypes.createArrayType(toSparkType(element), element.isNullable()); + } + if (type instanceof ArrowType.Struct) { + List children = new ArrayList<>(); + for (Field child : field.getChildren()) { + children.add( + DataTypes.createStructField(child.getName(), toSparkType(child), child.isNullable())); + } + return DataTypes.createStructType(children); + } + if (type instanceof ArrowType.Map) { + Field entries = field.getChildren().get(0); + Field key = entries.getChildren().get(0); + Field value = entries.getChildren().get(1); + return DataTypes.createMapType(toSparkType(key), toSparkType(value), value.isNullable()); + } + throw unsupported(field); + } + + // --- Cast planning -------------------------------------------------------- + + /** + * Whether the column (recursively) has any layout that Spark's reader cannot consume directly. + */ + static boolean needsCast(Field field) { + ArrowType type = field.getType(); + if (type instanceof ArrowType.Int i && !i.getIsSigned()) { + return true; + } + if (type instanceof ArrowType.FloatingPoint fp + && fp.getPrecision() == FloatingPointPrecision.HALF) { + return true; + } + if (type instanceof ArrowType.Timestamp ts && ts.getUnit() != TimeUnit.MICROSECOND) { + return true; + } + if (type instanceof ArrowType.Time) { + return true; + } + for (Field child : field.getChildren()) { + if (needsCast(child)) { + return true; + } + } + return false; + } + + /** + * Render the widened Arrow type as an {@code arrow_cast} type string (the reversible {@code + * arrow::datatypes::DataType} display form that DataFusion's {@code arrow_cast} parses). Cast + * leaves become their widened target; everything else is rendered as-is so a nested cast carries + * its unchanged siblings along. + */ + private static String renderArrowType(Field field) { + ArrowType type = field.getType(); if (type instanceof ArrowType.Bool) { - return DataTypes.BooleanType; + return "Boolean"; + } + if (type instanceof ArrowType.Int i) { + if (i.getIsSigned()) { + return "Int" + i.getBitWidth(); + } + return switch (i.getBitWidth()) { + case 8 -> "Int16"; + case 16 -> "Int32"; + case 32 -> "Int64"; + case 64 -> "Decimal128(20, 0)"; + default -> throw unsupported(field); + }; + } + if (type instanceof ArrowType.FloatingPoint fp) { + return fp.getPrecision() == FloatingPointPrecision.DOUBLE ? "Float64" : "Float32"; + } + if (type instanceof ArrowType.Utf8) { + return "Utf8"; + } + if (type instanceof ArrowType.LargeUtf8) { + return "LargeUtf8"; + } + if (type instanceof ArrowType.Binary) { + return "Binary"; + } + if (type instanceof ArrowType.LargeBinary) { + return "LargeBinary"; + } + if (type instanceof ArrowType.FixedSizeBinary fb) { + return "FixedSizeBinary(" + fb.getByteWidth() + ")"; + } + if (type instanceof ArrowType.Date d) { + return switch (d.getUnit()) { + case DAY -> "Date32"; + case MILLISECOND -> "Date64"; + }; + } + if (type instanceof ArrowType.Timestamp ts) { + // Normalize to microseconds, preserving the timezone. + return ts.getTimezone() == null + ? "Timestamp(Microsecond)" + : "Timestamp(Microsecond, \"" + ts.getTimezone() + "\")"; + } + if (type instanceof ArrowType.Time t) { + return t.getBitWidth() == 32 ? "Int32" : "Int64"; + } + if (type instanceof ArrowType.Decimal d) { + String kind = d.getBitWidth() == 256 ? "Decimal256" : "Decimal128"; + return kind + "(" + d.getPrecision() + ", " + d.getScale() + ")"; + } + if (type instanceof ArrowType.Duration dur) { + return "Duration(" + timeUnitName(dur.getUnit()) + ")"; + } + if (type instanceof ArrowType.Interval iv) { + return "Interval(" + intervalUnitName(iv.getUnit()) + ")"; + } + if (type instanceof ArrowType.Null) { + return "Null"; + } + if (type instanceof ArrowType.List) { + return "List(" + listChild(field.getChildren().get(0)) + ")"; + } + if (type instanceof ArrowType.LargeList) { + return "LargeList(" + listChild(field.getChildren().get(0)) + ")"; + } + if (type instanceof ArrowType.FixedSizeList fsl) { + return "FixedSizeList(" + + fsl.getListSize() + + " x " + + listChild(field.getChildren().get(0)) + + ")"; + } + if (type instanceof ArrowType.Struct) { + StringBuilder sb = new StringBuilder("Struct("); + List children = field.getChildren(); + for (int i = 0; i < children.size(); i++) { + if (i > 0) { + sb.append(", "); + } + sb.append(structField(children.get(i))); + } + return sb.append(")").toString(); + } + if (type instanceof ArrowType.Map m) { + Field entries = field.getChildren().get(0); + return "Map(" + + structField(entries) + + ", " + + (m.getKeysSorted() ? "sorted" : "unsorted") + + ")"; } throw unsupported(field); } + /** {@code [, field: 'name']} -- the list/fixed-size-list child form. */ + private static String listChild(Field field) { + String rendered = nullability(field) + renderArrowType(field); + // The default list-field name ("item") is elided by the display form. + return "item".equals(field.getName()) + ? rendered + : rendered + ", field: '" + field.getName() + "'"; + } + + /** {@code "name": } -- the struct/map field form. */ + private static String structField(Field field) { + return debugQuote(field.getName()) + ": " + nullability(field) + renderArrowType(field); + } + + private static String nullability(Field field) { + return field.isNullable() ? "" : "non-null "; + } + + private static String timeUnitName(TimeUnit unit) { + return switch (unit) { + case SECOND -> "Second"; + case MILLISECOND -> "Millisecond"; + case MICROSECOND -> "Microsecond"; + case NANOSECOND -> "Nanosecond"; + }; + } + + private static String intervalUnitName(IntervalUnit unit) { + return switch (unit) { + case YEAR_MONTH -> "YearMonth"; + case DAY_TIME -> "DayTime"; + case MONTH_DAY_NANO -> "MonthDayNano"; + }; + } + + /** Reproduce Rust's {@code {:?}} string quoting used by the Arrow display form. */ + private static String debugQuote(String s) { + StringBuilder sb = new StringBuilder("\""); + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + switch (c) { + case '"' -> sb.append("\\\""); + case '\\' -> sb.append("\\\\"); + case '\n' -> sb.append("\\n"); + case '\r' -> sb.append("\\r"); + case '\t' -> sb.append("\\t"); + default -> sb.append(c); + } + } + return sb.append("\"").toString(); + } + private static IllegalArgumentException unsupported(Field field) { return new IllegalArgumentException( "unsupported Arrow type for column '" + field.getName() + "': " + field.getType()); diff --git a/spark/src/main/java/org/apache/datafusion/spark/SqlQuery.java b/spark/src/main/java/org/apache/datafusion/spark/SqlQuery.java index 07cfb0d..92879c5 100644 --- a/spark/src/main/java/org/apache/datafusion/spark/SqlQuery.java +++ b/spark/src/main/java/org/apache/datafusion/spark/SqlQuery.java @@ -23,6 +23,7 @@ import java.util.OptionalLong; import java.util.stream.Collectors; +import org.apache.datafusion.spark.SchemaConverter.ProjectionColumn; import org.apache.spark.sql.sources.And; import org.apache.spark.sql.sources.EqualTo; import org.apache.spark.sql.sources.Filter; @@ -46,18 +47,22 @@ * *

Identifiers are double-quoted (ANSI, DataFusion's default) and string literals single-quoted, * both with doubling-based escaping. + * + *

Columns whose source Arrow type is not Spark-native (see {@link SchemaConverter#needsCast}) + * are wrapped in {@code arrow_cast(col, '')} and re-aliased to their original name, so + * the scan emits Spark-native Arrow and the output column names still match the reported schema. */ final class SqlQuery { private SqlQuery() {} static String build( - String table, List projection, List filters, OptionalLong limit) { + String table, List columns, List filters, OptionalLong limit) { StringBuilder sql = new StringBuilder("SELECT "); - if (projection == null || projection.isEmpty()) { + if (columns == null || columns.isEmpty()) { sql.append("*"); } else { - sql.append(projection.stream().map(SqlQuery::quoteId).collect(Collectors.joining(", "))); + sql.append(columns.stream().map(SqlQuery::column).collect(Collectors.joining(", "))); } sql.append(" FROM ").append(quoteId(table)); if (!filters.isEmpty()) { @@ -70,6 +75,14 @@ static String build( return sql.toString(); } + private static String column(ProjectionColumn c) { + if (c.castType() == null) { + return quoteId(c.name()); + } + // arrow_cast renames its output, so alias back to the source name. + return "arrow_cast(" + quoteId(c.name()) + ", '" + c.castType() + "') AS " + quoteId(c.name()); + } + private static String predicate(Filter f) { if (f instanceof EqualTo e) { return binary(e.attribute(), "=", e.value()); diff --git a/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java b/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java index b94d136..8fffd09 100644 --- a/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java +++ b/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java @@ -19,20 +19,28 @@ package org.apache.datafusion.spark; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue; +import java.math.BigDecimal; import java.nio.file.Files; import java.nio.file.Path; +import java.time.LocalDateTime; import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; import org.apache.arrow.adbc.core.TypedKey; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -57,6 +65,8 @@ class AdbcSourceTest { private static final String ENTRYPOINT = "AdbcDatafusionExampleInit"; private static final String TABLE = "example"; + // A second table whose schema spans the Arrow types the connector maps or casts. + private static final String TYPES_TABLE = "types"; private static SparkSession spark; private static String driverPath; @@ -80,15 +90,115 @@ static void tearDown() { } private Dataset load() { + return load(TABLE); + } + + private Dataset load(String table) { return spark .read() .format("adbc-datafusion") .option("driver", driverPath) .option("entrypoint", ENTRYPOINT) - .option("table", TABLE) + .option("table", table) .load(); } + /** + * The {@code types} table exercises the full converter: directly-representable columns pass + * through, cast-required columns (unsigned, ns timestamp, Float16, nested {@code List}) + * are cast to a Spark-native layout by a source-side {@code arrow_cast} pushed into the scan. + * This asserts the reported Spark types and the cast-column metadata flags. + */ + @Test + void typesSchemaMapsAndFlagsCasts() { + StructType schema = load(TYPES_TABLE).schema(); + + assertEquals(DataTypes.BinaryType, schema.apply("payload").dataType()); + assertEquals(DataTypes.IntegerType, schema.apply("channel").dataType()); + assertEquals(DataTypes.createDecimalType(20, 0), schema.apply("big").dataType()); + assertEquals(DataTypes.TimestampNTZType, schema.apply("event_time").dataType()); + assertEquals(DataTypes.FloatType, schema.apply("score").dataType()); + assertEquals( + DataTypes.createArrayType(DataTypes.IntegerType, true), schema.apply("tags").dataType()); + + // Cast columns are flagged (so filter pushdown stays off them); pass-through columns are not. + assertTrue(schema.apply("channel").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("big").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("event_time").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("score").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("tags").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertFalse(schema.apply("payload").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertFalse(schema.apply("attrs").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + } + + /** + * The cast-requiring scan still parallelizes. Casts force the SQL wire (Substrait can't encode + * unsigned/Float16), but ADBC {@code executePartitioned} is wire-independent -- it partitions on + * the physical plan's output partitioning, and the source-side {@code arrow_cast} is a + * partition-preserving projection -- so the {@code types} scan gets one Spark partition per + * driver partition, exactly like the Substrait-wire {@code example} scan. + */ + @Test + void typesScanIsMultiPartitionDespiteCasts() { + assertTrue( + load(TYPES_TABLE).rdd().getNumPartitions() >= 2, + "expected multiple Spark partitions for the cast (SQL-wire) scan, got " + + load(TYPES_TABLE).rdd().getNumPartitions()); + } + + /** Every column decodes to the expected Spark-native value -- casts are value-correct. */ + @Test + void typesValuesRoundTripThroughCasts() { + Map byId = + load(TYPES_TABLE).collectAsList().stream() + .collect(Collectors.toMap(r -> r.getLong(r.fieldIndex("id")), r -> r)); + assertEquals(Set.of(1L, 2L, 3L), byId.keySet()); + Row r1 = byId.get(1L); + Row r2 = byId.get(2L); + Row r3 = byId.get(3L); + + // unsigned UInt16 -> Integer, widened past i16::MAX. + assertEquals(100, r1.getInt(r1.fieldIndex("channel"))); + assertEquals(40_000, r2.getInt(r2.fieldIndex("channel"))); + assertEquals(65_535, r3.getInt(r3.fieldIndex("channel"))); + + // unsigned UInt64 -> Decimal(20,0): lossless for values past i64::MAX (a Long would overflow). + assertEquals(0, new BigDecimal("18446744073709551615").compareTo(bigValue(r1))); + assertEquals(0, BigDecimal.ZERO.compareTo(bigValue(r2))); + assertEquals(0, new BigDecimal("9223372036854775808").compareTo(bigValue(r3))); + + // nanosecond Timestamp -> microsecond TimestampNTZ, rescaled (a relabel would land near 1970). + assertEquals(LocalDateTime.of(2020, 9, 13, 12, 26, 40), r1.getAs("event_time")); + assertEquals(LocalDateTime.of(2021, 1, 7, 6, 13, 20), r2.getAs("event_time")); + assertEquals(LocalDateTime.of(2021, 5, 3, 0, 0, 0), r3.getAs("event_time")); + + // Float16 -> Float. + assertEquals(1.5f, r1.getFloat(r1.fieldIndex("score"))); + assertEquals(2.5f, r2.getFloat(r2.fieldIndex("score"))); + + // Binary passes through. + assertArrayEquals(new byte[] {0x01, 0x02}, (byte[]) r1.getAs("payload")); + assertArrayEquals(new byte[] {}, (byte[]) r2.getAs("payload")); + assertArrayEquals(new byte[] {(byte) 0xff, (byte) 0xfe}, (byte[]) r3.getAs("payload")); + + // nested List -> Array (recursive cast). + assertEquals(List.of(1, 2), r1.getList(r1.fieldIndex("tags"))); + assertEquals(List.of(), r2.getList(r2.fieldIndex("tags"))); + assertEquals(List.of(3), r3.getList(r3.fieldIndex("tags"))); + + // nested List> passes through. + List attrs3 = r3.getList(r3.fieldIndex("attrs")); + assertEquals(2, attrs3.size()); + assertEquals("b", attrs3.get(0).getAs("key")); + assertEquals("2", attrs3.get(0).getAs("val")); + assertEquals("c", attrs3.get(1).getAs("key")); + assertEquals(List.of(), r2.getList(r2.fieldIndex("attrs"))); + } + + private static BigDecimal bigValue(Row row) { + return row.getDecimal(row.fieldIndex("big")); + } + @Test void fullScanAcrossPartitions() { Dataset df = load(); diff --git a/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java b/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java new file mode 100644 index 0000000..e092704 --- /dev/null +++ b/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; + +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.datafusion.spark.SchemaConverter.ProjectionColumn; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for the Arrow -> Spark schema mapping and the {@code arrow_cast} target planning: one + * case per representable Arrow type, plus the cast-requiring types (unsigned, Float16, non-µs + * timestamps) and nested forms. + */ +class SchemaConverterTest { + + private static Field nullable(String name, ArrowType type) { + return Field.nullable(name, type); + } + + private static Field nullable(String name, ArrowType type, Field... children) { + return new Field(name, FieldType.nullable(type), List.of(children)); + } + + private static DataType spark(Field field) { + return SchemaConverter.toSparkType(field); + } + + // --- directly representable (A): mapped, no cast -------------------------- + + @Test + void primitivesMapWithoutCast() { + assertEquals(DataTypes.ByteType, spark(nullable("c", new ArrowType.Int(8, true)))); + assertEquals(DataTypes.ShortType, spark(nullable("c", new ArrowType.Int(16, true)))); + assertEquals(DataTypes.IntegerType, spark(nullable("c", new ArrowType.Int(32, true)))); + assertEquals(DataTypes.LongType, spark(nullable("c", new ArrowType.Int(64, true)))); + assertEquals( + DataTypes.DoubleType, + spark(nullable("c", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)))); + assertEquals(DataTypes.StringType, spark(nullable("c", ArrowType.Utf8.INSTANCE))); + assertEquals(DataTypes.BooleanType, spark(nullable("c", ArrowType.Bool.INSTANCE))); + + assertNull(SchemaConverter.castTargetString(nullable("c", new ArrowType.Int(32, true)))); + } + + @Test + void binaryMapsToBinary() { + assertEquals(DataTypes.BinaryType, spark(nullable("c", ArrowType.Binary.INSTANCE))); + assertEquals(DataTypes.BinaryType, spark(nullable("c", ArrowType.LargeBinary.INSTANCE))); + assertEquals(DataTypes.BinaryType, spark(nullable("c", new ArrowType.FixedSizeBinary(16)))); + assertNull(SchemaConverter.castTargetString(nullable("c", ArrowType.Binary.INSTANCE))); + } + + @Test + void dateDecimalNull() { + assertEquals(DataTypes.DateType, spark(nullable("c", new ArrowType.Date(DateUnit.DAY)))); + assertEquals( + DataTypes.createDecimalType(10, 2), + spark(nullable("c", new ArrowType.Decimal(10, 2, 128)))); + assertEquals(DataTypes.NullType, spark(nullable("c", ArrowType.Null.INSTANCE))); + } + + @Test + void microsecondTimestampNeedsNoCast() { + Field ntz = nullable("c", new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)); + assertEquals(DataTypes.TimestampNTZType, spark(ntz)); + assertNull(SchemaConverter.castTargetString(ntz)); + + Field zoned = nullable("c", new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")); + assertEquals(DataTypes.TimestampType, spark(zoned)); + assertNull(SchemaConverter.castTargetString(zoned)); + } + + @Test + void nestedListStructMap() { + Field struct = + nullable( + "s", + ArrowType.Struct.INSTANCE, + nullable("first", ArrowType.Utf8.INSTANCE), + nullable("second", ArrowType.Utf8.INSTANCE)); + Field listOfStruct = nullable("metadata", new ArrowType.List(), struct); + + DataType mapped = spark(listOfStruct); + assertTrue(mapped instanceof org.apache.spark.sql.types.ArrayType); + // A pure-(A) nested column needs no cast. + assertNull(SchemaConverter.castTargetString(listOfStruct)); + } + + // --- cast required (B): mapped to widened type + arrow_cast target -------- + + @Test + void unsignedIntsWidenAndCast() { + assertEquals(DataTypes.ShortType, spark(nullable("c", new ArrowType.Int(8, false)))); + assertEquals(DataTypes.IntegerType, spark(nullable("c", new ArrowType.Int(16, false)))); + assertEquals(DataTypes.LongType, spark(nullable("c", new ArrowType.Int(32, false)))); + assertEquals( + DataTypes.createDecimalType(20, 0), spark(nullable("c", new ArrowType.Int(64, false)))); + + assertEquals( + "Int16", SchemaConverter.castTargetString(nullable("c", new ArrowType.Int(8, false)))); + assertEquals( + "Int32", SchemaConverter.castTargetString(nullable("c", new ArrowType.Int(16, false)))); + assertEquals( + "Int64", SchemaConverter.castTargetString(nullable("c", new ArrowType.Int(32, false)))); + assertEquals( + "Decimal128(20, 0)", + SchemaConverter.castTargetString(nullable("c", new ArrowType.Int(64, false)))); + } + + @Test + void float16WidensToFloat() { + Field f = nullable("c", new ArrowType.FloatingPoint(FloatingPointPrecision.HALF)); + assertEquals(DataTypes.FloatType, spark(f)); + assertEquals("Float32", SchemaConverter.castTargetString(f)); + } + + @Test + void nonMicrosecondTimestampRescales() { + Field ns = nullable("t", new ArrowType.Timestamp(TimeUnit.NANOSECOND, null)); + assertEquals(DataTypes.TimestampNTZType, spark(ns)); + assertEquals("Timestamp(Microsecond)", SchemaConverter.castTargetString(ns)); + + Field zoned = nullable("t", new ArrowType.Timestamp(TimeUnit.NANOSECOND, "UTC")); + assertEquals(DataTypes.TimestampType, spark(zoned)); + assertEquals("Timestamp(Microsecond, \"UTC\")", SchemaConverter.castTargetString(zoned)); + } + + @Test + void nestedListOfUnsignedCastsRecursively() { + Field list = + nullable("ids", new ArrowType.List(), nullable("item", new ArrowType.Int(16, false))); + DataType mapped = spark(list); + assertEquals(DataTypes.createArrayType(DataTypes.IntegerType, true), mapped); + assertEquals("List(Int32)", SchemaConverter.castTargetString(list)); + } + + // --- schema-level: metadata flag + projection planning ------------------- + + @Test + void castColumnsTaggedInSchemaMetadata() { + Schema schema = + new Schema( + List.of( + Field.nullable("id", new ArrowType.Int(64, true)), + Field.nullable("channel", new ArrowType.Int(16, false)), + Field.nullable("ts", new ArrowType.Timestamp(TimeUnit.NANOSECOND, null)))); + StructType spark = SchemaConverter.toSparkSchema(schema); + + assertTrue(spark.apply("id").metadata().isEmpty()); + assertTrue(spark.apply("channel").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(spark.apply("ts").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + } + + @Test + void projectionColumnsCarryCastTargets() { + Schema schema = + new Schema( + List.of( + Field.nullable("id", new ArrowType.Int(64, true)), + Field.nullable("channel", new ArrowType.Int(16, false)))); + + // All columns. + List all = SchemaConverter.projectionColumns(schema, null); + assertEquals(2, all.size()); + assertEquals("id", all.get(0).name()); + assertNull(all.get(0).castType()); + assertEquals("channel", all.get(1).name()); + assertEquals("Int32", all.get(1).castType()); + + // Projected subset, order preserved. + List pruned = SchemaConverter.projectionColumns(schema, List.of("channel")); + assertEquals(1, pruned.size()); + assertEquals("channel", pruned.get(0).name()); + assertEquals("Int32", pruned.get(0).castType()); + } +} diff --git a/spark/src/test/java/org/apache/datafusion/spark/SqlQueryTest.java b/spark/src/test/java/org/apache/datafusion/spark/SqlQueryTest.java new file mode 100644 index 0000000..e9b6aa9 --- /dev/null +++ b/spark/src/test/java/org/apache/datafusion/spark/SqlQueryTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; +import java.util.OptionalLong; + +import org.apache.datafusion.spark.SchemaConverter.ProjectionColumn; +import org.apache.spark.sql.sources.Filter; +import org.junit.jupiter.api.Test; + +/** Verifies the SQL fallback wire, including {@code arrow_cast} injection for cast columns. */ +class SqlQueryTest { + + private static final List NO_FILTERS = List.of(); + + @Test + void selectStarWhenNoColumns() { + assertEquals( + "SELECT * FROM \"t\"", SqlQuery.build("t", null, NO_FILTERS, OptionalLong.empty())); + } + + @Test + void castColumnsWrappedInArrowCastAndAliased() { + List columns = + List.of( + new ProjectionColumn("id", null), + new ProjectionColumn("channel", "Int32"), + new ProjectionColumn("ts", "Timestamp(Microsecond)")); + assertEquals( + "SELECT \"id\", arrow_cast(\"channel\", 'Int32') AS \"channel\", " + + "arrow_cast(\"ts\", 'Timestamp(Microsecond)') AS \"ts\" FROM \"t\"", + SqlQuery.build("t", columns, NO_FILTERS, OptionalLong.empty())); + } +} From 3ecb9715d2bdd9f48f83b0bf48804dbf57ff7f40 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 1 Jul 2026 15:57:27 -0400 Subject: [PATCH 3/5] fix(spark): apply casts on count()/empty-projection scans count() prunes the pushed projection to empty, which rendered a bare SELECT * -- returning the raw, uncast schema. A non-Spark-native column (e.g. Timestamp(NANOSECOND), unsigned) then reached the reader and failed with UNSUPPORTED_ARROWTYPE. Only the empty-projection case leaked; the all-columns (null projection) path already injected casts. When the projection prunes to empty and the table has any cast column, emit a single readable probe column instead of SELECT * (SchemaConverter.probeColumn: prefer a castless column, else the first column with its cast applied). A count/column-less scan only needs the row count, and the emitted stream stays Spark-native. All-native tables keep plain SELECT *. Tests: SchemaConverterTest.probeColumn* (unit) and AdbcSourceTest.typesCountWorks* (E2E, reproduces the failure); the example pyspark harness gains a types-table count() assertion. Co-Authored-By: Claude Opus 4.8 --- .../adbc-datafusion-driver/pyspark_e2e.py | 6 +++++ .../apache/datafusion/spark/AdbcScanImpl.java | 22 ++++++++++++------ .../datafusion/spark/SchemaConverter.java | 17 ++++++++++++++ .../datafusion/spark/AdbcSourceTest.java | 11 +++++++++ .../datafusion/spark/SchemaConverterTest.java | 23 +++++++++++++++++++ 5 files changed, 72 insertions(+), 7 deletions(-) diff --git a/examples/adbc-datafusion-driver/pyspark_e2e.py b/examples/adbc-datafusion-driver/pyspark_e2e.py index edd5405..9c20ecf 100644 --- a/examples/adbc-datafusion-driver/pyspark_e2e.py +++ b/examples/adbc-datafusion-driver/pyspark_e2e.py @@ -120,6 +120,12 @@ def read(table): print("cast columns:", sorted(cast_cols)) assert cast_cols == {"channel", "big", "event_time", "score", "tags"}, cast_cols + # count() prunes the projection to empty; the connector must not emit a bare SELECT * (which + # would return the raw, uncast schema and fail the reader on the ns timestamp / unsigned ids). + types_count = dft.count() + print("types count:", types_count) + assert types_count == 3, types_count + # value correctness: collecting whole rows forces the vectorized reader to decode each # ArrowColumnVector, so a bad cast (wrong layout, unit relabel instead of rescale, unsigned # overflow) fails here. diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java index 7f2596d..e2f589a 100644 --- a/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java @@ -108,8 +108,23 @@ public InputPartition[] planInputPartitions() { Schema arrow = conn.getTableSchema(null, null, options.table()); + // The casts (unsigned, Float16, non-µs timestamps, time) live only in the SQL projection, so + // any schema needing one must use the SQL wire. The gate is the full schema, not just the + // projection: the Substrait NamedScan declares every field's type, so an unprojected cast + // column would still misdeclare the base schema. Substrait also can't encode several other + // Spark-native Arrow types (binary, nested, decimal, ...); build() throws for those, which we + // likewise treat as "not Substrait-representable" and fall back. + boolean schemaNeedsCast = arrow.getFields().stream().anyMatch(SchemaConverter::needsCast); + List columns = SchemaConverter.projectionColumns(arrow, projection); + // count() (and other column-less reads) prunes the projection to empty. A bare SELECT * + // would then return the raw, uncast schema and the reader would fail on a non-Spark-native + // column, so when the table has any cast column, emit a single readable probe column + // instead -- the row count is all such a scan needs. + if (columns.isEmpty() && schemaNeedsCast) { + columns = List.of(SchemaConverter.probeColumn(arrow)); + } boolean anyCast = columns.stream().anyMatch(c -> c.castType() != null); // SELECT * only when no columns are projected away and none need a cast; otherwise the // columns must be listed so the casts can be injected. @@ -117,13 +132,6 @@ public InputPartition[] planInputPartitions() { (projection == null && !anyCast) ? null : columns; String sql = SqlQuery.build(options.table(), sqlColumns, pushedFilters, limit); - // The casts (unsigned, Float16, non-µs timestamps, time) live only in the SQL projection, so - // any schema needing one must use the SQL wire. The gate is the full schema, not just the - // projection: the Substrait NamedScan declares every field's type, so an unprojected cast - // column would still misdeclare the base schema. Substrait also can't encode several other - // Spark-native Arrow types (binary, nested, decimal, ...); build() throws for those, which we - // likewise treat as "not Substrait-representable" and fall back. - boolean schemaNeedsCast = arrow.getFields().stream().anyMatch(SchemaConverter::needsCast); byte[] substrait = null; if (!schemaNeedsCast) { try { diff --git a/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java b/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java index 6f2ce51..b8c7357 100644 --- a/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java +++ b/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java @@ -114,6 +114,23 @@ static String castTargetString(Field field) { return needsCast(field) ? renderArrowType(field) : null; } + /** + * A single Spark-readable column for a column-less scan (e.g. {@code count()}, whose projection + * Catalyst prunes to empty). Such a scan only needs a row count, but the emitted stream must + * still be Spark-native -- a bare {@code SELECT *} would return the raw, uncast schema and the + * reader would fail on the first non-Spark-native column. Prefers a column that needs no cast + * (cheapest); falls back to the first column with its cast applied when every column needs one. + */ + static ProjectionColumn probeColumn(Schema schema) { + for (Field field : schema.getFields()) { + if (!needsCast(field)) { + return new ProjectionColumn(field.getName(), null); + } + } + Field first = schema.getFields().get(0); + return new ProjectionColumn(first.getName(), castTargetString(first)); + } + // --- Arrow type -> Spark type --------------------------------------------- static DataType toSparkType(Field field) { diff --git a/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java b/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java index 8fffd09..c055739 100644 --- a/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java +++ b/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java @@ -146,6 +146,17 @@ void typesScanIsMultiPartitionDespiteCasts() { + load(TYPES_TABLE).rdd().getNumPartitions()); } + /** + * count() over a cast-requiring table. Catalyst prunes the projection to empty for a count, and a + * bare {@code SELECT *} would return the raw (uncast) schema -- the reader would then reject a + * non-Spark-native column (e.g. {@code Timestamp(NANOSECOND)}). The connector emits a single + * readable probe column instead, so the row count still comes back. + */ + @Test + void typesCountWorksWithoutMaterializingCastColumns() { + assertEquals(3L, load(TYPES_TABLE).count()); + } + /** Every column decodes to the expected Spark-native value -- casts are value-correct. */ @Test void typesValuesRoundTripThroughCasts() { diff --git a/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java b/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java index e092704..c6f23f9 100644 --- a/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java +++ b/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java @@ -183,6 +183,29 @@ void castColumnsTaggedInSchemaMetadata() { assertTrue(spark.apply("ts").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); } + @Test + void probeColumnPrefersCastlessThenFallsBackToCast() { + // A castless column is preferred (cheapest, no cast) for a column-less scan. + Schema mixed = + new Schema( + List.of( + Field.nullable("channel", new ArrowType.Int(16, false)), + Field.nullable("id", new ArrowType.Int(64, true)))); + ProjectionColumn probe = SchemaConverter.probeColumn(mixed); + assertEquals("id", probe.name()); + assertNull(probe.castType()); + + // When every column needs a cast, the first column is used with its cast applied. + Schema allCast = + new Schema( + List.of( + Field.nullable("channel", new ArrowType.Int(16, false)), + Field.nullable("ts", new ArrowType.Timestamp(TimeUnit.NANOSECOND, null)))); + ProjectionColumn casted = SchemaConverter.probeColumn(allCast); + assertEquals("channel", casted.name()); + assertEquals("Int32", casted.castType()); + } + @Test void projectionColumnsCarryCastTargets() { Schema schema = From 390691bcc9a1f5633317c1ddb3d919e813fd0f34 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 1 Jul 2026 16:05:43 -0400 Subject: [PATCH 4/5] fix(spark): cast FixedSizeList to a variable List for Spark Spark's ArrowColumnVector backs ArrayType only from a variable ListVector, never a FixedSizeListVector, so a fixed-size list reached the reader and failed with UNSUPPORTED_ARROWTYPE FixedSizeList. SchemaConverter mapped it to ArrayType (schema probe passed) but needsCast did not flag it, so no arrow_cast was emitted; renderArrowType would also have produced an identity FixedSizeList(N) target. - needsCast: flag FixedSizeList unconditionally (the fixed layout is unreadable regardless of element type). - renderArrowType: render FixedSizeList as a variable List() so arrow_cast converts fixed->variable, with the element rendered cast-aware (e.g. FixedSizeList -> List(Float32)). Tests: SchemaConverterTest.fixedSizeListCastsToVariableList (unit); the example `types` table gains a FixedSizeList `vec` column and AdbcSourceTest / the pyspark harness assert it round-trips to Array with the element widened. Co-Authored-By: Claude Opus 4.8 --- .../adbc-datafusion-driver/pyspark_e2e.py | 7 +++++- .../adbc-datafusion-driver/src/provider.rs | 21 +++++++++++++++--- .../datafusion/spark/SchemaConverter.java | 15 ++++++++----- .../datafusion/spark/AdbcSourceTest.java | 9 ++++++++ .../datafusion/spark/SchemaConverterTest.java | 22 +++++++++++++++++++ 5 files changed, 64 insertions(+), 10 deletions(-) diff --git a/examples/adbc-datafusion-driver/pyspark_e2e.py b/examples/adbc-datafusion-driver/pyspark_e2e.py index 9c20ecf..c5cca04 100644 --- a/examples/adbc-datafusion-driver/pyspark_e2e.py +++ b/examples/adbc-datafusion-driver/pyspark_e2e.py @@ -118,7 +118,7 @@ def read(table): # cast columns are flagged (so filter pushdown stays off them); pass-through ones are not. cast_cols = {f.name for f in dft.schema.fields if f.metadata.get(CAST_META_KEY)} print("cast columns:", sorted(cast_cols)) - assert cast_cols == {"channel", "big", "event_time", "score", "tags"}, cast_cols + assert cast_cols == {"channel", "big", "event_time", "score", "tags", "vec"}, cast_cols # count() prunes the projection to empty; the connector must not emit a bare SELECT * (which # would return the raw, uncast schema and fail the reader on the ns timestamp / unsigned ids). @@ -158,6 +158,11 @@ def read(table): assert r2["tags"] == [] assert r3["tags"] == [3] + # FixedSizeList -> Array (fixed->variable + element widening) + assert r1["vec"] == [10, 20] + assert r2["vec"] == [30, 40] + assert r3["vec"] == [50, 60] + # nested List> passes through assert [(x["key"], x["val"]) for x in r1["attrs"]] == [("a", "1")] assert r2["attrs"] == [] diff --git a/examples/adbc-datafusion-driver/src/provider.rs b/examples/adbc-datafusion-driver/src/provider.rs index ff18fe8..854a694 100644 --- a/examples/adbc-datafusion-driver/src/provider.rs +++ b/examples/adbc-datafusion-driver/src/provider.rs @@ -48,8 +48,9 @@ use std::sync::Arc; use async_trait::async_trait; use datafusion::arrow::array::{ - ArrayRef, BinaryArray, Float16Array, Int64Array, ListBuilder, StringArray, StringBuilder, - StructBuilder, TimestampNanosecondArray, UInt16Array, UInt16Builder, UInt64Array, + ArrayRef, BinaryArray, FixedSizeListBuilder, Float16Array, Int64Array, ListBuilder, + StringArray, StringBuilder, StructBuilder, TimestampNanosecondArray, UInt16Array, + UInt16Builder, UInt64Array, }; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; @@ -145,6 +146,9 @@ struct Row { event_time: i64, score: f32, tags: &'static [u16], + /// A fixed-size list of two unsigned ints -- Spark cannot read FixedSizeList and must have it + /// cast to a variable List (with the element widened). Always length 2. + vec: [u16; 2], } // Values are chosen to exercise the widening edges: channel spans past i16::MAX, big includes @@ -161,6 +165,7 @@ const TYPE_ROWS: [Row; 3] = [ event_time: 1_600_000_000_000_000_000, // 2020-09-13 score: 1.5, tags: &[1, 2], + vec: [10, 20], }, Row { id: 2, @@ -172,6 +177,7 @@ const TYPE_ROWS: [Row; 3] = [ event_time: 1_610_000_000_000_000_000, // 2021-01-07 score: 2.5, tags: &[], + vec: [30, 40], }, Row { id: 3, @@ -183,6 +189,7 @@ const TYPE_ROWS: [Row; 3] = [ event_time: 1_620_000_000_000_000_000, // 2021-05-03 score: 3.5, tags: &[3], + vec: [50, 60], }, ]; @@ -209,6 +216,7 @@ impl TypesTableProvider { Field::new("event_time", proto[6].data_type().clone(), true), Field::new("score", proto[7].data_type().clone(), true), Field::new("tags", proto[8].data_type().clone(), true), + Field::new("vec", proto[9].data_type().clone(), true), ])); // One self-contained single-row batch per partition (see the module docs on why we do @@ -257,7 +265,7 @@ impl TableProvider for TypesTableProvider { } } -/// Build the nine single-row column arrays for one [`Row`], in schema order. +/// Build the ten single-row column arrays for one [`Row`], in schema order. fn row_columns(row: &Row) -> Vec { let mut tags = ListBuilder::new(UInt16Builder::new()); for t in row.tags { @@ -265,6 +273,12 @@ fn row_columns(row: &Row) -> Vec { } tags.append(true); + let mut vec = FixedSizeListBuilder::new(UInt16Builder::new(), row.vec.len() as i32); + for v in row.vec { + vec.values().append_value(v); + } + vec.append(true); + let attr_fields = vec![ Field::new("key", DataType::Utf8, true), Field::new("val", DataType::Utf8, true), @@ -292,5 +306,6 @@ fn row_columns(row: &Row) -> Vec { Arc::new(TimestampNanosecondArray::from(vec![row.event_time])), Arc::new(Float16Array::from(vec![f16::from_f32(row.score)])), Arc::new(tags.finish()) as ArrayRef, + Arc::new(vec.finish()) as ArrayRef, ] } diff --git a/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java b/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java index b8c7357..30cf072 100644 --- a/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java +++ b/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java @@ -242,6 +242,11 @@ static boolean needsCast(Field field) { if (type instanceof ArrowType.Time) { return true; } + // Spark's ArrowColumnVector backs ArrayType only from a variable ListVector, never a + // FixedSizeListVector, so a fixed-size list must always be cast to a variable list. + if (type instanceof ArrowType.FixedSizeList) { + return true; + } for (Field child : field.getChildren()) { if (needsCast(child)) { return true; @@ -325,12 +330,10 @@ private static String renderArrowType(Field field) { if (type instanceof ArrowType.LargeList) { return "LargeList(" + listChild(field.getChildren().get(0)) + ")"; } - if (type instanceof ArrowType.FixedSizeList fsl) { - return "FixedSizeList(" - + fsl.getListSize() - + " x " - + listChild(field.getChildren().get(0)) - + ")"; + if (type instanceof ArrowType.FixedSizeList) { + // Cast to a variable list: Spark can only read ArrayType from a ListVector. The element is + // rendered cast-aware, so e.g. FixedSizeList becomes List(Float32). + return "List(" + listChild(field.getChildren().get(0)) + ")"; } if (type instanceof ArrowType.Struct) { StringBuilder sb = new StringBuilder("Struct("); diff --git a/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java b/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java index c055739..f686e64 100644 --- a/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java +++ b/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java @@ -120,6 +120,9 @@ void typesSchemaMapsAndFlagsCasts() { assertEquals(DataTypes.FloatType, schema.apply("score").dataType()); assertEquals( DataTypes.createArrayType(DataTypes.IntegerType, true), schema.apply("tags").dataType()); + // FixedSizeList -> variable Array (fixed layout not readable by Spark). + assertEquals( + DataTypes.createArrayType(DataTypes.IntegerType, true), schema.apply("vec").dataType()); // Cast columns are flagged (so filter pushdown stays off them); pass-through columns are not. assertTrue(schema.apply("channel").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); @@ -127,6 +130,7 @@ void typesSchemaMapsAndFlagsCasts() { assertTrue(schema.apply("event_time").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); assertTrue(schema.apply("score").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); assertTrue(schema.apply("tags").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("vec").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); assertFalse(schema.apply("payload").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); assertFalse(schema.apply("attrs").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); } @@ -197,6 +201,11 @@ void typesValuesRoundTripThroughCasts() { assertEquals(List.of(), r2.getList(r2.fieldIndex("tags"))); assertEquals(List.of(3), r3.getList(r3.fieldIndex("tags"))); + // FixedSizeList -> Array (fixed->variable + element widening). + assertEquals(List.of(10, 20), r1.getList(r1.fieldIndex("vec"))); + assertEquals(List.of(30, 40), r2.getList(r2.fieldIndex("vec"))); + assertEquals(List.of(50, 60), r3.getList(r3.fieldIndex("vec"))); + // nested List> passes through. List attrs3 = r3.getList(r3.fieldIndex("attrs")); assertEquals(2, attrs3.size()); diff --git a/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java b/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java index c6f23f9..bc6006e 100644 --- a/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java +++ b/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java @@ -166,6 +166,28 @@ void nestedListOfUnsignedCastsRecursively() { assertEquals("List(Int32)", SchemaConverter.castTargetString(list)); } + @Test + void fixedSizeListCastsToVariableList() { + // Spark reads ArrayType only from a variable ListVector, so a fixed-size list is always cast + // to a variable list -- even when its element needs no cast. + Field intFixed = + new Field( + "v", + FieldType.nullable(new ArrowType.FixedSizeList(2)), + List.of(nullable("item", new ArrowType.Int(32, true)))); + assertEquals(DataTypes.createArrayType(DataTypes.IntegerType, true), spark(intFixed)); + assertEquals("List(Int32)", SchemaConverter.castTargetString(intFixed)); + + // The element is rendered cast-aware: FixedSizeList -> List(Float32). + Field halfFixed = + new Field( + "v", + FieldType.nullable(new ArrowType.FixedSizeList(3)), + List.of(nullable("item", new ArrowType.FloatingPoint(FloatingPointPrecision.HALF)))); + assertEquals(DataTypes.createArrayType(DataTypes.FloatType, true), spark(halfFixed)); + assertEquals("List(Float32)", SchemaConverter.castTargetString(halfFixed)); + } + // --- schema-level: metadata flag + projection planning ------------------- @Test From 9f6e89d24c595f93a70032ca62865719080fa985 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 1 Jul 2026 16:34:50 -0400 Subject: [PATCH 5/5] refactor(spark): derive schema, cast decision, and cast target from one authority Enumerating "maps to a Spark type" and "readable by ArrowColumnVector" by hand let them drift: a type could map to a Spark DataType yet have no vectorized accessor, crashing at task time (FixedSizeList, and latent for LargeList / FixedSizeBinary / Date64). Replace the three hand-kept switches with a single authority. - sparkConsumable(ArrowType): the read gate, mirroring Spark 4.0's ArrowColumnVector.initAccessor accessor set (verified against the class), deliberately the narrower gate. - sparkTarget(Field): the nearest consumable field, recursing into children (identity when already consumable). One place normalizes unsigned->signed, Float16->Float32, non-us timestamp->us, Date64->Date32, Time->int, FixedSizeBinary->Binary, FixedSizeList/LargeList->List. toSparkType, needsCast, and castTargetString now all derive from sparkTarget, so the gates cannot disagree and adding a type is one case. toSparkSchema asserts at plan time that every target is consumable, so an unsupported type (e.g. Interval) fails in planning naming the column, instead of an opaque executor UNSUPPORTED_ARROWTYPE. Closes three latent bugs alongside FixedSizeList: LargeList, FixedSizeBinary, and Date64 previously mapped to a Spark type without a cast. Tests: SchemaConverterTest gains cases for LargeList/FixedSizeBinary/Date64, a "sparkTarget is always consumable" invariant, and a fail-fast case; the example `types` table gains labels (LargeList), digest (FixedSizeBinary), and day (Date64) columns, asserted end-to-end in AdbcSourceTest and the pyspark harness. The E2E session enables the Java 8 date API so DateType decodes without needing sun.util.calendar opened. Co-Authored-By: Claude Opus 4.8 --- .../adbc-datafusion-driver/pyspark_e2e.py | 26 +- .../adbc-datafusion-driver/src/provider.rs | 38 +- .../datafusion/spark/SchemaConverter.java | 334 ++++++++++-------- .../datafusion/spark/AdbcSourceTest.java | 39 +- .../datafusion/spark/SchemaConverterTest.java | 79 +++++ 5 files changed, 366 insertions(+), 150 deletions(-) diff --git a/examples/adbc-datafusion-driver/pyspark_e2e.py b/examples/adbc-datafusion-driver/pyspark_e2e.py index c5cca04..279f909 100644 --- a/examples/adbc-datafusion-driver/pyspark_e2e.py +++ b/examples/adbc-datafusion-driver/pyspark_e2e.py @@ -118,7 +118,17 @@ def read(table): # cast columns are flagged (so filter pushdown stays off them); pass-through ones are not. cast_cols = {f.name for f in dft.schema.fields if f.metadata.get(CAST_META_KEY)} print("cast columns:", sorted(cast_cols)) - assert cast_cols == {"channel", "big", "event_time", "score", "tags", "vec"}, cast_cols + assert cast_cols == { + "channel", + "big", + "event_time", + "score", + "tags", + "vec", + "labels", + "digest", + "day", + }, cast_cols # count() prunes the projection to empty; the connector must not emit a bare SELECT * (which # would return the raw, uncast schema and fail the reader on the ns timestamp / unsigned ids). @@ -163,6 +173,20 @@ def read(table): assert r2["vec"] == [30, 40] assert r3["vec"] == [50, 60] + # LargeList -> Array + assert r1["labels"] == ["a", "b"] + assert r2["labels"] == [] + assert r3["labels"] == ["c"] + + # FixedSizeBinary -> Binary + assert bytes(r1["digest"]) == b"\x01\x02\x03\x04" + assert bytes(r3["digest"]) == b"\xff\xfe\xfd\xfc" + + # Date64 -> Date32 (day-aligned) + assert r1["day"] == datetime.date(2020, 9, 13) + assert r2["day"] == datetime.date(2021, 1, 7) + assert r3["day"] == datetime.date(2021, 5, 3) + # nested List> passes through assert [(x["key"], x["val"]) for x in r1["attrs"]] == [("a", "1")] assert r2["attrs"] == [] diff --git a/examples/adbc-datafusion-driver/src/provider.rs b/examples/adbc-datafusion-driver/src/provider.rs index 854a694..2c18fe7 100644 --- a/examples/adbc-datafusion-driver/src/provider.rs +++ b/examples/adbc-datafusion-driver/src/provider.rs @@ -48,9 +48,9 @@ use std::sync::Arc; use async_trait::async_trait; use datafusion::arrow::array::{ - ArrayRef, BinaryArray, FixedSizeListBuilder, Float16Array, Int64Array, ListBuilder, - StringArray, StringBuilder, StructBuilder, TimestampNanosecondArray, UInt16Array, - UInt16Builder, UInt64Array, + ArrayRef, BinaryArray, Date64Array, FixedSizeBinaryBuilder, FixedSizeListBuilder, Float16Array, + Int64Array, LargeListBuilder, ListBuilder, StringArray, StringBuilder, StructBuilder, + TimestampNanosecondArray, UInt16Array, UInt16Builder, UInt64Array, }; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; @@ -149,6 +149,12 @@ struct Row { /// A fixed-size list of two unsigned ints -- Spark cannot read FixedSizeList and must have it /// cast to a variable List (with the element widened). Always length 2. vec: [u16; 2], + /// A LargeList of strings -- maps to ArrayType but has no accessor, so it casts to List. + labels: &'static [&'static str], + /// FixedSizeBinary(4) -- maps to BinaryType but has no accessor, so it casts to Binary. + digest: [u8; 4], + /// Date64 (ms since epoch, day-aligned) -- no accessor, casts to Date32. + day_ms: i64, } // Values are chosen to exercise the widening edges: channel spans past i16::MAX, big includes @@ -166,6 +172,9 @@ const TYPE_ROWS: [Row; 3] = [ score: 1.5, tags: &[1, 2], vec: [10, 20], + labels: &["a", "b"], + digest: [0x01, 0x02, 0x03, 0x04], + day_ms: 1_599_955_200_000, // 2020-09-13 (18518 days) }, Row { id: 2, @@ -178,6 +187,9 @@ const TYPE_ROWS: [Row; 3] = [ score: 2.5, tags: &[], vec: [30, 40], + labels: &[], + digest: [0x00, 0x00, 0x00, 0x00], + day_ms: 1_609_977_600_000, // 2021-01-07 (18634 days) }, Row { id: 3, @@ -190,6 +202,9 @@ const TYPE_ROWS: [Row; 3] = [ score: 3.5, tags: &[3], vec: [50, 60], + labels: &["c"], + digest: [0xff, 0xfe, 0xfd, 0xfc], + day_ms: 1_620_000_000_000, // 2021-05-03 (18750 days) }, ]; @@ -217,6 +232,9 @@ impl TypesTableProvider { Field::new("score", proto[7].data_type().clone(), true), Field::new("tags", proto[8].data_type().clone(), true), Field::new("vec", proto[9].data_type().clone(), true), + Field::new("labels", proto[10].data_type().clone(), true), + Field::new("digest", proto[11].data_type().clone(), true), + Field::new("day", proto[12].data_type().clone(), true), ])); // One self-contained single-row batch per partition (see the module docs on why we do @@ -265,7 +283,7 @@ impl TableProvider for TypesTableProvider { } } -/// Build the ten single-row column arrays for one [`Row`], in schema order. +/// Build the single-row column arrays for one [`Row`], in schema order. fn row_columns(row: &Row) -> Vec { let mut tags = ListBuilder::new(UInt16Builder::new()); for t in row.tags { @@ -279,6 +297,15 @@ fn row_columns(row: &Row) -> Vec { } vec.append(true); + let mut labels = LargeListBuilder::new(StringBuilder::new()); + for l in row.labels { + labels.values().append_value(l); + } + labels.append(true); + + let mut digest = FixedSizeBinaryBuilder::new(row.digest.len() as i32); + digest.append_value(row.digest).expect("digest bytes"); + let attr_fields = vec![ Field::new("key", DataType::Utf8, true), Field::new("val", DataType::Utf8, true), @@ -307,5 +334,8 @@ fn row_columns(row: &Row) -> Vec { Arc::new(Float16Array::from(vec![f16::from_f32(row.score)])), Arc::new(tags.finish()) as ArrayRef, Arc::new(vec.finish()) as ArrayRef, + Arc::new(labels.finish()) as ArrayRef, + Arc::new(digest.finish()) as ArrayRef, + Arc::new(Date64Array::from(vec![row.day_ms])), ] } diff --git a/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java b/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java index 30cf072..4ef537c 100644 --- a/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java +++ b/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java @@ -22,11 +22,12 @@ import java.util.ArrayList; import java.util.List; +import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; -import org.apache.arrow.vector.types.IntervalUnit; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; @@ -37,28 +38,38 @@ /** * Converts an Arrow schema (produced by the ADBC scan) into a Spark {@link StructType}, and plans - * the source-side casts that make the scan emit Spark-native Arrow. + * the source-side casts that make the scan emit Arrow that Spark's vectorized reader can consume. * - *

Done directly rather than through Spark's {@code ArrowUtils} so the connector depends only on - * our Arrow version, never Spark's bundled one. + *

An Arrow column must clear two independent gates: * - *

Spark's vectorized {@code ArrowColumnVector} reads a fixed set of Arrow layouts: signed ints, - * 32/64-bit floats, microsecond timestamps, string/binary, decimal, date, and nested - * list/struct/map of those. Two categories of source type need handling: + *

    + *
  1. Schema mapping -- an Arrow type maps to some Spark {@code DataType} (used at {@code + * inferSchema}). + *
  2. Vectorized read -- Spark's {@code ArrowColumnVector} has an accessor for the vector. + * This gate is narrower than the first and is not a public API. + *
+ * + *

Enumerating the two gates by hand lets them drift: a type can map to a Spark type yet have no + * accessor (e.g. {@code FixedSizeList}/{@code LargeList} both map to {@code ArrayType} but only a + * variable {@code ListVector} is readable), producing an {@code UNSUPPORTED_ARROWTYPE} at task time + * rather than plan time. To avoid that, everything derives from one authority: * *

    - *
  • Directly representable -- the layout already matches; only the type mapping was - * missing (binary, nested list/struct/map, date, decimal, µs timestamp, null). These pass - * through untouched. - *
  • Cast required -- the layout differs from what {@code ArrowColumnVector} expects, so - * the scan must cast at the source (unsigned ints, Float16, non-µs timestamps, time). We map - * these to the Spark type they will be cast to, and {@link #castTargetString} names - * the Arrow target so {@link SqlQuery} can wrap the column in {@code arrow_cast}. + *
  • {@link #sparkConsumable} -- the read gate: {@code true} iff {@code ArrowColumnVector} has + * an accessor for the type. Mirrors the accessors in Spark 4.0's {@code + * ArrowColumnVector.initAccessor} (deliberately the narrower gate). + *
  • {@link #sparkTarget} -- the nearest consumable Arrow type for a field (identity if already + * consumable), recursing into children. Non-consumable leaves widen (unsigned -> signed, + * Float16 -> Float32, non-µs timestamp -> µs, Date64 -> Date32, Time -> int, + * FixedSizeBinary -> Binary) and non-consumable containers convert ({@code + * FixedSizeList}/{@code LargeList} -> {@code List}). *
* - *

The cast is pushed into the scan (see {@link SqlQuery}), so this converter and the reader only - * ever agree on Spark-native types: the reported Spark type of a column always equals what {@code - * ArrowColumnVector} produces from the (possibly cast) Arrow output. + *

Then the reported Spark type, the cast decision, and the cast target all come from {@code + * sparkTarget}, so the two gates cannot drift and adding a type is one case, not several. The cast + * is pushed into the scan (see {@link SqlQuery}); {@link #toSparkSchema} asserts at plan time that + * every target is consumable, so an unsupported type fails in planning with a clear message rather + * than as an opaque executor crash. */ final class SchemaConverter { @@ -77,11 +88,15 @@ record ProjectionColumn(String name, String castType) {} static StructType toSparkSchema(Schema arrowSchema) { StructType struct = new StructType(); for (Field field : arrowSchema.getFields()) { + Field target = sparkTarget(field); + // Plan-time guard: if a type cannot be normalized to something Spark can read, fail here + // (naming the column) instead of crashing on an executor with UNSUPPORTED_ARROWTYPE. + ensureConsumable(field.getName(), target); Metadata metadata = - needsCast(field) - ? new MetadataBuilder().putBoolean(CAST_METADATA_KEY, true).build() - : Metadata.empty(); - struct = struct.add(field.getName(), toSparkType(field), field.isNullable(), metadata); + target == field + ? Metadata.empty() + : new MetadataBuilder().putBoolean(CAST_METADATA_KEY, true).build(); + struct = struct.add(field.getName(), mapConsumable(target), field.isNullable(), metadata); } return struct; } @@ -109,9 +124,15 @@ static List projectionColumns(Schema schema, List proj return columns; } + /** Whether the column (recursively) is not something Spark's reader can consume as-is. */ + static boolean needsCast(Field field) { + return sparkTarget(field) != field; + } + /** The Arrow type string for {@code arrow_cast}, or {@code null} if the column needs no cast. */ static String castTargetString(Field field) { - return needsCast(field) ? renderArrowType(field) : null; + Field target = sparkTarget(field); + return target == field ? null : render(target); } /** @@ -131,35 +152,148 @@ static ProjectionColumn probeColumn(Schema schema) { return new ProjectionColumn(first.getName(), castTargetString(first)); } - // --- Arrow type -> Spark type --------------------------------------------- + // --- The two-gate authority ----------------------------------------------- + + /** + * Whether Spark's vectorized {@code ArrowColumnVector} has an accessor for this Arrow type. + * Mirrors {@code ArrowColumnVector.initAccessor} in Spark 4.0 -- deliberately the narrower gate, + * so a type that maps to a Spark {@code DataType} but has no accessor (unsigned, Float16, non-µs + * timestamp, Date64, {@code Time*}, {@code FixedSizeList}/{@code LargeList}, {@code + * FixedSizeBinary}, {@code Interval}, {@code Dictionary}, ...) is reported as non-consumable. + */ + static boolean sparkConsumable(ArrowType type) { + if (type instanceof ArrowType.Bool) { + return true; + } + if (type instanceof ArrowType.Int i) { + return i.getIsSigned() + && (i.getBitWidth() == 8 + || i.getBitWidth() == 16 + || i.getBitWidth() == 32 + || i.getBitWidth() == 64); + } + if (type instanceof ArrowType.FloatingPoint fp) { + return fp.getPrecision() == FloatingPointPrecision.SINGLE + || fp.getPrecision() == FloatingPointPrecision.DOUBLE; + } + if (type instanceof ArrowType.Utf8 || type instanceof ArrowType.LargeUtf8) { + return true; + } + if (type instanceof ArrowType.Binary || type instanceof ArrowType.LargeBinary) { + return true; // FixedSizeBinary is NOT readable. + } + if (type instanceof ArrowType.Decimal d) { + return d.getBitWidth() == 128; // Spark's DecimalVector is 128-bit; Decimal256 is not read. + } + if (type instanceof ArrowType.Date d) { + return d.getUnit() == DateUnit.DAY; // Date64 (MILLISECOND) is not readable. + } + if (type instanceof ArrowType.Timestamp ts) { + return ts.getUnit() == TimeUnit.MICROSECOND; // only µs, any timezone. + } + if (type instanceof ArrowType.Duration) { + return true; + } + if (type instanceof ArrowType.Null) { + return true; + } + // Only the variable-offset containers are readable (not FixedSizeList / LargeList). + return type instanceof ArrowType.List + || type instanceof ArrowType.Struct + || type instanceof ArrowType.Map; + } + + /** + * The nearest Spark-consumable field for {@code field}, recursing into children. Returns the same + * instance when nothing changes (so {@code sparkTarget(f) == f} means "no cast needed"). + */ + static Field sparkTarget(Field field) { + ArrowType type = field.getType(); + ArrowType targetType = type; + + if (type instanceof ArrowType.Int i && !i.getIsSigned()) { + targetType = + switch (i.getBitWidth()) { + case 8 -> new ArrowType.Int(16, true); + case 16 -> new ArrowType.Int(32, true); + case 32 -> new ArrowType.Int(64, true); + case 64 -> new ArrowType.Decimal(20, 0, 128); // no lossless signed 64-bit target + default -> type; + }; + } else if (type instanceof ArrowType.FloatingPoint fp + && fp.getPrecision() == FloatingPointPrecision.HALF) { + targetType = new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + } else if (type instanceof ArrowType.Timestamp ts && ts.getUnit() != TimeUnit.MICROSECOND) { + targetType = new ArrowType.Timestamp(TimeUnit.MICROSECOND, ts.getTimezone()); + } else if (type instanceof ArrowType.Date d && d.getUnit() != DateUnit.DAY) { + targetType = new ArrowType.Date(DateUnit.DAY); + } else if (type instanceof ArrowType.Time t) { + // Spark has no time-of-day accessor; carry the raw ticks as the matching-width signed int. + targetType = new ArrowType.Int(t.getBitWidth() == 32 ? 32 : 64, true); + } else if (type instanceof ArrowType.FixedSizeBinary) { + targetType = new ArrowType.Binary(); + } else if (type instanceof ArrowType.FixedSizeList || type instanceof ArrowType.LargeList) { + // Spark reads ArrayType only from a variable ListVector. + targetType = new ArrowType.List(); + } + + // Recurse into children (list element, struct fields, map entries), widening each. + List children = field.getChildren(); + List targetChildren = new ArrayList<>(children.size()); + boolean childChanged = false; + for (Field child : children) { + Field targetChild = sparkTarget(child); + childChanged |= targetChild != child; + targetChildren.add(targetChild); + } + + if (targetType == type && !childChanged) { + return field; + } + FieldType ft = new FieldType(field.isNullable(), targetType, field.getDictionary()); + return new Field(field.getName(), ft, targetChildren); + } + + /** Assert every node of a {@code sparkTarget} result is consumable, else fail with the column. */ + private static void ensureConsumable(String column, Field target) { + if (!sparkConsumable(target.getType())) { + throw new IllegalArgumentException( + "column '" + + column + + "': Arrow type " + + target.getType() + + " has no Spark reader and " + + "no supported cast; the connector cannot expose it to Spark"); + } + for (Field child : target.getChildren()) { + ensureConsumable(column, child); + } + } + // --- Arrow (consumable) type -> Spark type --------------------------------- + + /** Map an already-{@link #sparkConsumable} field to its Spark {@link DataType}. */ static DataType toSparkType(Field field) { + Field target = sparkTarget(field); + ensureConsumable(field.getName(), target); + return mapConsumable(target); + } + + private static DataType mapConsumable(Field field) { ArrowType type = field.getType(); if (type instanceof ArrowType.Bool) { return DataTypes.BooleanType; } if (type instanceof ArrowType.Int i) { - if (i.getIsSigned()) { - return switch (i.getBitWidth()) { - case 8 -> DataTypes.ByteType; - case 16 -> DataTypes.ShortType; - case 32 -> DataTypes.IntegerType; - case 64 -> DataTypes.LongType; - default -> throw unsupported(field); - }; - } - // Unsigned: widened to the next signed width it will be cast to (u64 has no lossless - // signed 64-bit target, so it becomes Decimal(20,0)). return switch (i.getBitWidth()) { - case 8 -> DataTypes.ShortType; - case 16 -> DataTypes.IntegerType; - case 32 -> DataTypes.LongType; - case 64 -> DataTypes.createDecimalType(20, 0); + case 8 -> DataTypes.ByteType; + case 16 -> DataTypes.ShortType; + case 32 -> DataTypes.IntegerType; + case 64 -> DataTypes.LongType; default -> throw unsupported(field); }; } if (type instanceof ArrowType.FloatingPoint fp) { - // Float16 has no Spark type; it is widened to Float. return fp.getPrecision() == FloatingPointPrecision.DOUBLE ? DataTypes.DoubleType : DataTypes.FloatType; @@ -167,22 +301,15 @@ static DataType toSparkType(Field field) { if (type instanceof ArrowType.Utf8 || type instanceof ArrowType.LargeUtf8) { return DataTypes.StringType; } - if (type instanceof ArrowType.Binary - || type instanceof ArrowType.LargeBinary - || type instanceof ArrowType.FixedSizeBinary) { + if (type instanceof ArrowType.Binary || type instanceof ArrowType.LargeBinary) { return DataTypes.BinaryType; } if (type instanceof ArrowType.Date) { return DataTypes.DateType; } if (type instanceof ArrowType.Timestamp ts) { - // Unit is normalized to microseconds by the cast; the timezone decides NTZ vs zoned. return ts.getTimezone() == null ? DataTypes.TimestampNTZType : DataTypes.TimestampType; } - if (type instanceof ArrowType.Time t) { - // Spark has no time-of-day accessor; the value is cast to its raw integer of ticks. - return t.getBitWidth() == 32 ? DataTypes.IntegerType : DataTypes.LongType; - } if (type instanceof ArrowType.Decimal d) { return DataTypes.createDecimalType(d.getPrecision(), d.getScale()); } @@ -192,24 +319,15 @@ static DataType toSparkType(Field field) { if (type instanceof ArrowType.Duration) { return DataTypes.createDayTimeIntervalType(); } - if (type instanceof ArrowType.Interval iv) { - return switch (iv.getUnit()) { - case YEAR_MONTH -> DataTypes.createYearMonthIntervalType(); - case DAY_TIME -> DataTypes.createDayTimeIntervalType(); - default -> throw unsupported(field); - }; - } - if (type instanceof ArrowType.List - || type instanceof ArrowType.LargeList - || type instanceof ArrowType.FixedSizeList) { + if (type instanceof ArrowType.List) { Field element = field.getChildren().get(0); - return DataTypes.createArrayType(toSparkType(element), element.isNullable()); + return DataTypes.createArrayType(mapConsumable(element), element.isNullable()); } if (type instanceof ArrowType.Struct) { List children = new ArrayList<>(); for (Field child : field.getChildren()) { children.add( - DataTypes.createStructField(child.getName(), toSparkType(child), child.isNullable())); + DataTypes.createStructField(child.getName(), mapConsumable(child), child.isNullable())); } return DataTypes.createStructType(children); } @@ -217,66 +335,25 @@ static DataType toSparkType(Field field) { Field entries = field.getChildren().get(0); Field key = entries.getChildren().get(0); Field value = entries.getChildren().get(1); - return DataTypes.createMapType(toSparkType(key), toSparkType(value), value.isNullable()); + return DataTypes.createMapType(mapConsumable(key), mapConsumable(value), value.isNullable()); } throw unsupported(field); } - // --- Cast planning -------------------------------------------------------- - - /** - * Whether the column (recursively) has any layout that Spark's reader cannot consume directly. - */ - static boolean needsCast(Field field) { - ArrowType type = field.getType(); - if (type instanceof ArrowType.Int i && !i.getIsSigned()) { - return true; - } - if (type instanceof ArrowType.FloatingPoint fp - && fp.getPrecision() == FloatingPointPrecision.HALF) { - return true; - } - if (type instanceof ArrowType.Timestamp ts && ts.getUnit() != TimeUnit.MICROSECOND) { - return true; - } - if (type instanceof ArrowType.Time) { - return true; - } - // Spark's ArrowColumnVector backs ArrayType only from a variable ListVector, never a - // FixedSizeListVector, so a fixed-size list must always be cast to a variable list. - if (type instanceof ArrowType.FixedSizeList) { - return true; - } - for (Field child : field.getChildren()) { - if (needsCast(child)) { - return true; - } - } - return false; - } + // --- Arrow (consumable) type -> arrow_cast type string --------------------- /** - * Render the widened Arrow type as an {@code arrow_cast} type string (the reversible {@code - * arrow::datatypes::DataType} display form that DataFusion's {@code arrow_cast} parses). Cast - * leaves become their widened target; everything else is rendered as-is so a nested cast carries - * its unchanged siblings along. + * Render an already-{@link #sparkConsumable} field as an {@code arrow_cast} type string (the + * reversible {@code arrow::datatypes::DataType} display form DataFusion's {@code arrow_cast} + * parses). Called only on {@link #sparkTarget} output, which is consumable by construction. */ - private static String renderArrowType(Field field) { + private static String render(Field field) { ArrowType type = field.getType(); if (type instanceof ArrowType.Bool) { return "Boolean"; } if (type instanceof ArrowType.Int i) { - if (i.getIsSigned()) { - return "Int" + i.getBitWidth(); - } - return switch (i.getBitWidth()) { - case 8 -> "Int16"; - case 16 -> "Int32"; - case 32 -> "Int64"; - case 64 -> "Decimal128(20, 0)"; - default -> throw unsupported(field); - }; + return "Int" + i.getBitWidth(); } if (type instanceof ArrowType.FloatingPoint fp) { return fp.getPrecision() == FloatingPointPrecision.DOUBLE ? "Float64" : "Float32"; @@ -293,48 +370,26 @@ private static String renderArrowType(Field field) { if (type instanceof ArrowType.LargeBinary) { return "LargeBinary"; } - if (type instanceof ArrowType.FixedSizeBinary fb) { - return "FixedSizeBinary(" + fb.getByteWidth() + ")"; - } - if (type instanceof ArrowType.Date d) { - return switch (d.getUnit()) { - case DAY -> "Date32"; - case MILLISECOND -> "Date64"; - }; + if (type instanceof ArrowType.Date) { + return "Date32"; } if (type instanceof ArrowType.Timestamp ts) { - // Normalize to microseconds, preserving the timezone. return ts.getTimezone() == null ? "Timestamp(Microsecond)" : "Timestamp(Microsecond, \"" + ts.getTimezone() + "\")"; } - if (type instanceof ArrowType.Time t) { - return t.getBitWidth() == 32 ? "Int32" : "Int64"; - } if (type instanceof ArrowType.Decimal d) { - String kind = d.getBitWidth() == 256 ? "Decimal256" : "Decimal128"; - return kind + "(" + d.getPrecision() + ", " + d.getScale() + ")"; + return "Decimal128(" + d.getPrecision() + ", " + d.getScale() + ")"; } if (type instanceof ArrowType.Duration dur) { return "Duration(" + timeUnitName(dur.getUnit()) + ")"; } - if (type instanceof ArrowType.Interval iv) { - return "Interval(" + intervalUnitName(iv.getUnit()) + ")"; - } if (type instanceof ArrowType.Null) { return "Null"; } if (type instanceof ArrowType.List) { return "List(" + listChild(field.getChildren().get(0)) + ")"; } - if (type instanceof ArrowType.LargeList) { - return "LargeList(" + listChild(field.getChildren().get(0)) + ")"; - } - if (type instanceof ArrowType.FixedSizeList) { - // Cast to a variable list: Spark can only read ArrayType from a ListVector. The element is - // rendered cast-aware, so e.g. FixedSizeList becomes List(Float32). - return "List(" + listChild(field.getChildren().get(0)) + ")"; - } if (type instanceof ArrowType.Struct) { StringBuilder sb = new StringBuilder("Struct("); List children = field.getChildren(); @@ -357,10 +412,9 @@ private static String renderArrowType(Field field) { throw unsupported(field); } - /** {@code [, field: 'name']} -- the list/fixed-size-list child form. */ + /** {@code [, field: 'name']} -- the list child form. */ private static String listChild(Field field) { - String rendered = nullability(field) + renderArrowType(field); - // The default list-field name ("item") is elided by the display form. + String rendered = nullability(field) + render(field); return "item".equals(field.getName()) ? rendered : rendered + ", field: '" + field.getName() + "'"; @@ -368,7 +422,7 @@ private static String listChild(Field field) { /** {@code "name": } -- the struct/map field form. */ private static String structField(Field field) { - return debugQuote(field.getName()) + ": " + nullability(field) + renderArrowType(field); + return debugQuote(field.getName()) + ": " + nullability(field) + render(field); } private static String nullability(Field field) { @@ -384,14 +438,6 @@ private static String timeUnitName(TimeUnit unit) { }; } - private static String intervalUnitName(IntervalUnit unit) { - return switch (unit) { - case YEAR_MONTH -> "YearMonth"; - case DAY_TIME -> "DayTime"; - case MONTH_DAY_NANO -> "MonthDayNano"; - }; - } - /** Reproduce Rust's {@code {:?}} string quoting used by the Arrow display form. */ private static String debugQuote(String s) { StringBuilder sb = new StringBuilder("\""); diff --git a/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java b/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java index f686e64..fc9105e 100644 --- a/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java +++ b/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java @@ -28,6 +28,7 @@ import java.math.BigDecimal; import java.nio.file.Files; import java.nio.file.Path; +import java.time.LocalDate; import java.time.LocalDateTime; import java.util.LinkedHashMap; import java.util.List; @@ -79,7 +80,14 @@ static void setUp() { "set -Dadbc.example.driver.path to the example driver cdylib to run this test"); // local[8]: several task slots in one executor JVM, so the per-executor connection cache // (AdbcConnectionPool) is exercised across concurrent tasks. - spark = SparkSession.builder().appName("adbc-source-test").master("local[8]").getOrCreate(); + // Java 8 date/time API: DateType -> java.time.LocalDate, avoiding Spark's legacy + // java.sql.Date conversion which needs sun.util.calendar opened on JDK 17. + spark = + SparkSession.builder() + .appName("adbc-source-test") + .master("local[8]") + .config("spark.sql.datetime.java8API.enabled", "true") + .getOrCreate(); } @AfterAll @@ -123,6 +131,11 @@ void typesSchemaMapsAndFlagsCasts() { // FixedSizeList -> variable Array (fixed layout not readable by Spark). assertEquals( DataTypes.createArrayType(DataTypes.IntegerType, true), schema.apply("vec").dataType()); + // LargeList -> Array; FixedSizeBinary -> Binary; Date64 -> Date. + assertEquals( + DataTypes.createArrayType(DataTypes.StringType, true), schema.apply("labels").dataType()); + assertEquals(DataTypes.BinaryType, schema.apply("digest").dataType()); + assertEquals(DataTypes.DateType, schema.apply("day").dataType()); // Cast columns are flagged (so filter pushdown stays off them); pass-through columns are not. assertTrue(schema.apply("channel").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); @@ -131,6 +144,9 @@ void typesSchemaMapsAndFlagsCasts() { assertTrue(schema.apply("score").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); assertTrue(schema.apply("tags").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); assertTrue(schema.apply("vec").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("labels").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("digest").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("day").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); assertFalse(schema.apply("payload").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); assertFalse(schema.apply("attrs").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); } @@ -206,6 +222,22 @@ void typesValuesRoundTripThroughCasts() { assertEquals(List.of(30, 40), r2.getList(r2.fieldIndex("vec"))); assertEquals(List.of(50, 60), r3.getList(r3.fieldIndex("vec"))); + // LargeList -> Array (large list not readable -> cast to variable list). + assertEquals(List.of("a", "b"), r1.getList(r1.fieldIndex("labels"))); + assertEquals(List.of(), r2.getList(r2.fieldIndex("labels"))); + assertEquals(List.of("c"), r3.getList(r3.fieldIndex("labels"))); + + // FixedSizeBinary -> Binary. + assertArrayEquals(new byte[] {0x01, 0x02, 0x03, 0x04}, (byte[]) r1.getAs("digest")); + assertArrayEquals( + new byte[] {(byte) 0xff, (byte) 0xfe, (byte) 0xfd, (byte) 0xfc}, + (byte[]) r3.getAs("digest")); + + // Date64 -> Date32 (day-aligned). + assertEquals(LocalDate.ofEpochDay(18518), toLocalDate(r1.getAs("day"))); + assertEquals(LocalDate.ofEpochDay(18634), toLocalDate(r2.getAs("day"))); + assertEquals(LocalDate.ofEpochDay(18750), toLocalDate(r3.getAs("day"))); + // nested List> passes through. List attrs3 = r3.getList(r3.fieldIndex("attrs")); assertEquals(2, attrs3.size()); @@ -219,6 +251,11 @@ private static BigDecimal bigValue(Row row) { return row.getDecimal(row.fieldIndex("big")); } + /** Spark returns DateType as java.sql.Date, or java.time.LocalDate under the Java 8 date API. */ + private static LocalDate toLocalDate(Object value) { + return value instanceof LocalDate d ? d : ((java.sql.Date) value).toLocalDate(); + } + @Test void fullScanAcrossPartitions() { Dataset df = load(); diff --git a/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java b/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java index bc6006e..fe535f0 100644 --- a/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java +++ b/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java @@ -21,12 +21,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.List; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -188,6 +190,83 @@ void fixedSizeListCastsToVariableList() { assertEquals("List(Float32)", SchemaConverter.castTargetString(halfFixed)); } + @Test + void largeListCastsToVariableList() { + // LargeList maps to ArrayType but has no ArrowColumnVector accessor -> must cast to List. + Field large = + new Field( + "v", + FieldType.nullable(new ArrowType.LargeList()), + List.of(nullable("item", new ArrowType.Int(32, true)))); + assertEquals(DataTypes.createArrayType(DataTypes.IntegerType, true), spark(large)); + assertEquals("List(Int32)", SchemaConverter.castTargetString(large)); + } + + @Test + void fixedSizeBinaryCastsToBinary() { + // FixedSizeBinary maps to BinaryType but has no accessor -> must cast to variable Binary. + Field fsb = nullable("c", new ArrowType.FixedSizeBinary(16)); + assertEquals(DataTypes.BinaryType, spark(fsb)); + assertEquals("Binary", SchemaConverter.castTargetString(fsb)); + } + + @Test + void date64CastsToDate32() { + Field d64 = nullable("c", new ArrowType.Date(DateUnit.MILLISECOND)); + assertEquals(DataTypes.DateType, spark(d64)); + assertEquals("Date32", SchemaConverter.castTargetString(d64)); + + // Date32 is consumable -> no cast. + assertNull(SchemaConverter.castTargetString(nullable("c", new ArrowType.Date(DateUnit.DAY)))); + } + + @Test + void sparkTargetIsAlwaysConsumable() { + // The core invariant behind the plan-time guard: every widened target is Spark-consumable. + List samples = + List.of( + nullable("a", new ArrowType.Int(8, false)), + nullable("b", new ArrowType.Int(64, false)), + nullable("c", new ArrowType.FloatingPoint(FloatingPointPrecision.HALF)), + nullable("d", new ArrowType.Timestamp(TimeUnit.NANOSECOND, "UTC")), + nullable("e", new ArrowType.Date(DateUnit.MILLISECOND)), + nullable("f", new ArrowType.Time(TimeUnit.MICROSECOND, 64)), + nullable("g", new ArrowType.FixedSizeBinary(4)), + new Field( + "h", + FieldType.nullable(new ArrowType.FixedSizeList(2)), + List.of( + nullable("item", new ArrowType.FloatingPoint(FloatingPointPrecision.HALF)))), + new Field( + "i", + FieldType.nullable(new ArrowType.LargeList()), + List.of(nullable("item", new ArrowType.Int(16, false)))), + new Field( + "j", + FieldType.nullable(ArrowType.Struct.INSTANCE), + List.of(nullable("x", new ArrowType.Int(32, false))))); + for (Field f : samples) { + assertConsumable(SchemaConverter.sparkTarget(f)); + } + } + + private static void assertConsumable(Field field) { + assertTrue( + SchemaConverter.sparkConsumable(field.getType()), + "sparkTarget produced a non-consumable type: " + field.getType()); + field.getChildren().forEach(SchemaConverterTest::assertConsumable); + } + + @Test + void unsupportedTypeFailsAtSchemaBuild() { + // An Interval has no accessor and no supported cast -> fail fast at inferSchema, naming it. + Schema schema = + new Schema(List.of(Field.nullable("dur", new ArrowType.Interval(IntervalUnit.DAY_TIME)))); + IllegalArgumentException e = + assertThrows(IllegalArgumentException.class, () -> SchemaConverter.toSparkSchema(schema)); + assertTrue(e.getMessage().contains("dur"), e.getMessage()); + } + // --- schema-level: metadata flag + projection planning ------------------- @Test