diff --git a/.gitignore b/.gitignore index 25c9216..c6a4806 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ rust-target/ *.iml .DS_Store tpch-data/ +# Lockfile for the example ADBC driver crate (a demo cdylib, not a published lib). +examples/adbc-datafusion-driver/Cargo.lock .claude docs/superpowers docs/build/ diff --git a/docs/source/user-guide/adbc-spark-connector.md b/docs/source/user-guide/adbc-spark-connector.md new file mode 100644 index 0000000..601fd69 --- /dev/null +++ b/docs/source/user-guide/adbc-spark-connector.md @@ -0,0 +1,240 @@ + + +# DataFusion as a Spark data source (ADBC) + +This repository ships a Spark connector that lets a Spark job read from a +DataFusion `TableProvider` as a native Spark `DataSourceV2`. Spark sees an +ordinary external table; the data is produced by DataFusion in native code and +handed back as Apache Arrow batches. + +## What this repository provides + +Two pieces work together: + +- **The `adbc-datafusion` Spark connector** (`spark/`) — a Spark + `DataSourceV2` registered under the format name `adbc-datafusion`. It builds + the scan, pushes projection / filters / limit down, maps DataFusion's scan + partitions onto Spark input partitions, and imports each partition's Arrow + data into Spark **zero-copy** as an `ArrowColumnVector`. +- **An example DataFusion ADBC driver** (`examples/adbc-datafusion-driver/`) — + a small native library (cdylib) that exposes a DataFusion `TableProvider` + through the [ADBC](https://arrow.apache.org/adbc/) C API. It is a worked + example you can copy: your own driver registers whatever provider you want and + is loaded the same way. + +The connector is built on the standard Apache Arrow ADBC Java driver manager +(`adbc-core` + `adbc-driver-jni`), so it talks to the native driver through a +stable, public interface. Because that boundary is plain ADBC, the same +connector can front **any** ADBC-speaking source; the DataFusion driver shown +here is just one example, and you can point it at a different ADBC driver +without changing the connector. + +```text +Spark DataSourceV2 (adbc-datafusion) + → arrow-adbc Java driver manager (adbc-core + adbc-driver-jni) + → native ADBC driver (the DataFusion cdylib) + → your DataFusion TableProvider +``` + +## Reading from Spark + +Point the `driver` option at the built driver library and the `table` option at +a table the provider exposes: + +```scala +val df = spark.read + .format("adbc-datafusion") + .option("driver", "/abs/path/to/libadbc_datafusion_example_driver.so") + .option("entrypoint", "AdbcDatafusionExampleInit") + .option("table", "example") + .load() + +df.filter("id > 1").select("name").show() +``` + +The connector probes the schema once (on the driver) when the `DataFrame` is +created, then plans the scan. `df.filter(...)` and `df.select(...)` are pushed +into the DataFusion scan where possible (see below); the rest Spark evaluates +itself. + +## Two partitioning models that must be reconciled + +The heart of the connector is mapping DataFusion's idea of a partitioned scan +onto Spark's idea of tasks. The two engines size scans on **different +principles**, and a provider author who ignores the difference will get a job +that runs but performs badly. + +### DataFusion: parallelism-bound + +DataFusion partitions a scan to **saturate the cores of one machine**. Its +built-in `ListingTable`, for example, packs input files into roughly +`target_partitions` groups — where `target_partitions` defaults to the local +core count — using a minimum-size floor and optional single-file splitting only +to *reach* that count. The partition count is therefore approximately + +```text +N ≈ target_partitions ≈ cores +``` + +independent of how much data there is. More data means **bigger** partitions, +not more of them. + +### Spark: one task per partition, byte-bound + +Spark turns **each input partition into one task**, and a native Spark file +source sizes partitions by **bytes**: it targets roughly one partition per +`spark.sql.files.maxPartitionBytes` (default 128 MB), cutting files into splits +and bin-packing them to that size. (Spark shrinks the target below 128 MB when +the data is small, so that there are at least enough partitions to keep every +core busy.) + +So large data means **many** partitions — typically far more than the cluster +has cores — which run in waves. Each task processes a bounded slice (≈128 MB), +which is what keeps memory bounded and lets Spark reschedule stragglers. + +### How the connector joins them + +For a scan that only pushes projection / filter / limit (no DataFusion-side +join or aggregation), the physical plan's output partitioning **is** the +provider's `scan()` partitioning, and the mapping to Spark is **1:1**: + +```text +provider.scan() output partitions + → DataFusion physical plan partition count N + → N ADBC partition descriptors + → N Spark input partitions = N Spark tasks (scan stage) +``` + +Spark never splits or merges a partition afterward — `df.rdd().getNumPartitions() +== N`, and the cluster core count only bounds how many of the `N` run at once. + +**The consequence you must plan for:** suppose a provider keeps DataFusion's +default of `N = target_partitions ≈ cores` partitions. On a large dataset that +produces only a handful of partitions, each holding about `total_bytes / cores` +of data — so Spark runs a few very large tasks. Because a task is the unit of +work, those oversized tasks bring memory pressure, stragglers that hold up the +whole stage, and no way for Spark to rebalance. To feed Spark well, a provider +should size partitions by **bytes** — the way a Spark file source does — rather +than by the `target_partitions` count. + +## Sizing the scan + +Aim for the Spark sweet spot: + +- **Floor:** `N` ≥ total executor cores, or cores sit idle. `~2–4× cores` gives + slack for skew and stragglers. +- **Per-partition size:** big enough to amortize per-task overhead — the + canonical target is **128–256 MB** of data per partition (matching Spark's + `spark.sql.files.maxPartitionBytes` default of 128 MB). +- **Ceiling — and ADBC lowers it.** Beyond Spark's generic per-task cost, this + path adds two per-partition costs: each of the `N` ADBC descriptors carries a + copy of the **whole serialized physical plan**, and each task **deserializes** + that plan when it reads its partition. (The driver caches a deserialized plan + per connection, but that does not help here: each Spark task uses its own + connection and reads a single partition, so the plan is deserialized once per + task and never reused.) Tens of thousands of tiny partitions is far costlier + here than for a native file source. Keep `N` in the hundreds; prefer bigger + partitions over a huge count. + +Rule of thumb, balanced by **bytes** rather than split count: + +```text +N ≈ clamp(total_bytes / target_bytes, floor=cores, ceiling≈hundreds) +``` + +## Writing a provider that feeds Spark well + +Your provider's `scan()` receives the session, so it can read the connector's +parallelism hint and then decide based on what it knows about the data: + +```text +T = state.config().target_partitions() // connector hint ≈ desired parallelism +target_bytes = ~128–256 MB // bias larger for ADBC to keep N down +min_bytes = floor so partitions never get tiny // ~ tens of MB + +if total_bytes known AND splittable (files / row groups / key ranges): + N = clamp(ceil(total_bytes / target_bytes), 1, CEILING) // byte-bound + bin-pack splits into N groups BALANCED BY BYTES // not by count + do not split below min_bytes; merge small splits + +elif split_count known but not bytes (shards / external partitions): + N = min(split_count, CEILING) // natural partitioning + if split_count >> T: coalesce shards into ~T balanced groups + +else (size and splits unknown; opaque / streaming): + N = T // fall back to the hint + lean slightly high — stragglers hurt more than overhead +``` + +- **Balance by bytes, not by split count.** A task is atomic; Spark cannot + rebalance a fat partition. One 10 GB partition among 100 small ones stalls the + whole stage. Bin-pack so partition bytes are even. +- **`target_partitions` is a hint, cap, and fallback** — not the data-bound + count. Use it when bytes are unknown or as a parallelism ceiling; when you know + the data size, let **bytes** drive `N`. + +### Planning happens once, on the driver + +Your provider's `scan()` runs **once**, on the driver, while the multi-partition +descriptors are built. The driver plans the query, fixes the partitioning, and +**serializes the whole physical plan** into each descriptor. Each executor task +then deserializes that plan and executes its partition index. So partition +`index i` always means the same slice: there is no second planning pass that +could disagree with the first, and no need to make `N` stable across a +driver/executor boundary. + +The one requirement this places on a provider is that its plan be +**serializable**: built-in DataFusion nodes round-trip through the default codec, +but a custom `ExecutionPlan` node needs a `PhysicalExtensionCodec` registered +with the driver so it can be encoded on the driver and decoded on the executor. + +## Connector options + +| Option | Required | Meaning | +| --- | --- | --- | +| `driver` | yes | Path to (or manifest name of) the native ADBC driver library. | +| `entrypoint` | depends on driver | Name of the driver's C init symbol (e.g. `AdbcDatafusionExampleInit`). | +| `table` | yes | Table the provider exposes; its schema is probed on the driver. | +| `target_partitions` | no | Parallelism hint. Defaults to `SparkContext.defaultParallelism` (total executor cores). The connector issues `SET datafusion.execution.target_partitions = N` on the planning session, so it influences the partition count when the driver builds the physical plan; that plan (with its partitioning fixed) is what gets serialized into the descriptors. Pass `k × cores` to raise parallelism. | +| `manifest.path` | no | Extra search path for ADBC driver manifests when `driver` is a name rather than an absolute path. | +| *(anything else)* | no | Forwarded verbatim as native ADBC database options, so provider-specific knobs pass straight through. | + +Two things are intentionally **not** connector options: + +- **`maxPartitionBytes`** — byte-sizing lives in the provider's `scan()`; the + connector does not know the data size and cannot derive it. Pass a sizing knob + as a provider-specific option if your provider supports one. +- **Reported partitioning / storage-partitioned joins + (`SupportsReportPartitioning`)** — not available over ADBC today: the + descriptor carries no distribution metadata, and DataFusion's hash + partitioning does not match Spark's key-grouped / bucketed model. + +## Summary + +- The connector maps the provider's `scan()` partition count **1:1** onto Spark + tasks. +- DataFusion sizes scans to **cores** (parallelism-bound); Spark sizes them to + **bytes** (one task per ~128 MB). To feed Spark well, size your provider's + partitions by bytes — typically 128–256 MB each, floored at executor cores and + capped in the hundreds because every descriptor carries the full serialized + plan and each task deserializes it. +- Planning happens once on the driver; the physical plan is serialized into each + descriptor and executors run it as-is, so a custom `ExecutionPlan` node just + needs a registered `PhysicalExtensionCodec` to be serializable. diff --git a/docs/source/user-guide/index.md b/docs/source/user-guide/index.md index a8d42be..68351b0 100644 --- a/docs/source/user-guide/index.md +++ b/docs/source/user-guide/index.md @@ -26,7 +26,8 @@ DataFrame queries execute in native Rust; results return to the JVM as Data Interface. This guide covers installation, the `SessionContext` and `DataFrame` APIs, -and Parquet ingestion. +Parquet ingestion, table providers, and using DataFusion as a Spark data +source over ADBC. ```{toctree} :maxdepth: 1 @@ -39,6 +40,7 @@ parquet proto-plans scalar-udf table-provider +adbc-spark-connector api-reference ``` diff --git a/examples/adbc-datafusion-driver/Cargo.toml b/examples/adbc-datafusion-driver/Cargo.toml new file mode 100644 index 0000000..dc213c9 --- /dev/null +++ b/examples/adbc-datafusion-driver/Cargo.toml @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "adbc-datafusion-example-driver" +version = "0.1.0" +edition = "2021" +publish = false +description = "End-to-end example: an ADBC driver that exposes a custom DataFusion TableProvider" + +# Standalone crate that stands in for a real, separate driver repo. The empty +# [workspace] table makes it its own workspace root so it is NOT pulled into the +# parent datafusion-java workspace (whose members are listed explicitly). +[workspace] + +[dependencies] +# The ready-made ADBC-over-DataFusion driver. We depend on it as a LIBRARY and +# only customize the SessionContext via its new_with_context_init hook -- no +# fork, no datafusion-ffi module loading. Pulled from its own repo (git) since +# that is where it lives. +adbc-driver-datafusion = { git = "https://github.com/adbc-drivers/datafusion" } +adbc_core = "0.23" +adbc_ffi = "0.23" +datafusion = "53" +async-trait = "0.1" +# For building Float16 example data (arrow's half-precision element type). +half = "2" + +# Build both an rlib (so the provider/driver can be unit-tested or embedded) and +# a cdylib (the artifact the arrow-adbc Java driver manager loads at runtime). +[lib] +crate-type = ["cdylib", "lib"] + +# Temporary: redirect adbc-driver-datafusion to the PR branch that implements +# execute_partitions / read_partition plus the shared (database-scoped) plan +# cache (adbc-drivers/datafusion#32), so we can run a full multi-partition +# system test before it merges. The crate is not on crates.io, so this patches +# the git source (not [patch.crates-io]). Pinned to a rev (not the branch) so +# builds are reproducible and the plan-cache assertion in AdbcSourceTest is +# stable. Remove once the PR is merged and the canonical dependency carries the +# change. +[patch."https://github.com/adbc-drivers/datafusion"] +adbc-driver-datafusion = { git = "https://github.com/timsaucer/adbc-driver-datafusion", rev = "9b51df2b13d20d81c563ef3ee818db1c02082fac" } diff --git a/examples/adbc-datafusion-driver/PARTITIONS.md b/examples/adbc-datafusion-driver/PARTITIONS.md new file mode 100644 index 0000000..1a2fe47 --- /dev/null +++ b/examples/adbc-datafusion-driver/PARTITIONS.md @@ -0,0 +1,208 @@ + + +# Multi-partition execution: upstream driver patch + +The Spark connector in this repo can spread a scan across executors using ADBC's +partitioned-execution API (`AdbcStatement.executePartitioned()` → +`AdbcConnection.readPartition(descriptor)`). The Java JNI driver manager binds +both calls, but the **DataFusion driver does not implement them yet**: + +```rust +// adbc-driver-datafusion/src/lib.rs (today) +fn execute_partitions(&mut self) -> Result { + Err(ErrorHelper::not_implemented().message("execute_partitions").to_adbc()) +} +fn read_partition(&self, _partition: impl AsRef<[u8]>) -> Result<...> { + Err(ErrorHelper::not_implemented().message("read_partition").to_adbc()) +} +``` + +So the connector calls `executePartitioned()`, catches `NOT_IMPLEMENTED`, and +falls back to a single `executeQuery()` partition. To get real parallelism, the +driver must produce one ADBC partition per DataFusion output partition. This is +an upstream change to [`adbc-driver-datafusion`](https://github.com/adbc-drivers/datafusion); +the patch below is PR-ready against the structures in its `src/lib.rs`. Until it +lands, point this example's git dependency at a branch carrying it. + +## Descriptor: self-contained, by construction + +A partition descriptor is shipped to another process and handed back via +`read_partition` with no live handle. It must therefore reconstruct the work on +its own. We encode three things: + +``` +[u32 LE target_partitions][u32 LE partition_index][u8 kind][query bytes...] +kind 0 = Substrait plan (prost-encoded substrait.proto.Plan) +kind 1 = SQL text (utf-8) +``` + +The `query bytes` rebuild the logical plan in the executor's `SessionContext` +(which has the same providers registered via `ContextInit`), and +`partition_index` selects the slice. + +## The correctness subtlety: deterministic partitioning + +`read_partition` re-plans from the query, so partition `i` must mean the **same** +slice it meant when `execute_partitions` counted the partitions. DataFusion +planning is deterministic given the same plan **and the same +`target_partitions`** — but that defaults to the machine's CPU count, which +differs between the driver host and an executor. Left unpinned, the executor +could re-plan into a different partition count and read the wrong (or an +out-of-range) slice. + +Fix: capture `target_partitions` at `execute_partitions` time, encode it in every +descriptor, and re-plan in `read_partition` with that value pinned. Both sides +then build the identical physical plan. + +## Patch + +```rust +// --- new: descriptor codec ------------------------------------------------- +use prost::Message; // already a dependency + +enum DescQuery { Substrait(Vec), Sql(String) } + +fn encode_query(q: &QueryState) -> adbc_core::error::Result<(u8, Vec)> { + match q { + QueryState::Substrait(plan) => Ok((0, plan.encode_to_vec())), + QueryState::Sql(sql) => Ok((1, sql.clone().into_bytes())), + QueryState::Prepared(_) => Err(ErrorHelper::not_implemented() + .message("partitioned execution of prepared statements") + .to_adbc()), + } +} + +fn encode_descriptor(target_partitions: u32, index: u32, kind: u8, query: &[u8]) -> Vec { + let mut out = Vec::with_capacity(9 + query.len()); + out.extend_from_slice(&target_partitions.to_le_bytes()); + out.extend_from_slice(&index.to_le_bytes()); + out.push(kind); + out.extend_from_slice(query); + out +} + +fn decode_descriptor(bytes: &[u8]) -> adbc_core::error::Result<(usize, u32, DescQuery)> { + if bytes.len() < 9 { + return Err(ErrorHelper::invalid_argument().message("short partition descriptor").to_adbc()); + } + let target_partitions = u32::from_le_bytes(bytes[0..4].try_into().unwrap()) as usize; + let index = u32::from_le_bytes(bytes[4..8].try_into().unwrap()); + let query = match bytes[8] { + 0 => DescQuery::Substrait(bytes[9..].to_vec()), + 1 => DescQuery::Sql(String::from_utf8(bytes[9..].to_vec()) + .map_err(|e| ErrorHelper::invalid_argument().message(e.to_string()).to_adbc())?), + other => return Err(ErrorHelper::invalid_argument() + .format(format_args!("unknown descriptor kind {other}")).to_adbc()), + }; + Ok((target_partitions, index, query)) +} + +// --- DataFusionReader: construct from a single-partition stream ------------ +impl DataFusionReader { + pub fn from_stream( + runtime: Arc, + stream: datafusion::execution::SendableRecordBatchStream, + schema: SchemaRef, + ) -> Self { + Self { runtime, stream, schema } + } +} + +// --- DataFusionStatement::execute_partitions ------------------------------- +fn execute_partitions(&mut self) -> adbc_core::error::Result { + let query = self.query.as_ref().ok_or_else(|| { + ErrorHelper::invalid_state().message("no query or Substrait plan has been set").to_adbc() + })?; + self.runtime.block_on(async { + let df = query.execute(&self.ctx).await?; // registers object store, plans logically + let schema = df.schema().as_arrow().clone(); + let physical = df.create_physical_plan().await.map_err(ErrorHelper::from_datafusion)?; + let n = physical.output_partitioning().partition_count() as u32; + let target_partitions = self.ctx.copied_config().target_partitions() as u32; + let (kind, query_bytes) = encode_query(query)?; + let partitions = (0..n) + .map(|i| encode_descriptor(target_partitions, i, kind, &query_bytes)) + .collect::>(); + Ok(adbc_core::PartitionedResult { partitions, schema, rows_affected: -1 }) + }) +} + +// --- DataFusionConnection::read_partition ---------------------------------- +fn read_partition( + &self, + partition: impl AsRef<[u8]>, +) -> adbc_core::error::Result> { + let (target_partitions, index, query) = decode_descriptor(partition.as_ref())?; + self.runtime.block_on(async { + // Pin target_partitions so the physical plan matches execute_partitions'. + let state = datafusion::execution::session_state::SessionStateBuilder::new_from_existing( + self.ctx.state(), + ) + .with_config(self.ctx.copied_config().with_target_partitions(target_partitions)) + .build(); + + let plan = match query { + DescQuery::Substrait(bytes) => { + let proto = Plan::decode(bytes.as_slice()) + .map_err(|e| ErrorHelper::invalid_argument().message(e.to_string()).to_adbc())?; + from_substrait_plan(&state, &proto).await.map_err(ErrorHelper::from_datafusion)? + } + DescQuery::Sql(sql) => { + state.create_logical_plan(&sql).await.map_err(ErrorHelper::from_datafusion)? + } + }; + register_object_store_for_plan(&self.ctx, &plan).await.map_err(ErrorHelper::from_datafusion)?; + let schema = plan.schema().as_arrow().clone(); + let physical = state.create_physical_plan(&plan).await.map_err(ErrorHelper::from_datafusion)?; + let stream = physical + .execute(index as usize, state.task_ctx()) + .map_err(ErrorHelper::from_datafusion)?; + Ok(Box::new(DataFusionReader::from_stream(self.runtime.clone(), stream, schema.into())) + as Box) + }) +} +``` + +## How the connector consumes it (already implemented here) + +`AdbcScanImpl.planInputPartitions` (driver side): + +1. build the pushed Substrait plan, `setSubstraitPlan`, call `executePartitioned()`; +2. on success, emit one `AdbcInputPartition` per `PartitionDescriptor` + (`descriptor == true`, payload = the descriptor bytes); +3. on `NOT_IMPLEMENTED` / `NOT_FOUND`, emit a single plan-carrying partition + (`descriptor == false`). + +`AdbcColumnarPartitionReader` (executor side): for a descriptor partition it calls +`connection.readPartition(payload)`; otherwise `setSubstraitPlan` + `executeQuery`. +Both yield an `ArrowReader` wrapped zero-copy into `ArrowColumnVector`s. + +So once this patch lands upstream, the connector lights up multi-partition with +no further change on the Java side. + +## Caveats / follow-ups + +- `Prepared` statements aren't covered (no portable plan serialization without + `datafusion-proto`); they return `NOT_IMPLEMENTED`, and the connector falls + back to single-partition. The connector only sends Substrait, so this is fine. +- `register_object_store_for_plan` is applied in `read_partition` too, so + object-store-backed providers work on the executor. +- Verify `output_partitioning().partition_count()` against the target source; for + some providers a `RepartitionExec` may be needed to get N > 1. diff --git a/examples/adbc-datafusion-driver/README.md b/examples/adbc-datafusion-driver/README.md new file mode 100644 index 0000000..1712b99 --- /dev/null +++ b/examples/adbc-datafusion-driver/README.md @@ -0,0 +1,113 @@ + + +# Example: an ADBC driver for a custom DataFusion provider + +End-to-end demonstration of the ADBC route ("Part A"): expose a custom +DataFusion `TableProvider` as a standard ADBC driver, then read it from the +`adbc-datafusion` Spark connector in this repo. + +This crate stands in for a **separate driver repo**. It is its own Cargo +workspace (note the empty `[workspace]` in `Cargo.toml`) and is *not* part of +the parent `datafusion-java` workspace. + +## What it does + +- `src/provider.rs` -- `ExampleTableProvider`, an ordinary `TableProvider` + exposing one fixed table `example(id BIGINT, name STRING)`. Replace this with + your real provider. +- `src/lib.rs` -- `ExampleDriver`, a thin wrapper over + [`adbc-driver-datafusion`](https://github.com/adbc-drivers/datafusion)'s + `DataFusionDriver`. Its `Default` builds the inner driver with a + `new_with_context_init` hook that registers the provider into every session. + `adbc_ffi::export_driver!` emits the C `AdbcDatafusionExampleInit` entrypoint. + +No fork of the driver, no `datafusion-ffi` module loading -- the provider is +compiled in. + +## Build + +```bash +cargo build --release +# produces target/release/libadbc_datafusion_example_driver.{so,dylib,dll} +``` + +## Use from the Spark connector + +The `adbc-datafusion` Spark DataSource in this repo loads the cdylib through the +arrow-adbc Java driver manager. Point the `driver` option at the built library +and the `table` option at the provider's table: + +```scala +val df = spark.read + .format("adbc-datafusion") + .option("driver", "/abs/path/to/libadbc_datafusion_example_driver.so") + .option("table", "example") + .load() + +df.show() +// +---+-----+ +// | id| name| +// +---+-----+ +// | 1|alice| +// | 2| bob| +// | 3|carol| +// +---+-----+ +``` + +`option("target_partitions", N)` tunes scan parallelism: the connector issues +`SET datafusion.execution.target_partitions = N` on the planning session, and the +driver pins it into each partition descriptor so executors re-plan identically. +It defaults to the cluster parallelism (`SparkContext.defaultParallelism`). +Repartition-aware providers (e.g. file scans) use it to choose the partition +count; fixed-partition providers (like this in-memory example) keep their +intrinsic count. + +Any other `option(...)` keys (besides `driver`, `table`, `target_partitions`) are +forwarded verbatim as ADBC database options, so a real provider can be configured +through them. + +## Use from any ADBC client + +The same cdylib works with any ADBC driver manager, e.g. Python: + +```python +import adbc_driver_manager.dbapi as dbapi + +with dbapi.connect( + driver="/abs/path/to/libadbc_datafusion_example_driver.so", + entrypoint="AdbcDatafusionExampleInit", +) as conn, conn.cursor() as cur: + cur.execute("SELECT * FROM example") + print(cur.fetch_arrow_table()) +``` + +## Notes + +- `datafusion` / `arrow` versions are pinned to match `adbc-driver-datafusion` + (datafusion 53, arrow 58) so the shared provider type is ABI-compatible across + the driver and any in-repo Rust that registers the same provider. +- Pushdown: ADBC clients send SQL or a Substrait plan; DataFusion's optimizer + applies projection/filter/limit against the provider. The Spark connector + pushes a Substrait plan with projection/filter/limit folded in. +- Multi-partition: with the `execute_partitions` / `read_partition` driver + support (adbc-drivers/datafusion#32) and arrow-adbc >= 0.24's JNI bridge, the + Spark connector reads one partition per provider partition; against older + arrow-adbc it degrades to a single partition. See + [PARTITIONS.md](PARTITIONS.md) for the driver-side code. diff --git a/examples/adbc-datafusion-driver/pyspark_e2e.py b/examples/adbc-datafusion-driver/pyspark_e2e.py new file mode 100644 index 0000000..279f909 --- /dev/null +++ b/examples/adbc-datafusion-driver/pyspark_e2e.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""End-to-end PySpark test of the adbc-datafusion connector + example driver. + +PySpark drives the same JVM Spark, so the connector is used exactly as from +Scala/Java: spark.read.format("adbc-datafusion"). No Python-side connector code. + +Prereqs: + - arrow-adbc `main` (0.24.0-SNAPSHOT) built into ~/.m2 (Substrait + partitions), + - the connector jar packaged (mvn -pl spark -Padbc-snapshot package), + - its runtime deps copied to a dir (mvn dependency:copy-dependencies), + - arrow-c-data (matching Spark's bundled Arrow, e.g. 18.1.0) added to that dir + -- vanilla Spark bundles arrow-vector/arrow-memory but NOT arrow-c-data, + - the example cdylib built (cargo build --release). + +Configure via env vars (defaults assume this repo's build layout): + CONNECTOR_JAR path to datafusion-spark-*.jar + ADBC_JARS_DIR dir of runtime dep jars (adbc-*, substrait core, datafusion-java, ...) + ADBC_DRIVER_LIB path to libadbc_datafusion_example_driver.{dylib,so} +""" + +import datetime +import glob +import os +import sys +from decimal import Decimal + +from pyspark.sql import SparkSession + +# Field metadata flag the connector sets on columns it casts source-side to a Spark-native +# layout (SchemaConverter.CAST_METADATA_KEY). +CAST_META_KEY = "org.apache.datafusion.spark.adbc.cast" + +REPO = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +LIBEXT = "dylib" if sys.platform == "darwin" else "so" + +connector_jar = os.environ.get( + "CONNECTOR_JAR", + glob.glob(os.path.join(REPO, "spark/target/datafusion-spark-*.jar"))[0], +) +jars_dir = os.environ.get("ADBC_JARS_DIR", "/tmp/adbc-libs") +driver_lib = os.environ.get( + "ADBC_DRIVER_LIB", + os.path.join(REPO, f"rust-target/release/libadbc_datafusion_example_driver.{LIBEXT}"), +) + +jars = [connector_jar] + glob.glob(os.path.join(jars_dir, "*.jar")) +assert os.path.exists(driver_lib), f"driver cdylib not found: {driver_lib}" + +# Java 17 opens Spark/Arrow need (Spark 4.0 sets most; be explicit for the data path). +opens = " ".join( + f"--add-opens=java.base/{p}=ALL-UNNAMED" + for p in ("java.nio", "sun.nio.ch", "java.lang", "java.util") +) + +spark = ( + SparkSession.builder.master("local[2]") + .appName("adbc-datafusion-pyspark-e2e") + .config("spark.jars", ",".join(jars)) + .config("spark.driver.extraJavaOptions", opens) + .config("spark.executor.extraJavaOptions", opens) + .getOrCreate() +) + +def read(table): + return ( + spark.read.format("adbc-datafusion") + .option("driver", driver_lib) + .option("entrypoint", "AdbcDatafusionExampleInit") + .option("table", table) + .load() + ) + + +try: + # --- `example`: two Spark-native columns, three partitions ------------------------------- + df = read("example") + + print("=== example schema ===") + df.printSchema() + + num_partitions = df.rdd.getNumPartitions() + print("numPartitions:", num_partitions) + + ids = sorted(r["id"] for r in df.collect()) + names = sorted(r["name"] for r in df.select("name").collect()) + filtered = sorted(r["id"] for r in df.filter("id > 1").collect()) + print("full scan ids:", ids) + print("projection names:", names) + print("filter id>1:", filtered) + + assert ids == [1, 2, 3], ids + assert names == ["alice", "bob", "carol"], names + assert filtered == [2, 3], filtered + assert num_partitions >= 2, f"expected multi-partition, got {num_partitions}" + + # --- `types`: the schema-conversion + source-side arrow_cast coverage -------------------- + dft = read("types") + print("=== types schema ===") + dft.printSchema() + + # cast columns are flagged (so filter pushdown stays off them); pass-through ones are not. + cast_cols = {f.name for f in dft.schema.fields if f.metadata.get(CAST_META_KEY)} + print("cast columns:", sorted(cast_cols)) + assert cast_cols == { + "channel", + "big", + "event_time", + "score", + "tags", + "vec", + "labels", + "digest", + "day", + }, cast_cols + + # count() prunes the projection to empty; the connector must not emit a bare SELECT * (which + # would return the raw, uncast schema and fail the reader on the ns timestamp / unsigned ids). + types_count = dft.count() + print("types count:", types_count) + assert types_count == 3, types_count + + # value correctness: collecting whole rows forces the vectorized reader to decode each + # ArrowColumnVector, so a bad cast (wrong layout, unit relabel instead of rescale, unsigned + # overflow) fails here. + rows = {r["id"]: r for r in dft.collect()} + r1, r2, r3 = rows[1], rows[2], rows[3] + + # unsigned UInt16 -> Integer, widened past i16::MAX + assert [r1["channel"], r2["channel"], r3["channel"]] == [100, 40000, 65535] + + # unsigned UInt64 -> Decimal(20,0): lossless for values past i64::MAX (a Long would overflow) + assert r1["big"] == Decimal("18446744073709551615") + assert r2["big"] == Decimal(0) + assert r3["big"] == Decimal("9223372036854775808") + + # nanosecond Timestamp -> microsecond TimestampNTZ, rescaled (not relabeled -> would be ~1970) + assert r1["event_time"] == datetime.datetime(2020, 9, 13, 12, 26, 40) + assert r2["event_time"] == datetime.datetime(2021, 1, 7, 6, 13, 20) + assert r3["event_time"] == datetime.datetime(2021, 5, 3, 0, 0, 0) + + # Float16 -> Float + assert [r1["score"], r2["score"], r3["score"]] == [1.5, 2.5, 3.5] + + # Binary passes through + assert bytes(r1["payload"]) == b"\x01\x02" + assert bytes(r2["payload"]) == b"" + assert bytes(r3["payload"]) == b"\xff\xfe" + + # nested List -> Array (recursive cast) + assert r1["tags"] == [1, 2] + assert r2["tags"] == [] + assert r3["tags"] == [3] + + # FixedSizeList -> Array (fixed->variable + element widening) + assert r1["vec"] == [10, 20] + assert r2["vec"] == [30, 40] + assert r3["vec"] == [50, 60] + + # LargeList -> Array + assert r1["labels"] == ["a", "b"] + assert r2["labels"] == [] + assert r3["labels"] == ["c"] + + # FixedSizeBinary -> Binary + assert bytes(r1["digest"]) == b"\x01\x02\x03\x04" + assert bytes(r3["digest"]) == b"\xff\xfe\xfd\xfc" + + # Date64 -> Date32 (day-aligned) + assert r1["day"] == datetime.date(2020, 9, 13) + assert r2["day"] == datetime.date(2021, 1, 7) + assert r3["day"] == datetime.date(2021, 5, 3) + + # nested List> passes through + assert [(x["key"], x["val"]) for x in r1["attrs"]] == [("a", "1")] + assert r2["attrs"] == [] + assert [(x["key"], x["val"]) for x in r3["attrs"]] == [("b", "2"), ("c", "3")] + + # --- filter on a cast column: excluded from pushdown, evaluated by Spark on the cast value - + channel_gt = sorted(r["id"] for r in dft.filter(dft.channel > 100).collect()) + print("filter channel>100 ids:", channel_gt) + assert channel_gt == [2, 3], channel_gt + + print("\nPYSPARK E2E OK (example: multi-partition + pushdown; types: casts + nested + filter)") +finally: + spark.stop() diff --git a/examples/adbc-datafusion-driver/src/lib.rs b/examples/adbc-datafusion-driver/src/lib.rs new file mode 100644 index 0000000..e2c89f7 --- /dev/null +++ b/examples/adbc-datafusion-driver/src/lib.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An end-to-end ADBC driver for a specific, compiled-in DataFusion provider. +//! +//! This is "Part A" of the ADBC route: rather than fork [`adbc-driver-datafusion`] +//! or load providers over `datafusion-ffi`, we depend on it as a library and use +//! its [`DataFusionDriver::new_with_context_init`] hook to register our own +//! [`ExampleTableProvider`] into every database's `SessionContext`. The provider +//! is statically linked; ADBC clients (the arrow-adbc Java driver manager, the +//! Spark connector, Python `adbc_driver_manager`, ...) then query it by SQL or +//! Substrait. +//! +//! [`export_driver!`] emits the C `AdbcDriverInit` entrypoint a driver manager +//! dlopen's. Build the cdylib (`cargo build --release`) and point an ADBC client +//! at the resulting shared library. + +mod provider; + +use std::sync::Arc; + +use adbc_core::error::Result; +use adbc_core::options::{OptionDatabase, OptionValue}; +use adbc_core::Driver; +use adbc_driver_datafusion::{ContextInit, DataFusionDatabase, DataFusionDriver}; +use datafusion::prelude::SessionContext; + +pub use provider::{ExampleTableProvider, TypesTableProvider}; + +/// An ADBC driver that registers [`ExampleTableProvider`] into each session. +/// +/// A thin wrapper over [`DataFusionDriver`]: it exists only so that `Default` +/// (which [`export_driver!`] calls) constructs the inner driver with our +/// `ContextInit` hook. All `Driver` behavior is delegated. +pub struct ExampleDriver(DataFusionDriver); + +impl Default for ExampleDriver { + fn default() -> Self { + let init: ContextInit = Arc::new(|ctx: &mut SessionContext, _opts| { + // Register our provider. A real driver would read connection + // options out of `_opts` (DatabaseOpts) to decide what to expose. + ctx.register_table( + ExampleTableProvider::TABLE_NAME, + Arc::new(ExampleTableProvider::new()), + )?; + // A second table whose schema spans the Arrow types the Spark connector must cast + // or map (unsigned, ns timestamp, binary, nested list/struct); the E2E test scans + // it to check schema conversion and source-side arrow_cast with known values. + ctx.register_table( + TypesTableProvider::TABLE_NAME, + Arc::new(TypesTableProvider::new()), + )?; + Ok(()) + }); + ExampleDriver(DataFusionDriver::new_with_context_init(None, init)) + } +} + +impl Driver for ExampleDriver { + type DatabaseType = DataFusionDatabase; + + fn new_database(&mut self) -> Result { + self.0.new_database() + } + + fn new_database_with_opts( + &mut self, + opts: impl IntoIterator, + ) -> Result { + self.0.new_database_with_opts(opts) + } +} + +// Export the C ADBC entrypoint. The arrow-adbc driver manager resolves this +// symbol after dlopen'ing the cdylib. +adbc_ffi::export_driver!(AdbcDatafusionExampleInit, ExampleDriver); diff --git a/examples/adbc-datafusion-driver/src/provider.rs b/examples/adbc-datafusion-driver/src/provider.rs new file mode 100644 index 0000000..2c18fe7 --- /dev/null +++ b/examples/adbc-datafusion-driver/src/provider.rs @@ -0,0 +1,341 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The "custom" `TableProvider`s this example exposes. +//! +//! In a real deployment this is *your* provider crate -- reading your store, +//! your format, your catalog. Here they are fixed tables so the example is +//! self-contained; the only thing that matters for the integration is that they +//! are ordinary [`datafusion::catalog::TableProvider`]s. Execution is delegated +//! to in-memory [`MemTable`]s; swap the `scan` for your real source and the rest +//! of the pipeline (ADBC driver, Spark connector) is unchanged. +//! +//! Two tables: +//! +//! - [`ExampleTableProvider`] (`example`) -- two Spark-native columns across +//! three partitions. Keeps the partitioned-execution system test simple and +//! lets the connector use its preferred Substrait wire. +//! - [`TypesTableProvider`] (`types`) -- a schema spanning both categories the +//! Spark connector's `SchemaConverter` must handle, so the end-to-end test +//! covers them with known values and no external server: +//! - *directly representable* (pass through to `ArrowColumnVector`): +//! `Int64`, `Utf8`, `Binary`, `List>`; +//! - *cast-required* (source-side `arrow_cast` to a Spark-native layout): +//! unsigned `UInt16`/`UInt64`, nanosecond `Timestamp`, `Float16`, nested +//! `List`. +//! Each row is built as its own self-contained single-row batch (rather than +//! slicing one batch), so every partition's arrays start at offset 0 -- a +//! sliced array carries a non-zero offset into shared buffers, which the Arrow +//! C-data export at the ADBC boundary does not rebase, corrupting +//! variable-width (binary/list) columns on the Spark side. + +use std::any::Any; +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::arrow::array::{ + ArrayRef, BinaryArray, Date64Array, FixedSizeBinaryBuilder, FixedSizeListBuilder, Float16Array, + Int64Array, LargeListBuilder, ListBuilder, StringArray, StringBuilder, StructBuilder, + TimestampNanosecondArray, UInt16Array, UInt16Builder, UInt64Array, +}; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::{Session, TableProvider}; +use datafusion::datasource::MemTable; +use datafusion::error::Result; +use datafusion::logical_expr::TableType; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::Expr; +use half::f16; + +/// A custom provider exposing a single fixed table named [`Self::TABLE_NAME`]. +#[derive(Debug)] +pub struct ExampleTableProvider { + inner: MemTable, +} + +impl ExampleTableProvider { + /// The name the provider is registered under; the ADBC client scans this. + pub const TABLE_NAME: &'static str = "example"; + + pub fn new() -> Self { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + // Three partitions (one row each) so the table has a non-trivial output + // partitioning -- this is what makes ADBC execute_partitions / + // read_partition return more than one partition, exercising distributed + // reads. + let rows = [(1_i64, "alice"), (2, "bob"), (3, "carol")]; + let partitions: Vec> = rows + .iter() + .map(|(id, name)| { + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int64Array::from(vec![*id])), + Arc::new(StringArray::from(vec![*name])), + ], + ) + .expect("example record batch"); + vec![batch] + }) + .collect(); + let inner = MemTable::try_new(schema, partitions).expect("example in-memory table"); + Self { inner } + } +} + +impl Default for ExampleTableProvider { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl TableProvider for ExampleTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.inner.schema() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + // A real provider builds its own ExecutionPlan here, honoring the + // pushed projection / filters / limit. We delegate to MemTable. + self.inner.scan(state, projection, filters, limit).await + } +} + +/// One row of the [`TypesTableProvider`] table, as plain Rust values. +struct Row { + id: i64, + name: &'static str, + payload: &'static [u8], + attrs: &'static [(&'static str, &'static str)], + channel: u16, + big: u64, + /// Nanoseconds since the Unix epoch (no timezone). + event_time: i64, + score: f32, + tags: &'static [u16], + /// A fixed-size list of two unsigned ints -- Spark cannot read FixedSizeList and must have it + /// cast to a variable List (with the element widened). Always length 2. + vec: [u16; 2], + /// A LargeList of strings -- maps to ArrayType but has no accessor, so it casts to List. + labels: &'static [&'static str], + /// FixedSizeBinary(4) -- maps to BinaryType but has no accessor, so it casts to Binary. + digest: [u8; 4], + /// Date64 (ms since epoch, day-aligned) -- no accessor, casts to Date32. + day_ms: i64, +} + +// Values are chosen to exercise the widening edges: channel spans past i16::MAX, big includes +// u64 values past i64::MAX (must stay lossless via Decimal(20,0)), and event_time is nanoseconds +// that must rescale to microseconds rather than be relabeled. +const TYPE_ROWS: [Row; 3] = [ + Row { + id: 1, + name: "alice", + payload: &[0x01, 0x02], + attrs: &[("a", "1")], + channel: 100, + big: u64::MAX, + event_time: 1_600_000_000_000_000_000, // 2020-09-13 + score: 1.5, + tags: &[1, 2], + vec: [10, 20], + labels: &["a", "b"], + digest: [0x01, 0x02, 0x03, 0x04], + day_ms: 1_599_955_200_000, // 2020-09-13 (18518 days) + }, + Row { + id: 2, + name: "bob", + payload: &[], + attrs: &[], + channel: 40_000, + big: 0, + event_time: 1_610_000_000_000_000_000, // 2021-01-07 + score: 2.5, + tags: &[], + vec: [30, 40], + labels: &[], + digest: [0x00, 0x00, 0x00, 0x00], + day_ms: 1_609_977_600_000, // 2021-01-07 (18634 days) + }, + Row { + id: 3, + name: "carol", + payload: &[0xff, 0xfe], + attrs: &[("b", "2"), ("c", "3")], + channel: 65_535, + big: 9_223_372_036_854_775_808, // 2^63, past i64::MAX + event_time: 1_620_000_000_000_000_000, // 2021-05-03 + score: 3.5, + tags: &[3], + vec: [50, 60], + labels: &["c"], + digest: [0xff, 0xfe, 0xfd, 0xfc], + day_ms: 1_620_000_000_000, // 2021-05-03 (18750 days) + }, +]; + +/// A provider whose schema exercises every Arrow type the Spark connector's converter handles. +#[derive(Debug)] +pub struct TypesTableProvider { + inner: MemTable, +} + +impl TypesTableProvider { + pub const TABLE_NAME: &'static str = "types"; + + pub fn new() -> Self { + // Build the schema once from single-row prototype arrays so the declared types (in + // particular the nested list/struct types) always match the produced data exactly. + let proto = row_columns(&TYPE_ROWS[0]); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + Field::new("payload", proto[2].data_type().clone(), true), + Field::new("attrs", proto[3].data_type().clone(), true), + Field::new("channel", proto[4].data_type().clone(), true), + Field::new("big", proto[5].data_type().clone(), true), + Field::new("event_time", proto[6].data_type().clone(), true), + Field::new("score", proto[7].data_type().clone(), true), + Field::new("tags", proto[8].data_type().clone(), true), + Field::new("vec", proto[9].data_type().clone(), true), + Field::new("labels", proto[10].data_type().clone(), true), + Field::new("digest", proto[11].data_type().clone(), true), + Field::new("day", proto[12].data_type().clone(), true), + ])); + + // One self-contained single-row batch per partition (see the module docs on why we do + // not slice a shared batch). + let partitions: Vec> = TYPE_ROWS + .iter() + .map(|row| { + let batch = RecordBatch::try_new(schema.clone(), row_columns(row)) + .expect("types record batch"); + vec![batch] + }) + .collect(); + let inner = MemTable::try_new(schema, partitions).expect("types in-memory table"); + Self { inner } + } +} + +impl Default for TypesTableProvider { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl TableProvider for TypesTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.inner.schema() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + self.inner.scan(state, projection, filters, limit).await + } +} + +/// Build the single-row column arrays for one [`Row`], in schema order. +fn row_columns(row: &Row) -> Vec { + let mut tags = ListBuilder::new(UInt16Builder::new()); + for t in row.tags { + tags.values().append_value(*t); + } + tags.append(true); + + let mut vec = FixedSizeListBuilder::new(UInt16Builder::new(), row.vec.len() as i32); + for v in row.vec { + vec.values().append_value(v); + } + vec.append(true); + + let mut labels = LargeListBuilder::new(StringBuilder::new()); + for l in row.labels { + labels.values().append_value(l); + } + labels.append(true); + + let mut digest = FixedSizeBinaryBuilder::new(row.digest.len() as i32); + digest.append_value(row.digest).expect("digest bytes"); + + let attr_fields = vec![ + Field::new("key", DataType::Utf8, true), + Field::new("val", DataType::Utf8, true), + ]; + let mut attrs = ListBuilder::new(StructBuilder::from_fields(attr_fields, 0)); + for (key, val) in row.attrs { + let s = attrs.values(); + s.field_builder::(0) + .unwrap() + .append_value(key); + s.field_builder::(1) + .unwrap() + .append_value(val); + s.append(true); + } + attrs.append(true); + + vec![ + Arc::new(Int64Array::from(vec![row.id])), + Arc::new(StringArray::from(vec![row.name])), + Arc::new(BinaryArray::from_iter_values([row.payload])), + Arc::new(attrs.finish()) as ArrayRef, + Arc::new(UInt16Array::from(vec![row.channel])), + Arc::new(UInt64Array::from(vec![row.big])), + Arc::new(TimestampNanosecondArray::from(vec![row.event_time])), + Arc::new(Float16Array::from(vec![f16::from_f32(row.score)])), + Arc::new(tags.finish()) as ArrayRef, + Arc::new(vec.finish()) as ArrayRef, + Arc::new(labels.finish()) as ArrayRef, + Arc::new(digest.finish()) as ArrayRef, + Arc::new(Date64Array::from(vec![row.day_ms])), + ] +} diff --git a/examples/adbc-datafusion-driver/tests/partitions.rs b/examples/adbc-datafusion-driver/tests/partitions.rs new file mode 100644 index 0000000..c76c2e6 --- /dev/null +++ b/examples/adbc-datafusion-driver/tests/partitions.rs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Full system test of ADBC partitioned execution against the example driver. +//! +//! Exercises the upstream `execute_partitions` / `read_partition` (PR +//! adbc-drivers/datafusion#32, wired in via the [patch] in Cargo.toml) end to +//! end with our custom provider: plan a query, get one descriptor per output +//! partition, then read each partition on its own and assert the union is the +//! whole table -- the same contract the Spark connector relies on. + +use adbc_core::{Connection, Database, Driver, Statement}; +use adbc_datafusion_example_driver::{ExampleDriver, ExampleTableProvider}; +use datafusion::arrow::array::{Array, Int64Array}; + +#[test] +fn execute_partitions_then_read_each_partition() { + let mut driver = ExampleDriver::default(); + let db = driver.new_database().expect("new_database"); + let mut conn = db.new_connection().expect("new_connection"); + + // Plan the scan; the provider has three partitions. + let mut stmt = conn.new_statement().expect("new_statement"); + stmt.set_sql_query(format!( + "SELECT id, name FROM {}", + ExampleTableProvider::TABLE_NAME + )) + .expect("set_sql_query"); + let result = stmt.execute_partitions().expect("execute_partitions"); + + assert!( + result.partitions.len() >= 2, + "expected multiple partitions, got {}", + result.partitions.len() + ); + + // Read each partition independently (as a separate executor would) and + // collect the ids; the union must be exactly the whole table, once each. + let mut ids = Vec::new(); + for descriptor in &result.partitions { + let reader = conn.read_partition(descriptor).expect("read_partition"); + for batch in reader { + let batch = batch.expect("batch"); + let column = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("id column is Int64"); + for i in 0..column.len() { + ids.push(column.value(i)); + } + } + } + + ids.sort(); + assert_eq!( + ids, + vec![1, 2, 3], + "every row read exactly once across partitions" + ); +} diff --git a/pom.xml b/pom.xml index 7ceec07..a48be6c 100644 --- a/pom.xml +++ b/pom.xml @@ -33,6 +33,7 @@ under the License. core examples + spark diff --git a/spark/pom.xml b/spark/pom.xml new file mode 100644 index 0000000..cf6e3dc --- /dev/null +++ b/spark/pom.xml @@ -0,0 +1,182 @@ + + + + 4.0.0 + + + org.apache.datafusion + datafusion-java-parent + 0.2.0-SNAPSHOT + + + datafusion-spark + DataFusion Spark DataSource + A Spark DataSourceV2 backed by a DataFusion TableProvider, over either the plain-C scan ABI or an ADBC driver. + + + 4.0.0 + 2.13 + + 18.1.0 + + + 0.23.0 + + 0.57.0 + + + + + + org.apache.datafusion + datafusion-java + ${project.version} + + + org.apache.arrow + * + + + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark.version} + provided + + + + + org.apache.arrow + arrow-c-data + ${spark.arrow.version} + provided + + + + + org.apache.arrow.adbc + adbc-core + ${adbc.version} + + + org.apache.arrow + * + + + + + org.apache.arrow.adbc + adbc-driver-jni + ${adbc.version} + + + org.apache.arrow + * + + + + + + + io.substrait + core + ${substrait.version} + + + + + org.junit.jupiter + junit-jupiter + test + + + + org.mockito + mockito-core + 5.14.2 + test + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + + --add-opens=java.base/java.lang=ALL-UNNAMED + --add-opens=java.base/java.lang.invoke=ALL-UNNAMED + --add-opens=java.base/java.io=ALL-UNNAMED + --add-opens=java.base/java.net=ALL-UNNAMED + --add-opens=java.base/java.nio=ALL-UNNAMED + --add-opens=java.base/java.util=ALL-UNNAMED + --add-opens=java.base/java.util.concurrent=ALL-UNNAMED + --add-opens=java.base/sun.nio.ch=ALL-UNNAMED + --add-opens=java.base/sun.security.action=ALL-UNNAMED + + + + + + + + + + adbc-snapshot + + 0.24.0-SNAPSHOT + + + + diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcColumnarPartitionReader.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcColumnarPartitionReader.java new file mode 100644 index 0000000..530c47c --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcColumnarPartitionReader.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcStatement; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * Reads one ADBC scan partition as Spark {@link ColumnarBatch}es, zero-copy. + * + *

This executor borrows a cached database from {@link AdbcConnectionPool} (cached per executor + * so the DataFusion cdylib is loaded and its providers registered once per executor, not once per + * task) and opens its own per-task connection off it, then obtains an {@link ArrowReader} one of + * two ways depending on how the scan was planned (see {@link AdbcInputPartition}): + * + *

    + *
  • partition descriptor -> {@code AdbcConnection.readPartition(descriptor)} (multi-partition); + *
  • Substrait plan -> {@code setSubstraitPlan} + {@code executeQuery} (single-partition + * fallback). + *
+ * + *

The imported Arrow vectors are wrapped directly in Spark {@link ArrowColumnVector}s -- no + * per-cell copy -- which requires the executor JVM to have a single arrow-java (the cluster's Spark + * Arrow), shared by the ADBC driver manager and Spark. + * + *

Lifecycle: the Arrow vectors are owned by the {@link ArrowReader}. We do not close the {@link + * ColumnarBatch} (which would double-free the vectors); {@link #close()} closes the per-task ADBC + * handles and releases the {@link AdbcConnectionPool.Lease}. The cached database and its root + * allocator are owned by the pool and outlive this reader. + */ +final class AdbcColumnarPartitionReader implements PartitionReader { + + private final AdbcConnectionPool.Lease lease; + // Non-null only on the executeQuery (single-partition fallback) path. + private final AdbcStatement statement; + private final AdbcStatement.QueryResult queryResult; + private final ArrowReader reader; + private final VectorSchemaRoot root; + private final ColumnarBatch batch; + + AdbcColumnarPartitionReader(AdbcInputPartition partition) { + AdbcConnectionPool.Lease ls = null; + AdbcStatement stmt = null; + AdbcStatement.QueryResult result = null; + try { + ls = AdbcConnectionPool.acquire(partition.options); + AdbcConnection conn = ls.connection(); + + ArrowReader r; + switch (partition.kind) { + case DESCRIPTOR -> { + // Multi-partition path: read one opaque partition descriptor. + r = conn.readPartition(directBuffer(partition.payload)); + } + case SUBSTRAIT -> { + stmt = conn.createStatement(); + stmt.setSubstraitPlan(directBuffer(partition.payload)); + result = stmt.executeQuery(); + r = result.getReader(); + } + case SQL -> { + stmt = conn.createStatement(); + stmt.setSqlQuery(new String(partition.payload, java.nio.charset.StandardCharsets.UTF_8)); + result = stmt.executeQuery(); + r = result.getReader(); + } + default -> throw new IllegalStateException("unknown partition kind " + partition.kind); + } + + this.lease = ls; + this.statement = stmt; + this.queryResult = result; + this.reader = r; + this.root = reader.getVectorSchemaRoot(); + this.batch = new ColumnarBatch(wrap(root)); + } catch (Exception e) { + try { + // Close only task-owned handles plus the lease; the lease leaves the cached + // database/connection alone. + AutoCloseables.close(result, stmt, ls); + } catch (Exception suppressed) { + e.addSuppressed(suppressed); + } + throw new RuntimeException("failed to open ADBC scan partition", e); + } + } + + /** Wrap each Arrow vector of the (reused) root as a Spark column vector, once. */ + private static ColumnVector[] wrap(VectorSchemaRoot root) { + ColumnVector[] columns = new ColumnVector[root.getFieldVectors().size()]; + int i = 0; + for (FieldVector vector : root.getFieldVectors()) { + columns[i++] = new ArrowColumnVector(vector); + } + return columns; + } + + private static ByteBuffer directBuffer(byte[] bytes) { + ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length); + buffer.put(bytes).flip(); + return buffer; + } + + @Override + public boolean next() throws IOException { + // The root's vectors are reloaded in place each batch; skip empty batches. + while (reader.loadNextBatch()) { + int rows = root.getRowCount(); + if (rows > 0) { + batch.setNumRows(rows); + return true; + } + } + return false; + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + try { + // On the executeQuery path the QueryResult owns the reader, so close it instead + // of the reader; on the readPartition path close the reader directly. The lease + // releases the per-task child allocator and the per-task connection; the cached + // database and root allocator are owned by the pool. + AutoCloseable readerHandle = queryResult != null ? queryResult : reader; + AutoCloseables.close(readerHandle, statement, lease); + } catch (Exception e) { + throw new IOException("failed to close ADBC scan partition", e); + } + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcConnectionPool.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcConnectionPool.java new file mode 100644 index 0000000..4574cfe --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcConnectionPool.java @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.driver.jni.JniDriver; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; + +/** + * Per-executor-JVM cache of the expensive native ADBC objects, so the task slots on one executor do + * not each repeat driver load, provider registration, and session setup against the same cdylib. + * + *

A Spark executor runs many tasks (one per slot) in one long-lived JVM. Each task's {@link + * AdbcColumnarPartitionReader} needs an {@link AdbcConnection}; opening one per task re-runs the + * native session construction and -- the costly part -- the driver's {@code ContextInit} provider + * registration, which the DataFusion ADBC driver runs once per database. This holder keeps + * one {@link CachedDatabase} per {@link Key} (the inputs that determine the native database; see + * {@link AdbcOptions#cacheKey()}) for the life of the JVM, and each task opens its own short-lived + * connection off that shared database. + * + *

Connections are not shared across tasks. The arrow-adbc Rust FFI exporter + * reaches the inner connection on every C call via a {@code &mut} to the shared object with no + * lock, so two task threads calling into one connection concurrently would alias {@code &mut} -- + * undefined behavior, regardless of the driver being internally {@code Sync}. Caching the database + * (where {@code ContextInit} runs) already removes the per-task cost, so the connection stays + * per-task. + * + *

Cached handles outlive every task and are closed only by a JVM shutdown hook (executors are + * long-lived). A per-task {@link Lease} hands out a task-owned connection plus a child allocator + * and, on {@link Lease#close()}, releases only those task-owned handles -- never the cached + * database or its root allocator. + */ +final class AdbcConnectionPool { + + private static final ConcurrentHashMap CACHE = new ConcurrentHashMap<>(); + private static final AtomicBoolean HOOK_INSTALLED = new AtomicBoolean(false); + + // Observability counters: how many native databases were built and how many per-task connections + // were opened over the JVM's lifetime. The point of the pool is that the first stays at + // one-per-executor (provider registration runs once); the second scales with tasks. Read by the + // driver-gated E2E benchmark. + private static final AtomicInteger NATIVE_DATABASES = new AtomicInteger(); + private static final AtomicInteger TASK_CONNECTIONS = new AtomicInteger(); + + // Indirection so unit/concurrency tests can supply cheap doubles in place of native handles. + private static volatile DatabaseBuilder builder = AdbcConnectionPool::buildNative; + + private AdbcConnectionPool() {} + + /** + * Identity of a native ADBC database: the driver path/name plus the database-option map, compared + * order-insensitively (map order does not change the native database). Derived from {@link + * AdbcOptions} on the executor with no new wire fields -- both inputs already ride along inside + * the serializable {@link AdbcOptions}. + */ + static final class Key { + private final String driver; + private final List> sortedOptions; + + Key(String driver, Map databaseOptions) { + this.driver = Objects.requireNonNull(driver, "driver"); + List> entries = new ArrayList<>(databaseOptions.size()); + for (Map.Entry e : databaseOptions.entrySet()) { + entries.add(new AbstractMap.SimpleImmutableEntry<>(e.getKey(), e.getValue())); + } + entries.sort(Map.Entry.comparingByKey()); + this.sortedOptions = List.copyOf(entries); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Key)) { + return false; + } + Key k = (Key) o; + return driver.equals(k.driver) && sortedOptions.equals(k.sortedOptions); + } + + @Override + public int hashCode() { + return Objects.hash(driver, sortedOptions); + } + + @Override + public String toString() { + return "AdbcConnectionPool.Key[driver=" + driver + ", options=" + sortedOptions + "]"; + } + } + + /** One cached native database plus its root allocator. */ + static final class CachedDatabase implements AutoCloseable { + final BufferAllocator rootAllocator; + final AdbcDatabase database; + final AtomicInteger activeLeases = new AtomicInteger(); + volatile long lastAccessNanos; + + CachedDatabase(BufferAllocator rootAllocator, AdbcDatabase database) { + this.rootAllocator = rootAllocator; + this.database = database; + } + + @Override + public void close() throws Exception { + // Close order: database -> allocator (the allocator must outlive the handles opened against + // it). + AutoCloseables.close(database, rootAllocator); + } + } + + /** Builds the native objects for a key; replaceable in tests. */ + @FunctionalInterface + interface DatabaseBuilder { + CachedDatabase build(Key key, AdbcOptions options) throws AdbcException; + } + + /** + * A task's borrowed handle on a {@link CachedDatabase}. Carries the task-owned connection the + * reader uses and a per-task child allocator. {@link #close()} releases only what the task owns; + * the cached database and root allocator are not touched. + */ + static final class Lease implements AutoCloseable { + private final CachedDatabase shared; + private final BufferAllocator taskAllocator; + private final AdbcConnection connection; + + Lease(CachedDatabase shared, BufferAllocator taskAllocator, AdbcConnection connection) { + this.shared = shared; + this.taskAllocator = taskAllocator; + this.connection = connection; + } + + AdbcConnection connection() { + return connection; + } + + BufferAllocator allocator() { + return taskAllocator; + } + + @Override + public void close() throws Exception { + try { + // Close the task-owned connection and allocator. The cached database and root allocator are + // never closed here -- only by the JVM shutdown hook. + AutoCloseables.close(connection, taskAllocator); + } finally { + shared.activeLeases.decrementAndGet(); + } + } + } + + /** Fetch-or-create the cached database for these options and return a per-task lease. */ + static Lease acquire(AdbcOptions options) throws AdbcException { + installShutdownHookOnce(); + Key key = options.cacheKey(); + CachedDatabase shared = getOrCreate(key, options); + shared.activeLeases.incrementAndGet(); + shared.lastAccessNanos = System.nanoTime(); + + BufferAllocator taskAllocator = null; + try { + // A child of the shared root: per-task accounting/isolation; closed when the lease closes. + taskAllocator = shared.rootAllocator.newChildAllocator("adbc-task", 0, Long.MAX_VALUE); + AdbcConnection conn = shared.database.connect(); + TASK_CONNECTIONS.incrementAndGet(); + return new Lease(shared, taskAllocator, conn); + } catch (Exception e) { + shared.activeLeases.decrementAndGet(); + if (taskAllocator != null) { + try { + taskAllocator.close(); + } catch (Exception suppressed) { + e.addSuppressed(suppressed); + } + } + if (e instanceof AdbcException) { + throw (AdbcException) e; + } + throw new RuntimeException("failed to acquire ADBC connection lease", e); + } + } + + private static CachedDatabase getOrCreate(Key key, AdbcOptions options) throws AdbcException { + try { + // computeIfAbsent runs the mapping function at most once per absent key; concurrent callers + // on the same key block until it completes, so exactly one database is built and ContextInit + // runs once. AdbcException is checked and cannot escape the mapping Function, so wrap it and + // unwrap below; on failure nothing is inserted, so the next caller retries cleanly. + return CACHE.computeIfAbsent( + key, + k -> { + try { + return builder.build(k, options); + } catch (AdbcException e) { + throw new UncheckedAdbcException(e); + } + }); + } catch (UncheckedAdbcException e) { + throw e.cause(); + } + } + + private static CachedDatabase buildNative(Key key, AdbcOptions options) throws AdbcException { + BufferAllocator root = new RootAllocator(); + AdbcDatabase db = null; + try { + db = new JniDriver(root).open(options.driverParameters()); + NATIVE_DATABASES.incrementAndGet(); + return new CachedDatabase(root, db); + } catch (Exception e) { + try { + AutoCloseables.close(db, root); + } catch (Exception suppressed) { + e.addSuppressed(suppressed); + } + if (e instanceof AdbcException) { + throw (AdbcException) e; + } + throw new RuntimeException("failed to open ADBC database", e); + } + } + + private static void installShutdownHookOnce() { + if (HOOK_INSTALLED.compareAndSet(false, true)) { + Runtime.getRuntime() + .addShutdownHook(new Thread(AdbcConnectionPool::closeAll, "adbc-pool-shutdown")); + } + } + + /** Remove and close every cached database. Run by the shutdown hook; best-effort. */ + static void closeAll() { + for (Key key : new ArrayList<>(CACHE.keySet())) { + CachedDatabase d = CACHE.remove(key); + if (d != null) { + try { + d.close(); + } catch (Exception e) { + // Best effort: the JVM is going away. Nothing useful to do with the exception here. + } + } + } + } + + /** Unchecked carrier so a checked {@link AdbcException} can cross the computeIfAbsent lambda. */ + private static final class UncheckedAdbcException extends RuntimeException { + private static final long serialVersionUID = 1L; + + UncheckedAdbcException(AdbcException cause) { + super(cause); + } + + AdbcException cause() { + return (AdbcException) getCause(); + } + } + + // ---- Test hooks (package-private) ---- + + static void setBuilderForTesting(DatabaseBuilder b) { + builder = b; + } + + static int cacheSizeForTesting() { + return CACHE.size(); + } + + /** Native databases built over the JVM lifetime; one per executor when the pool works. */ + static int databasesBuiltForTesting() { + return NATIVE_DATABASES.get(); + } + + /** Per-task connections opened over the JVM lifetime. */ + static int taskConnectionsOpenedForTesting() { + return TASK_CONNECTIONS.get(); + } + + /** Close and clear all cached state and restore production defaults. For test isolation. */ + static void resetForTesting() { + closeAll(); + builder = AdbcConnectionPool::buildNative; + NATIVE_DATABASES.set(0); + TASK_CONNECTIONS.set(0); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcDatafusionTableProvider.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcDatafusionTableProvider.java new file mode 100644 index 0000000..bec7db8 --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcDatafusionTableProvider.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.util.Map; + +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.driver.jni.JniDriver; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableProvider; +import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.sources.DataSourceRegister; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * Entry point for the {@code adbc-datafusion} Spark data source. + * + *

The native boundary is a standard ADBC driver: the arrow-adbc Java driver manager loading a + * DataFusion ADBC cdylib. {@code spark.read.format("adbc-datafusion").option("driver", + * ...).option("table", ...).load()} resolves here; the schema is probed once, on the driver, via + * {@link AdbcConnection#getTableSchema}. + */ +public final class AdbcDatafusionTableProvider implements TableProvider, DataSourceRegister { + + @Override + public String shortName() { + return "adbc-datafusion"; + } + + @Override + public StructType inferSchema(CaseInsensitiveStringMap options) { + AdbcOptions opts = AdbcOptions.fromOptions(options); + try (BufferAllocator allocator = new RootAllocator(); + AdbcDatabase db = new JniDriver(allocator).open(opts.driverParameters()); + AdbcConnection conn = db.connect()) { + Schema arrow = conn.getTableSchema(null, null, opts.table()); + return SchemaConverter.toSparkSchema(arrow); + } catch (Exception e) { + throw new RuntimeException("failed to probe ADBC schema for table " + opts.table(), e); + } + } + + @Override + public Table getTable( + StructType schema, Transform[] partitioning, Map properties) { + AdbcOptions opts = AdbcOptions.fromOptions(new CaseInsensitiveStringMap(properties)); + return new AdbcTable(schema, opts); + } + + @Override + public boolean supportsExternalMetadata() { + return false; + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcInputPartition.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcInputPartition.java new file mode 100644 index 0000000..ed880dd --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcInputPartition.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import org.apache.spark.sql.connector.read.InputPartition; + +/** + * A serializable slice of an ADBC scan shipped to an executor. Carries only the connector options + * and an opaque payload -- never a native handle or connection, which are meaningless in another + * process. Each executor reopens its own ADBC database/connection from {@code options}. + * + *

The {@code kind} says how the executor turns {@code payload} into an {@code ArrowReader}: + * + *

    + *
  • {@link Kind#DESCRIPTOR}: {@code payload} is an ADBC partition descriptor from {@code + * executePartitioned()}; the executor runs {@code AdbcConnection.readPartition(payload)} (the + * multi-partition path, one partition per descriptor). + *
  • {@link Kind#SUBSTRAIT}: {@code payload} is a serialized Substrait plan; the executor runs + * {@code setSubstraitPlan} + {@code executeQuery} (single partition). + *
  • {@link Kind#SQL}: {@code payload} is UTF-8 SQL; the executor runs {@code setSqlQuery} + + * {@code executeQuery} (single partition; used when the driver/JNI lacks Substrait support). + *
+ */ +final class AdbcInputPartition implements InputPartition { + + private static final long serialVersionUID = 1L; + + enum Kind { + DESCRIPTOR, + SUBSTRAIT, + SQL + } + + final AdbcOptions options; + final byte[] payload; + final Kind kind; + + AdbcInputPartition(AdbcOptions options, byte[] payload, Kind kind) { + this.options = options; + this.payload = payload; + this.kind = kind; + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcOptions.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcOptions.java new file mode 100644 index 0000000..8251b03 --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcOptions.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.io.Serializable; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.OptionalInt; + +import org.apache.arrow.adbc.driver.jni.JniDriver; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * Connector options for the {@code adbc-datafusion} data source, decoded from the Spark options + * map. + * + *

Serializable so it can ride along inside an {@link AdbcInputPartition} to each executor, which + * reopens its own ADBC database/connection. Three keys are connector-control; everything else is + * forwarded verbatim as native ADBC database options (so provider-specific options pass straight + * through to the registered {@code TableProvider}). + * + *

    + *
  • {@code driver} (required) -- path to, or name of, the native ADBC driver shared library + * (our DataFusion ADBC cdylib). Resolved by the C driver manager. + *
  • {@code manifest.path} (optional) -- extra search path for ADBC driver manifests, when the + * driver is given by manifest name rather than an absolute path. + *
  • {@code table} (required) -- the table name the registered provider exposes; becomes the + * Substrait {@code NamedScan} target. + *
+ */ +final class AdbcOptions implements Serializable { + + private static final long serialVersionUID = 1L; + + static final String DRIVER = "driver"; + static final String TABLE = "table"; + static final String TARGET_PARTITIONS = "target_partitions"; + + private final String driver; + private final String table; + // DataFusion target_partitions, or null to default to the cluster parallelism. + // Stored as a (Serializable) Integer rather than OptionalInt -- this object rides + // to executors inside AdbcInputPartition, and OptionalInt is not Serializable. + private final Integer targetPartitions; + // Provider-specific / passthrough ADBC database options. + private final LinkedHashMap databaseOptions; + + private AdbcOptions( + String driver, + String table, + Integer targetPartitions, + LinkedHashMap databaseOptions) { + this.driver = driver; + this.table = table; + this.targetPartitions = targetPartitions; + this.databaseOptions = databaseOptions; + } + + static AdbcOptions fromOptions(CaseInsensitiveStringMap options) { + String driver = require(options, DRIVER); + String table = require(options, TABLE); + Integer targetPartitions = + parsePositiveInt(options, TARGET_PARTITIONS).isPresent() + ? parsePositiveInt(options, TARGET_PARTITIONS).getAsInt() + : null; + + LinkedHashMap passthrough = new LinkedHashMap<>(); + for (Map.Entry entry : options.entrySet()) { + String key = entry.getKey(); + if (key.equalsIgnoreCase(DRIVER) + || key.equalsIgnoreCase(TABLE) + || key.equalsIgnoreCase(TARGET_PARTITIONS)) { + continue; + } + passthrough.put(key, entry.getValue()); + } + return new AdbcOptions(driver, table, targetPartitions, passthrough); + } + + String table() { + return table; + } + + /** Explicit DataFusion target_partitions, or empty to use the cluster parallelism. */ + OptionalInt targetPartitions() { + return targetPartitions == null ? OptionalInt.empty() : OptionalInt.of(targetPartitions); + } + + /** + * Build the parameter map for {@link JniDriver#open}. All values must be strings; the {@code + * jni.driver} key (set via {@link JniDriver#PARAM_DRIVER}) is consumed by the driver manager to + * locate the shared library, the rest become native database options. + */ + Map driverParameters() { + Map params = new LinkedHashMap<>(); + JniDriver.PARAM_DRIVER.set(params, driver); + params.putAll(databaseOptions); + return params; + } + + /** + * Identity of the native ADBC database these options open. Two option sets that produce the same + * {@link #driverParameters()} share a key, so {@link AdbcConnectionPool} caches one native + * database/connection per executor JVM for them. Derived from exactly the inputs {@code + * driverParameters()} consumes -- the {@code driver} path/name plus the passthrough database + * options -- and must stay adjacent to it so the two cannot drift. {@code table} and {@code + * target_partitions} are connector-control (not forwarded to the native database) and so are + * deliberately excluded. + */ + AdbcConnectionPool.Key cacheKey() { + return new AdbcConnectionPool.Key(driver, databaseOptions); + } + + private static String require(CaseInsensitiveStringMap options, String key) { + String value = options.getOrDefault(key, null); + if (value == null || value.isEmpty()) { + throw new IllegalArgumentException( + "the 'adbc-datafusion' data source requires the '" + key + "' option"); + } + return value; + } + + private static OptionalInt parsePositiveInt(CaseInsensitiveStringMap options, String key) { + String value = options.getOrDefault(key, null); + if (value == null || value.isEmpty()) { + return OptionalInt.empty(); + } + int parsed; + try { + parsed = Integer.parseInt(value.trim()); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("'" + key + "' must be an integer, got: " + value, e); + } + if (parsed < 1) { + throw new IllegalArgumentException("'" + key + "' must be >= 1, got: " + parsed); + } + return OptionalInt.of(parsed); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcPartitionReaderFactory.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcPartitionReaderFactory.java new file mode 100644 index 0000000..eb1bf67 --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcPartitionReaderFactory.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReader; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * Creates a columnar ADBC reader per partition. Serialized to executors, so it holds no state. + * + *

Reads are columnar: {@link #supportColumnarReads} returns true, so Spark consumes Arrow + * buffers directly via {@link AdbcColumnarPartitionReader}. The row reader is unsupported. + */ +final class AdbcPartitionReaderFactory implements PartitionReaderFactory { + + private static final long serialVersionUID = 1L; + + @Override + public boolean supportColumnarReads(InputPartition partition) { + return true; + } + + @Override + public PartitionReader createColumnarReader(InputPartition partition) { + return new AdbcColumnarPartitionReader((AdbcInputPartition) partition); + } + + @Override + public PartitionReader createReader(InputPartition partition) { + throw new UnsupportedOperationException("adbc-datafusion source reads are columnar"); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcScanBuilder.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanBuilder.java new file mode 100644 index 0000000..b18fa80 --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanBuilder.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.OptionalLong; +import java.util.Set; + +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.SupportsPushDownFilters; +import org.apache.spark.sql.connector.read.SupportsPushDownLimit; +import org.apache.spark.sql.connector.read.SupportsPushDownRequiredColumns; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * Builds an {@link AdbcScanImpl}, pushing projection / filter / limit down as Substrait. + * + *

Pushdown decision: a filter is pushed only if {@link SubstraitPlan#canPush} maps it + * to a standard-catalog Substrait predicate over a known column; the rest are returned to Spark. + * Projection becomes a Substrait emit. Limit is pushed only when no filters were left to Spark -- + * otherwise Spark must still filter after the scan, so a scan-side limit would be wrong. + */ +final class AdbcScanBuilder + implements ScanBuilder, + SupportsPushDownRequiredColumns, + SupportsPushDownFilters, + SupportsPushDownLimit { + + private final StructType fullSchema; + private final AdbcOptions options; + + private StructType requiredSchema; + private List pushedFilters = new ArrayList<>(); + private boolean hasResidualFilters = false; + private OptionalLong limit = OptionalLong.empty(); + + AdbcScanBuilder(StructType schema, AdbcOptions options) { + this.fullSchema = schema; + this.requiredSchema = schema; + this.options = options; + } + + @Override + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; + } + + @Override + public Filter[] pushFilters(Filter[] filters) { + // Push filters only against Spark-native (non-cast) columns: a cast column is cast in the + // scan projection, but a pushed predicate runs against the pre-cast source column, so its + // Spark-domain literal would not match. Those filters are left to Spark (see SchemaConverter). + Set columns = new LinkedHashSet<>(Arrays.asList(fullSchema.fieldNames())); + for (StructField field : fullSchema.fields()) { + if (field.metadata().contains(SchemaConverter.CAST_METADATA_KEY)) { + columns.remove(field.name()); + } + } + List pushable = new ArrayList<>(); + List residual = new ArrayList<>(); + for (Filter f : filters) { + if (SubstraitPlan.canPush(f, columns)) { + pushable.add(f); + } else { + residual.add(f); + } + } + this.pushedFilters = pushable; + this.hasResidualFilters = !residual.isEmpty(); + return residual.toArray(new Filter[0]); + } + + @Override + public Filter[] pushedFilters() { + return pushedFilters.toArray(new Filter[0]); + } + + @Override + public boolean pushLimit(int limit) { + // Only safe to push when Spark has no filters left to apply after the scan; + // otherwise the limit would be applied before that filtering. + if (hasResidualFilters) { + return false; + } + this.limit = OptionalLong.of(limit); + return true; + } + + @Override + public Scan build() { + List projection = + Arrays.equals(requiredSchema.fieldNames(), fullSchema.fieldNames()) + ? null + : Arrays.asList(requiredSchema.fieldNames()); + return new AdbcScanImpl(requiredSchema, options, projection, pushedFilters, limit); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java new file mode 100644 index 0000000..e2f589a --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcScanImpl.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.OptionalLong; + +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.adbc.core.AdbcStatement; +import org.apache.arrow.adbc.core.AdbcStatusCode; +import org.apache.arrow.adbc.core.PartitionDescriptor; +import org.apache.arrow.adbc.driver.jni.JniDriver; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.datafusion.spark.AdbcInputPartition.Kind; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.connector.read.Batch; +import org.apache.spark.sql.connector.read.InputPartition; +import org.apache.spark.sql.connector.read.PartitionReaderFactory; +import org.apache.spark.sql.connector.read.Scan; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.types.StructType; + +/** + * A planned ADBC-backed scan as a Spark {@link Scan}/{@link Batch}. + * + *

{@link #planInputPartitions()} runs on the driver and probes the driver's capabilities, since + * the arrow-adbc JNI manager and the underlying driver implement different subsets of ADBC: + * + *

    + *
  1. Pick the query wire: prefer a typed Substrait plan; if the driver reports {@code + * NOT_IMPLEMENTED} for {@code setSubstraitPlan}, fall back to a SQL string. + *
  2. Pick the partitioning: try {@code executePartitioned()} for one descriptor per output + * partition; if {@code NOT_IMPLEMENTED}, emit a single partition carrying the chosen query. + *
+ * + * No native handle ever crosses the wire -- each {@link AdbcInputPartition} carries only opaque + * bytes, and each executor reopens its own ADBC connection. + */ +final class AdbcScanImpl implements Scan, Batch { + + private final StructType readSchema; + private final AdbcOptions options; + // null projection means all columns; column names in output order otherwise. + private final List projection; + private final List pushedFilters; + private final OptionalLong limit; + + AdbcScanImpl( + StructType readSchema, + AdbcOptions options, + List projection, + List pushedFilters, + OptionalLong limit) { + this.readSchema = readSchema; + this.options = options; + this.projection = projection; + this.pushedFilters = pushedFilters; + this.limit = limit; + } + + @Override + public StructType readSchema() { + return readSchema; + } + + @Override + public Batch toBatch() { + return this; + } + + @Override + public InputPartition[] planInputPartitions() { + try (BufferAllocator allocator = new RootAllocator(); + AdbcDatabase db = new JniDriver(allocator).open(options.driverParameters()); + AdbcConnection conn = db.connect()) { + // Set DataFusion's target_partitions on the planning session. execute_partitions + // pins it into each descriptor, so executors re-plan into the same partitioning. + // Defaults to the cluster's parallelism (total executor cores). Repartition-aware + // providers (e.g. file scans) use this to choose N; fixed-partition providers + // (e.g. an in-memory table) keep their intrinsic partition count. + int targetPartitions = options.targetPartitions().orElseGet(AdbcScanImpl::clusterParallelism); + applyTargetPartitions(conn, targetPartitions); + + Schema arrow = conn.getTableSchema(null, null, options.table()); + + // The casts (unsigned, Float16, non-µs timestamps, time) live only in the SQL projection, so + // any schema needing one must use the SQL wire. The gate is the full schema, not just the + // projection: the Substrait NamedScan declares every field's type, so an unprojected cast + // column would still misdeclare the base schema. Substrait also can't encode several other + // Spark-native Arrow types (binary, nested, decimal, ...); build() throws for those, which we + // likewise treat as "not Substrait-representable" and fall back. + boolean schemaNeedsCast = arrow.getFields().stream().anyMatch(SchemaConverter::needsCast); + + List columns = + SchemaConverter.projectionColumns(arrow, projection); + // count() (and other column-less reads) prunes the projection to empty. A bare SELECT * + // would then return the raw, uncast schema and the reader would fail on a non-Spark-native + // column, so when the table has any cast column, emit a single readable probe column + // instead -- the row count is all such a scan needs. + if (columns.isEmpty() && schemaNeedsCast) { + columns = List.of(SchemaConverter.probeColumn(arrow)); + } + boolean anyCast = columns.stream().anyMatch(c -> c.castType() != null); + // SELECT * only when no columns are projected away and none need a cast; otherwise the + // columns must be listed so the casts can be injected. + List sqlColumns = + (projection == null && !anyCast) ? null : columns; + String sql = SqlQuery.build(options.table(), sqlColumns, pushedFilters, limit); + + byte[] substrait = null; + if (!schemaNeedsCast) { + try { + substrait = SubstraitPlan.build(options.table(), arrow, projection, pushedFilters, limit); + } catch (RuntimeException e) { + substrait = null; + } + } + return plan(conn, substrait, sql); + } catch (Exception e) { + throw new RuntimeException("failed to plan ADBC scan for table " + options.table(), e); + } + } + + private static void applyTargetPartitions(AdbcConnection conn, int targetPartitions) + throws Exception { + try (AdbcStatement stmt = conn.createStatement()) { + stmt.setSqlQuery("SET datafusion.execution.target_partitions = " + targetPartitions); + stmt.executeUpdate(); + } + } + + /** Total executor cores via the active SparkSession; falls back to local cores. */ + private static int clusterParallelism() { + try { + return SparkSession.active().sparkContext().defaultParallelism(); + } catch (Throwable t) { + return Runtime.getRuntime().availableProcessors(); + } + } + + /** Probe the query wire and partitioning, returning the input partitions. */ + private InputPartition[] plan(AdbcConnection conn, byte[] substrait, String sql) + throws Exception { + Kind singleKind; + byte[] singlePayload; + + // Force the SQL wire when Substrait can't encode this scan (null plan), or via the escape + // hatch (e.g. when the Substrait round-trip plans to fewer partitions than SQL). Defaults to + // preferring Substrait. + boolean forceSql = + substrait == null || "sql".equalsIgnoreCase(System.getProperty("adbc.wire", "")); + + AdbcStatement stmt = conn.createStatement(); + try { + if (forceSql) { + stmt.setSqlQuery(sql); + singleKind = Kind.SQL; + singlePayload = sql.getBytes(StandardCharsets.UTF_8); + } else { + try { + stmt.setSubstraitPlan(directBuffer(substrait)); + singleKind = Kind.SUBSTRAIT; + singlePayload = substrait; + } catch (AdbcException e) { + if (!notImplemented(e)) { + throw e; + } + // Driver/JNI without Substrait support -> SQL. + AutoCloseables.close(stmt); + stmt = conn.createStatement(); + stmt.setSqlQuery(sql); + singleKind = Kind.SQL; + singlePayload = sql.getBytes(StandardCharsets.UTF_8); + } + } + + List descriptors = tryPartition(stmt); + if (descriptors != null) { + return descriptors.toArray(new InputPartition[0]); + } + } finally { + AutoCloseables.close(stmt); + } + // Single-partition fallback carrying the chosen query. + return new InputPartition[] {new AdbcInputPartition(options, singlePayload, singleKind)}; + } + + /** + * Attempt ADBC partitioned execution on the (already-query-set) statement. Returns one partition + * per descriptor, or {@code null} if the driver does not implement partitioning. + */ + private List tryPartition(AdbcStatement stmt) throws Exception { + try { + AdbcStatement.PartitionResult result = stmt.executePartitioned(); + List descriptors = result.getPartitionDescriptors(); + if (descriptors.isEmpty()) { + return null; + } + List partitions = new ArrayList<>(descriptors.size()); + for (PartitionDescriptor descriptor : descriptors) { + partitions.add( + new AdbcInputPartition(options, toBytes(descriptor.getDescriptor()), Kind.DESCRIPTOR)); + } + return partitions; + } catch (AdbcException e) { + if (notImplemented(e)) { + return null; + } + throw e; + } + } + + private static boolean notImplemented(AdbcException e) { + return e.getStatus() == AdbcStatusCode.NOT_IMPLEMENTED + || e.getStatus() == AdbcStatusCode.NOT_FOUND; + } + + private static ByteBuffer directBuffer(byte[] bytes) { + ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length); + buffer.put(bytes).flip(); + return buffer; + } + + private static byte[] toBytes(ByteBuffer buffer) { + ByteBuffer dup = buffer.duplicate(); + byte[] bytes = new byte[dup.remaining()]; + dup.get(bytes); + return bytes; + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new AdbcPartitionReaderFactory(); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/AdbcTable.java b/spark/src/main/java/org/apache/datafusion/spark/AdbcTable.java new file mode 100644 index 0000000..6467bc8 --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/AdbcTable.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.util.EnumSet; +import java.util.Set; + +import org.apache.spark.sql.connector.catalog.SupportsRead; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * A readable table over a DataFusion provider exposed via ADBC; produces {@link AdbcScanBuilder}s. + */ +final class AdbcTable implements SupportsRead { + + private final StructType schema; + private final AdbcOptions options; + + AdbcTable(StructType schema, AdbcOptions options) { + this.schema = schema; + this.options = options; + } + + @Override + public String name() { + return "adbc-datafusion"; + } + + @Override + public StructType schema() { + return schema; + } + + @Override + public Set capabilities() { + return EnumSet.of(TableCapability.BATCH_READ); + } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap scanOptions) { + return new AdbcScanBuilder(schema, options); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java b/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java new file mode 100644 index 0000000..4ef537c --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/SchemaConverter.java @@ -0,0 +1,462 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.MetadataBuilder; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +/** + * Converts an Arrow schema (produced by the ADBC scan) into a Spark {@link StructType}, and plans + * the source-side casts that make the scan emit Arrow that Spark's vectorized reader can consume. + * + *

An Arrow column must clear two independent gates: + * + *

    + *
  1. Schema mapping -- an Arrow type maps to some Spark {@code DataType} (used at {@code + * inferSchema}). + *
  2. Vectorized read -- Spark's {@code ArrowColumnVector} has an accessor for the vector. + * This gate is narrower than the first and is not a public API. + *
+ * + *

Enumerating the two gates by hand lets them drift: a type can map to a Spark type yet have no + * accessor (e.g. {@code FixedSizeList}/{@code LargeList} both map to {@code ArrayType} but only a + * variable {@code ListVector} is readable), producing an {@code UNSUPPORTED_ARROWTYPE} at task time + * rather than plan time. To avoid that, everything derives from one authority: + * + *

    + *
  • {@link #sparkConsumable} -- the read gate: {@code true} iff {@code ArrowColumnVector} has + * an accessor for the type. Mirrors the accessors in Spark 4.0's {@code + * ArrowColumnVector.initAccessor} (deliberately the narrower gate). + *
  • {@link #sparkTarget} -- the nearest consumable Arrow type for a field (identity if already + * consumable), recursing into children. Non-consumable leaves widen (unsigned -> signed, + * Float16 -> Float32, non-µs timestamp -> µs, Date64 -> Date32, Time -> int, + * FixedSizeBinary -> Binary) and non-consumable containers convert ({@code + * FixedSizeList}/{@code LargeList} -> {@code List}). + *
+ * + *

Then the reported Spark type, the cast decision, and the cast target all come from {@code + * sparkTarget}, so the two gates cannot drift and adding a type is one case, not several. The cast + * is pushed into the scan (see {@link SqlQuery}); {@link #toSparkSchema} asserts at plan time that + * every target is consumable, so an unsupported type fails in planning with a clear message rather + * than as an opaque executor crash. + */ +final class SchemaConverter { + + /** + * Metadata flag set on a top-level Spark field whose source column needs a cast. Read by {@link + * AdbcScanBuilder} to keep filter pushdown off these columns: a pushed predicate runs against the + * pre-cast source column, so its literal would be in the wrong (source, not Spark) domain. + */ + static final String CAST_METADATA_KEY = "org.apache.datafusion.spark.adbc.cast"; + + private SchemaConverter() {} + + /** A projected output column: its name and, when a cast is required, the Arrow target type. */ + record ProjectionColumn(String name, String castType) {} + + static StructType toSparkSchema(Schema arrowSchema) { + StructType struct = new StructType(); + for (Field field : arrowSchema.getFields()) { + Field target = sparkTarget(field); + // Plan-time guard: if a type cannot be normalized to something Spark can read, fail here + // (naming the column) instead of crashing on an executor with UNSUPPORTED_ARROWTYPE. + ensureConsumable(field.getName(), target); + Metadata metadata = + target == field + ? Metadata.empty() + : new MetadataBuilder().putBoolean(CAST_METADATA_KEY, true).build(); + struct = struct.add(field.getName(), mapConsumable(target), field.isNullable(), metadata); + } + return struct; + } + + /** + * Plan the projected columns for the pushed scan: the kept columns in output order, each with the + * Arrow cast target (or {@code null} to pass through). + * + * @param schema the full source Arrow schema + * @param projection kept column names in output order, or {@code null} for all columns + */ + static List projectionColumns(Schema schema, List projection) { + List fields = new ArrayList<>(); + if (projection == null) { + fields.addAll(schema.getFields()); + } else { + for (String name : projection) { + fields.add(schema.findField(name)); + } + } + List columns = new ArrayList<>(fields.size()); + for (Field field : fields) { + columns.add(new ProjectionColumn(field.getName(), castTargetString(field))); + } + return columns; + } + + /** Whether the column (recursively) is not something Spark's reader can consume as-is. */ + static boolean needsCast(Field field) { + return sparkTarget(field) != field; + } + + /** The Arrow type string for {@code arrow_cast}, or {@code null} if the column needs no cast. */ + static String castTargetString(Field field) { + Field target = sparkTarget(field); + return target == field ? null : render(target); + } + + /** + * A single Spark-readable column for a column-less scan (e.g. {@code count()}, whose projection + * Catalyst prunes to empty). Such a scan only needs a row count, but the emitted stream must + * still be Spark-native -- a bare {@code SELECT *} would return the raw, uncast schema and the + * reader would fail on the first non-Spark-native column. Prefers a column that needs no cast + * (cheapest); falls back to the first column with its cast applied when every column needs one. + */ + static ProjectionColumn probeColumn(Schema schema) { + for (Field field : schema.getFields()) { + if (!needsCast(field)) { + return new ProjectionColumn(field.getName(), null); + } + } + Field first = schema.getFields().get(0); + return new ProjectionColumn(first.getName(), castTargetString(first)); + } + + // --- The two-gate authority ----------------------------------------------- + + /** + * Whether Spark's vectorized {@code ArrowColumnVector} has an accessor for this Arrow type. + * Mirrors {@code ArrowColumnVector.initAccessor} in Spark 4.0 -- deliberately the narrower gate, + * so a type that maps to a Spark {@code DataType} but has no accessor (unsigned, Float16, non-µs + * timestamp, Date64, {@code Time*}, {@code FixedSizeList}/{@code LargeList}, {@code + * FixedSizeBinary}, {@code Interval}, {@code Dictionary}, ...) is reported as non-consumable. + */ + static boolean sparkConsumable(ArrowType type) { + if (type instanceof ArrowType.Bool) { + return true; + } + if (type instanceof ArrowType.Int i) { + return i.getIsSigned() + && (i.getBitWidth() == 8 + || i.getBitWidth() == 16 + || i.getBitWidth() == 32 + || i.getBitWidth() == 64); + } + if (type instanceof ArrowType.FloatingPoint fp) { + return fp.getPrecision() == FloatingPointPrecision.SINGLE + || fp.getPrecision() == FloatingPointPrecision.DOUBLE; + } + if (type instanceof ArrowType.Utf8 || type instanceof ArrowType.LargeUtf8) { + return true; + } + if (type instanceof ArrowType.Binary || type instanceof ArrowType.LargeBinary) { + return true; // FixedSizeBinary is NOT readable. + } + if (type instanceof ArrowType.Decimal d) { + return d.getBitWidth() == 128; // Spark's DecimalVector is 128-bit; Decimal256 is not read. + } + if (type instanceof ArrowType.Date d) { + return d.getUnit() == DateUnit.DAY; // Date64 (MILLISECOND) is not readable. + } + if (type instanceof ArrowType.Timestamp ts) { + return ts.getUnit() == TimeUnit.MICROSECOND; // only µs, any timezone. + } + if (type instanceof ArrowType.Duration) { + return true; + } + if (type instanceof ArrowType.Null) { + return true; + } + // Only the variable-offset containers are readable (not FixedSizeList / LargeList). + return type instanceof ArrowType.List + || type instanceof ArrowType.Struct + || type instanceof ArrowType.Map; + } + + /** + * The nearest Spark-consumable field for {@code field}, recursing into children. Returns the same + * instance when nothing changes (so {@code sparkTarget(f) == f} means "no cast needed"). + */ + static Field sparkTarget(Field field) { + ArrowType type = field.getType(); + ArrowType targetType = type; + + if (type instanceof ArrowType.Int i && !i.getIsSigned()) { + targetType = + switch (i.getBitWidth()) { + case 8 -> new ArrowType.Int(16, true); + case 16 -> new ArrowType.Int(32, true); + case 32 -> new ArrowType.Int(64, true); + case 64 -> new ArrowType.Decimal(20, 0, 128); // no lossless signed 64-bit target + default -> type; + }; + } else if (type instanceof ArrowType.FloatingPoint fp + && fp.getPrecision() == FloatingPointPrecision.HALF) { + targetType = new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + } else if (type instanceof ArrowType.Timestamp ts && ts.getUnit() != TimeUnit.MICROSECOND) { + targetType = new ArrowType.Timestamp(TimeUnit.MICROSECOND, ts.getTimezone()); + } else if (type instanceof ArrowType.Date d && d.getUnit() != DateUnit.DAY) { + targetType = new ArrowType.Date(DateUnit.DAY); + } else if (type instanceof ArrowType.Time t) { + // Spark has no time-of-day accessor; carry the raw ticks as the matching-width signed int. + targetType = new ArrowType.Int(t.getBitWidth() == 32 ? 32 : 64, true); + } else if (type instanceof ArrowType.FixedSizeBinary) { + targetType = new ArrowType.Binary(); + } else if (type instanceof ArrowType.FixedSizeList || type instanceof ArrowType.LargeList) { + // Spark reads ArrayType only from a variable ListVector. + targetType = new ArrowType.List(); + } + + // Recurse into children (list element, struct fields, map entries), widening each. + List children = field.getChildren(); + List targetChildren = new ArrayList<>(children.size()); + boolean childChanged = false; + for (Field child : children) { + Field targetChild = sparkTarget(child); + childChanged |= targetChild != child; + targetChildren.add(targetChild); + } + + if (targetType == type && !childChanged) { + return field; + } + FieldType ft = new FieldType(field.isNullable(), targetType, field.getDictionary()); + return new Field(field.getName(), ft, targetChildren); + } + + /** Assert every node of a {@code sparkTarget} result is consumable, else fail with the column. */ + private static void ensureConsumable(String column, Field target) { + if (!sparkConsumable(target.getType())) { + throw new IllegalArgumentException( + "column '" + + column + + "': Arrow type " + + target.getType() + + " has no Spark reader and " + + "no supported cast; the connector cannot expose it to Spark"); + } + for (Field child : target.getChildren()) { + ensureConsumable(column, child); + } + } + + // --- Arrow (consumable) type -> Spark type --------------------------------- + + /** Map an already-{@link #sparkConsumable} field to its Spark {@link DataType}. */ + static DataType toSparkType(Field field) { + Field target = sparkTarget(field); + ensureConsumable(field.getName(), target); + return mapConsumable(target); + } + + private static DataType mapConsumable(Field field) { + ArrowType type = field.getType(); + if (type instanceof ArrowType.Bool) { + return DataTypes.BooleanType; + } + if (type instanceof ArrowType.Int i) { + return switch (i.getBitWidth()) { + case 8 -> DataTypes.ByteType; + case 16 -> DataTypes.ShortType; + case 32 -> DataTypes.IntegerType; + case 64 -> DataTypes.LongType; + default -> throw unsupported(field); + }; + } + if (type instanceof ArrowType.FloatingPoint fp) { + return fp.getPrecision() == FloatingPointPrecision.DOUBLE + ? DataTypes.DoubleType + : DataTypes.FloatType; + } + if (type instanceof ArrowType.Utf8 || type instanceof ArrowType.LargeUtf8) { + return DataTypes.StringType; + } + if (type instanceof ArrowType.Binary || type instanceof ArrowType.LargeBinary) { + return DataTypes.BinaryType; + } + if (type instanceof ArrowType.Date) { + return DataTypes.DateType; + } + if (type instanceof ArrowType.Timestamp ts) { + return ts.getTimezone() == null ? DataTypes.TimestampNTZType : DataTypes.TimestampType; + } + if (type instanceof ArrowType.Decimal d) { + return DataTypes.createDecimalType(d.getPrecision(), d.getScale()); + } + if (type instanceof ArrowType.Null) { + return DataTypes.NullType; + } + if (type instanceof ArrowType.Duration) { + return DataTypes.createDayTimeIntervalType(); + } + if (type instanceof ArrowType.List) { + Field element = field.getChildren().get(0); + return DataTypes.createArrayType(mapConsumable(element), element.isNullable()); + } + if (type instanceof ArrowType.Struct) { + List children = new ArrayList<>(); + for (Field child : field.getChildren()) { + children.add( + DataTypes.createStructField(child.getName(), mapConsumable(child), child.isNullable())); + } + return DataTypes.createStructType(children); + } + if (type instanceof ArrowType.Map) { + Field entries = field.getChildren().get(0); + Field key = entries.getChildren().get(0); + Field value = entries.getChildren().get(1); + return DataTypes.createMapType(mapConsumable(key), mapConsumable(value), value.isNullable()); + } + throw unsupported(field); + } + + // --- Arrow (consumable) type -> arrow_cast type string --------------------- + + /** + * Render an already-{@link #sparkConsumable} field as an {@code arrow_cast} type string (the + * reversible {@code arrow::datatypes::DataType} display form DataFusion's {@code arrow_cast} + * parses). Called only on {@link #sparkTarget} output, which is consumable by construction. + */ + private static String render(Field field) { + ArrowType type = field.getType(); + if (type instanceof ArrowType.Bool) { + return "Boolean"; + } + if (type instanceof ArrowType.Int i) { + return "Int" + i.getBitWidth(); + } + if (type instanceof ArrowType.FloatingPoint fp) { + return fp.getPrecision() == FloatingPointPrecision.DOUBLE ? "Float64" : "Float32"; + } + if (type instanceof ArrowType.Utf8) { + return "Utf8"; + } + if (type instanceof ArrowType.LargeUtf8) { + return "LargeUtf8"; + } + if (type instanceof ArrowType.Binary) { + return "Binary"; + } + if (type instanceof ArrowType.LargeBinary) { + return "LargeBinary"; + } + if (type instanceof ArrowType.Date) { + return "Date32"; + } + if (type instanceof ArrowType.Timestamp ts) { + return ts.getTimezone() == null + ? "Timestamp(Microsecond)" + : "Timestamp(Microsecond, \"" + ts.getTimezone() + "\")"; + } + if (type instanceof ArrowType.Decimal d) { + return "Decimal128(" + d.getPrecision() + ", " + d.getScale() + ")"; + } + if (type instanceof ArrowType.Duration dur) { + return "Duration(" + timeUnitName(dur.getUnit()) + ")"; + } + if (type instanceof ArrowType.Null) { + return "Null"; + } + if (type instanceof ArrowType.List) { + return "List(" + listChild(field.getChildren().get(0)) + ")"; + } + if (type instanceof ArrowType.Struct) { + StringBuilder sb = new StringBuilder("Struct("); + List children = field.getChildren(); + for (int i = 0; i < children.size(); i++) { + if (i > 0) { + sb.append(", "); + } + sb.append(structField(children.get(i))); + } + return sb.append(")").toString(); + } + if (type instanceof ArrowType.Map m) { + Field entries = field.getChildren().get(0); + return "Map(" + + structField(entries) + + ", " + + (m.getKeysSorted() ? "sorted" : "unsorted") + + ")"; + } + throw unsupported(field); + } + + /** {@code [, field: 'name']} -- the list child form. */ + private static String listChild(Field field) { + String rendered = nullability(field) + render(field); + return "item".equals(field.getName()) + ? rendered + : rendered + ", field: '" + field.getName() + "'"; + } + + /** {@code "name": } -- the struct/map field form. */ + private static String structField(Field field) { + return debugQuote(field.getName()) + ": " + nullability(field) + render(field); + } + + private static String nullability(Field field) { + return field.isNullable() ? "" : "non-null "; + } + + private static String timeUnitName(TimeUnit unit) { + return switch (unit) { + case SECOND -> "Second"; + case MILLISECOND -> "Millisecond"; + case MICROSECOND -> "Microsecond"; + case NANOSECOND -> "Nanosecond"; + }; + } + + /** Reproduce Rust's {@code {:?}} string quoting used by the Arrow display form. */ + private static String debugQuote(String s) { + StringBuilder sb = new StringBuilder("\""); + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + switch (c) { + case '"' -> sb.append("\\\""); + case '\\' -> sb.append("\\\\"); + case '\n' -> sb.append("\\n"); + case '\r' -> sb.append("\\r"); + case '\t' -> sb.append("\\t"); + default -> sb.append(c); + } + } + return sb.append("\"").toString(); + } + + private static IllegalArgumentException unsupported(Field field) { + return new IllegalArgumentException( + "unsupported Arrow type for column '" + field.getName() + "': " + field.getType()); + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/SqlQuery.java b/spark/src/main/java/org/apache/datafusion/spark/SqlQuery.java new file mode 100644 index 0000000..92879c5 --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/SqlQuery.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.util.List; +import java.util.OptionalLong; +import java.util.stream.Collectors; + +import org.apache.datafusion.spark.SchemaConverter.ProjectionColumn; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; + +/** + * Renders the pushed-down scan as an ANSI SQL {@code SELECT}. + * + *

This is the fallback wire for ADBC clients that don't accept Substrait. The arrow-adbc JNI + * driver manager (0.23) forwards {@code SetSqlQuery} but not {@code SetSubstraitPlan}, so the + * connector prefers Substrait and falls back to this. The pushable predicate set is identical to + * {@link SubstraitPlan} (so the {@link AdbcScanBuilder} decision is encoding-independent) -- which + * is why {@link SubstraitPlan#canPush} excludes non-finite floats: they have no SQL literal. + * + *

Identifiers are double-quoted (ANSI, DataFusion's default) and string literals single-quoted, + * both with doubling-based escaping. + * + *

Columns whose source Arrow type is not Spark-native (see {@link SchemaConverter#needsCast}) + * are wrapped in {@code arrow_cast(col, '')} and re-aliased to their original name, so + * the scan emits Spark-native Arrow and the output column names still match the reported schema. + */ +final class SqlQuery { + + private SqlQuery() {} + + static String build( + String table, List columns, List filters, OptionalLong limit) { + StringBuilder sql = new StringBuilder("SELECT "); + if (columns == null || columns.isEmpty()) { + sql.append("*"); + } else { + sql.append(columns.stream().map(SqlQuery::column).collect(Collectors.joining(", "))); + } + sql.append(" FROM ").append(quoteId(table)); + if (!filters.isEmpty()) { + sql.append(" WHERE ") + .append(filters.stream().map(SqlQuery::predicate).collect(Collectors.joining(" AND "))); + } + if (limit.isPresent()) { + sql.append(" LIMIT ").append(limit.getAsLong()); + } + return sql.toString(); + } + + private static String column(ProjectionColumn c) { + if (c.castType() == null) { + return quoteId(c.name()); + } + // arrow_cast renames its output, so alias back to the source name. + return "arrow_cast(" + quoteId(c.name()) + ", '" + c.castType() + "') AS " + quoteId(c.name()); + } + + private static String predicate(Filter f) { + if (f instanceof EqualTo e) { + return binary(e.attribute(), "=", e.value()); + } else if (f instanceof GreaterThan e) { + return binary(e.attribute(), ">", e.value()); + } else if (f instanceof GreaterThanOrEqual e) { + return binary(e.attribute(), ">=", e.value()); + } else if (f instanceof LessThan e) { + return binary(e.attribute(), "<", e.value()); + } else if (f instanceof LessThanOrEqual e) { + return binary(e.attribute(), "<=", e.value()); + } else if (f instanceof IsNull e) { + return "(" + quoteId(e.attribute()) + " IS NULL)"; + } else if (f instanceof IsNotNull e) { + return "(" + quoteId(e.attribute()) + " IS NOT NULL)"; + } else if (f instanceof And e) { + return "(" + predicate(e.left()) + " AND " + predicate(e.right()) + ")"; + } else if (f instanceof Or e) { + return "(" + predicate(e.left()) + " OR " + predicate(e.right()) + ")"; + } else if (f instanceof Not e) { + return "(NOT " + predicate(e.child()) + ")"; + } + throw new IllegalArgumentException("filter is not pushable to SQL: " + f); + } + + private static String binary(String attribute, String op, Object value) { + return "(" + quoteId(attribute) + " " + op + " " + literal(value) + ")"; + } + + private static String literal(Object v) { + if (v instanceof Boolean b) { + return b ? "TRUE" : "FALSE"; + } else if (v instanceof String s) { + return "'" + s.replace("'", "''") + "'"; + } else if (v instanceof Integer + || v instanceof Long + || v instanceof Short + || v instanceof Byte) { + return v.toString(); + } else if (v instanceof Double || v instanceof Float) { + // SubstraitPlan.canPush has already rejected NaN/Infinity, so toString is a + // valid SQL numeric literal here. + return v.toString(); + } + throw new IllegalArgumentException("unsupported SQL literal type: " + v.getClass()); + } + + private static String quoteId(String identifier) { + return "\"" + identifier.replace("\"", "\"\"") + "\""; + } +} diff --git a/spark/src/main/java/org/apache/datafusion/spark/SubstraitPlan.java b/spark/src/main/java/org/apache/datafusion/spark/SubstraitPlan.java new file mode 100644 index 0000000..83a24c0 --- /dev/null +++ b/spark/src/main/java/org/apache/datafusion/spark/SubstraitPlan.java @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.OptionalLong; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.sources.And; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.GreaterThanOrEqual; +import org.apache.spark.sql.sources.IsNotNull; +import org.apache.spark.sql.sources.IsNull; +import org.apache.spark.sql.sources.LessThan; +import org.apache.spark.sql.sources.LessThanOrEqual; +import org.apache.spark.sql.sources.Not; +import org.apache.spark.sql.sources.Or; + +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; +import io.substrait.plan.Plan; +import io.substrait.plan.PlanProtoConverter; +import io.substrait.relation.NamedScan; +import io.substrait.relation.Rel; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; + +/** + * Builds the Substrait plan pushed to the ADBC statement ({@code setSubstraitPlan}). + * + *

Shape: {@code NamedScan -> [Filter] -> [Project(emit)] -> [Fetch]}. The plan is built with + * substrait-java's {@link SubstraitBuilder}, and every pushed predicate resolves against the + * standard Substrait function catalog ({@link + * DefaultExtensionCatalog#FUNCTIONS_COMPARISON} / {@link + * DefaultExtensionCatalog#FUNCTIONS_BOOLEAN}). That is the key to interop: those are the function + * URIs/signatures DataFusion's Substrait consumer ({@code from_substrait_plan}) recognizes, so we + * never emit a custom extension the two sides could disagree on. Anything outside the whitelist in + * {@link #canPush} is left to Spark. + * + *

Pushdown is carried as typed Substrait rather than a SQL string, so literal fidelity + * (float/decimal/NaN/null-safe) is not lost in a text round-trip. + */ +final class SubstraitPlan { + + /** Load the standard extension declarations once; immutable and shareable. */ + private static final SimpleExtension.ExtensionCollection EXTENSIONS = + SimpleExtension.loadDefaults(); + + /** Predicates produce a (nullable, three-valued) boolean. */ + private static final Type BOOL = TypeCreator.NULLABLE.BOOLEAN; + + private SubstraitPlan() {} + + /** + * Build the scan plan with the given pushdown. + * + * @param table the table name the provider exposes + * @param schema the table's full Arrow schema + * @param projection kept column names (in output order), or {@code null} for all columns + * @param filters the pushed predicates (must all satisfy {@link #canPush}); ANDed together + * @param limit an optional row limit applied after filtering + */ + static byte[] build( + String table, + Schema schema, + List projection, + List filters, + OptionalLong limit) { + SubstraitBuilder b = new SubstraitBuilder(EXTENSIONS); + + List names = new ArrayList<>(schema.getFields().size()); + List types = new ArrayList<>(schema.getFields().size()); + Map columnIndex = new java.util.HashMap<>(); + int i = 0; + for (Field field : schema.getFields()) { + names.add(field.getName()); + types.add(toSubstraitType(field)); + columnIndex.put(field.getName(), i++); + } + + NamedScan scan = b.namedScan(List.of(table), names, types); + Rel rel = scan; + + if (!filters.isEmpty()) { + rel = b.filter(input -> conjunction(b, input, columnIndex, filters), rel); + } + + List outputNames = names; + if (projection != null && !projection.equals(names)) { + List keep = projection.stream().map(columnIndex::get).collect(Collectors.toList()); + int inputWidth = names.size(); + rel = + b.project( + input -> + keep.stream() + .map(k -> (Expression) b.fieldReference(input, k)) + .collect(Collectors.toList()), + // Project appends expressions after the input columns; emit only the + // appended ones so the output is exactly the kept columns. + Rel.Remap.offset(inputWidth, keep.size()), + rel); + outputNames = projection; + } + + if (limit.isPresent()) { + rel = b.limit(limit.getAsLong(), rel); + } + + Plan plan = b.plan(Plan.Root.builder().input(rel).addAllNames(outputNames).build()); + return new PlanProtoConverter().toProto(plan).toByteArray(); + } + + // --- Spark Filter -> Substrait Expression --------------------------------- + + /** Whether {@code f} maps to a standard-catalog Substrait predicate over known columns. */ + static boolean canPush(Filter f, Set columns) { + if (f instanceof EqualTo e) { + return columns.contains(e.attribute()) && isPushableLiteral(e.value()); + } else if (f instanceof GreaterThan e) { + return columns.contains(e.attribute()) && isPushableLiteral(e.value()); + } else if (f instanceof GreaterThanOrEqual e) { + return columns.contains(e.attribute()) && isPushableLiteral(e.value()); + } else if (f instanceof LessThan e) { + return columns.contains(e.attribute()) && isPushableLiteral(e.value()); + } else if (f instanceof LessThanOrEqual e) { + return columns.contains(e.attribute()) && isPushableLiteral(e.value()); + } else if (f instanceof IsNull e) { + return columns.contains(e.attribute()); + } else if (f instanceof IsNotNull e) { + return columns.contains(e.attribute()); + } else if (f instanceof And e) { + return canPush(e.left(), columns) && canPush(e.right(), columns); + } else if (f instanceof Or e) { + return canPush(e.left(), columns) && canPush(e.right(), columns); + } else if (f instanceof Not e) { + return canPush(e.child(), columns); + } + return false; + } + + private static Expression conjunction( + SubstraitBuilder b, Rel input, Map idx, List filters) { + List terms = + filters.stream().map(f -> translate(b, input, idx, f)).collect(Collectors.toList()); + if (terms.size() == 1) { + return terms.get(0); + } + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, + "and:bool", + BOOL, + terms.toArray(new Expression[0])); + } + + private static Expression translate( + SubstraitBuilder b, Rel input, Map idx, Filter f) { + if (f instanceof EqualTo e) { + return cmp(b, input, idx, "equal:any_any", e.attribute(), e.value()); + } else if (f instanceof GreaterThan e) { + return cmp(b, input, idx, "gt:any_any", e.attribute(), e.value()); + } else if (f instanceof GreaterThanOrEqual e) { + return cmp(b, input, idx, "gte:any_any", e.attribute(), e.value()); + } else if (f instanceof LessThan e) { + return cmp(b, input, idx, "lt:any_any", e.attribute(), e.value()); + } else if (f instanceof LessThanOrEqual e) { + return cmp(b, input, idx, "lte:any_any", e.attribute(), e.value()); + } else if (f instanceof IsNull e) { + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_COMPARISON, + "is_null:any", + BOOL, + b.fieldReference(input, idx.get(e.attribute()))); + } else if (f instanceof IsNotNull e) { + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_COMPARISON, + "is_not_null:any", + BOOL, + b.fieldReference(input, idx.get(e.attribute()))); + } else if (f instanceof And e) { + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, + "and:bool", + BOOL, + translate(b, input, idx, e.left()), + translate(b, input, idx, e.right())); + } else if (f instanceof Or e) { + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, + "or:bool", + BOOL, + translate(b, input, idx, e.left()), + translate(b, input, idx, e.right())); + } else if (f instanceof Not e) { + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_BOOLEAN, + "not:bool", + BOOL, + translate(b, input, idx, e.child())); + } + throw new IllegalArgumentException("filter is not pushable: " + f); + } + + private static Expression cmp( + SubstraitBuilder b, + Rel input, + Map idx, + String key, + String attribute, + Object value) { + return b.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_COMPARISON, + key, + BOOL, + b.fieldReference(input, idx.get(attribute)), + literal(value)); + } + + private static boolean isPushableLiteral(Object v) { + // Exclude non-finite floats: they have no SQL literal, and the connector may + // encode this predicate as SQL (when the driver lacks Substrait support), so + // the pushable set must be expressible in both wires. + if (v instanceof Double d) { + return !d.isNaN() && !d.isInfinite(); + } + if (v instanceof Float f) { + return !f.isNaN() && !f.isInfinite(); + } + return v instanceof Integer + || v instanceof Long + || v instanceof Short + || v instanceof Byte + || v instanceof Boolean + || v instanceof String; + } + + private static Expression literal(Object v) { + if (v instanceof Integer n) { + return ExpressionCreator.i32(true, n); + } else if (v instanceof Long n) { + return ExpressionCreator.i64(true, n); + } else if (v instanceof Short n) { + return ExpressionCreator.i16(true, n); + } else if (v instanceof Byte n) { + return ExpressionCreator.i8(true, n); + } else if (v instanceof Double n) { + return ExpressionCreator.fp64(true, n); + } else if (v instanceof Float n) { + return ExpressionCreator.fp32(true, n); + } else if (v instanceof Boolean n) { + return ExpressionCreator.bool(true, n); + } else if (v instanceof String n) { + return ExpressionCreator.string(true, n); + } + throw new IllegalArgumentException("unsupported literal type: " + v.getClass()); + } + + // --- Arrow type -> Substrait type ----------------------------------------- + + private static Type toSubstraitType(Field field) { + ArrowType type = field.getType(); + TypeCreator t = field.isNullable() ? TypeCreator.NULLABLE : TypeCreator.REQUIRED; + if (type instanceof ArrowType.Int n) { + if (!n.getIsSigned()) { + throw unsupported(field); + } + return switch (n.getBitWidth()) { + case 8 -> t.I8; + case 16 -> t.I16; + case 32 -> t.I32; + case 64 -> t.I64; + default -> throw unsupported(field); + }; + } + if (type instanceof ArrowType.FloatingPoint fp) { + return fp.getPrecision() == FloatingPointPrecision.DOUBLE ? t.FP64 : t.FP32; + } + if (type instanceof ArrowType.Utf8 || type instanceof ArrowType.LargeUtf8) { + return t.STRING; + } + if (type instanceof ArrowType.Bool) { + return t.BOOLEAN; + } + throw unsupported(field); + } + + private static IllegalArgumentException unsupported(Field field) { + return new IllegalArgumentException( + "unsupported Arrow type for column '" + field.getName() + "': " + field.getType()); + } +} diff --git a/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000..9237b20 --- /dev/null +++ b/spark/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1 @@ +org.apache.datafusion.spark.AdbcDatafusionTableProvider diff --git a/spark/src/test/java/org/apache/datafusion/spark/AdbcConnectionPoolTest.java b/spark/src/test/java/org/apache/datafusion/spark/AdbcConnectionPoolTest.java new file mode 100644 index 0000000..75332cc --- /dev/null +++ b/spark/src/test/java/org/apache/datafusion/spark/AdbcConnectionPoolTest.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.arrow.adbc.core.AdbcConnection; +import org.apache.arrow.adbc.core.AdbcDatabase; +import org.apache.arrow.adbc.core.AdbcException; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +/** Unit + concurrency tests for {@link AdbcConnectionPool}; no native driver required. */ +class AdbcConnectionPoolTest { + + @AfterEach + void tearDown() { + AdbcConnectionPool.resetForTesting(); + } + + private static AdbcOptions options(Map raw) { + return AdbcOptions.fromOptions(new CaseInsensitiveStringMap(new LinkedHashMap<>(raw))); + } + + private static LinkedHashMap base() { + LinkedHashMap m = new LinkedHashMap<>(); + m.put("driver", "/path/to/lib.so"); + m.put("table", "t"); + return m; + } + + // ---- cacheKey() ---- + + @Test + void cacheKeyIgnoresOptionOrder() { + LinkedHashMap a = base(); + a.put("foo", "1"); + a.put("bar", "2"); + LinkedHashMap b = base(); + b.put("bar", "2"); + b.put("foo", "1"); + + AdbcConnectionPool.Key ka = options(a).cacheKey(); + AdbcConnectionPool.Key kb = options(b).cacheKey(); + assertEquals(ka, kb); + assertEquals(ka.hashCode(), kb.hashCode()); + } + + @Test + void cacheKeyDistinguishesDriverAndOptionValues() { + LinkedHashMap a = base(); + a.put("foo", "1"); + LinkedHashMap diffValue = base(); + diffValue.put("foo", "2"); + LinkedHashMap diffDriver = base(); + diffDriver.put("driver", "/other/lib.so"); + diffDriver.put("foo", "1"); + + assertNotEquals(options(a).cacheKey(), options(diffValue).cacheKey()); + assertNotEquals(options(a).cacheKey(), options(diffDriver).cacheKey()); + } + + @Test + void cacheKeyIgnoresControlKeys() { + // table and target_partitions are connector-control, not forwarded to the native database, + // so they must not affect the cache key. + LinkedHashMap a = base(); + a.put("target_partitions", "4"); + LinkedHashMap b = base(); + b.put("table", "other-table"); + b.put("target_partitions", "16"); + + assertEquals(options(a).cacheKey(), options(b).cacheKey()); + } + + // ---- computeIfAbsent: at-most-once build under concurrency ---- + + @Test + void concurrentAcquireBuildsExactlyOnce() throws Exception { + AtomicInteger builds = new AtomicInteger(); + AdbcConnectionPool.setBuilderForTesting( + (key, opts) -> { + builds.incrementAndGet(); + return new AdbcConnectionPool.CachedDatabase( + new RootAllocator(), mock(AdbcDatabase.class)); + }); + + AdbcOptions opts = options(base()); + int threads = 16; + CyclicBarrier barrier = new CyclicBarrier(threads); + ConcurrentLinkedQueue leases = new ConcurrentLinkedQueue<>(); + ConcurrentLinkedQueue errors = new ConcurrentLinkedQueue<>(); + + runConcurrently( + threads, + () -> { + try { + barrier.await(); + leases.add(AdbcConnectionPool.acquire(opts)); + } catch (Throwable t) { + errors.add(t); + } + }); + + assertTrue(errors.isEmpty(), () -> "unexpected errors: " + errors); + assertEquals(1, builds.get(), "database built exactly once for one key"); + assertEquals(1, AdbcConnectionPool.cacheSizeForTesting()); + assertEquals(threads, leases.size()); + for (AdbcConnectionPool.Lease lease : leases) { + lease.close(); + } + } + + @Test + void distinctKeysBuildDistinctEntries() throws Exception { + AdbcConnectionPool.setBuilderForTesting( + (key, opts) -> + new AdbcConnectionPool.CachedDatabase(new RootAllocator(), mock(AdbcDatabase.class))); + + LinkedHashMap a = base(); + a.put("foo", "1"); + LinkedHashMap b = base(); + b.put("foo", "2"); + + AdbcConnectionPool.Lease la = AdbcConnectionPool.acquire(options(a)); + AdbcConnectionPool.Lease lb = AdbcConnectionPool.acquire(options(b)); + assertEquals(2, AdbcConnectionPool.cacheSizeForTesting()); + la.close(); + lb.close(); + } + + @Test + void buildFailureSurfacesAdbcExceptionAndLeavesCacheEmpty() throws Exception { + AtomicInteger attempts = new AtomicInteger(); + AdbcConnectionPool.setBuilderForTesting( + (key, opts) -> { + attempts.incrementAndGet(); + throw AdbcException.io("boom"); + }); + + AdbcOptions opts = options(base()); + int threads = 8; + CyclicBarrier barrier = new CyclicBarrier(threads); + ConcurrentLinkedQueue errors = new ConcurrentLinkedQueue<>(); + + runConcurrently( + threads, + () -> { + try { + barrier.await(); + AdbcConnectionPool.acquire(opts); + } catch (AdbcException e) { + errors.add(e); // expected + } catch (Throwable t) { + errors.add(t); + } + }); + + assertEquals(threads, errors.size()); + for (Throwable t : errors) { + assertTrue(t instanceof AdbcException, () -> "expected AdbcException, got " + t); + } + assertEquals(0, AdbcConnectionPool.cacheSizeForTesting(), "failed build leaves no entry"); + + // A subsequent good build inserts exactly one entry (retry works). + AdbcConnectionPool.setBuilderForTesting( + (key, o) -> + new AdbcConnectionPool.CachedDatabase(new RootAllocator(), mock(AdbcDatabase.class))); + AdbcConnectionPool.Lease lease = AdbcConnectionPool.acquire(opts); + assertEquals(1, AdbcConnectionPool.cacheSizeForTesting()); + lease.close(); + } + + // ---- Lease close semantics ---- + + @Test + void leaseCloseClosesTaskConnectionButNotDatabase() throws Exception { + try (BufferAllocator root = new RootAllocator()) { + AdbcDatabase db = mock(AdbcDatabase.class); + AdbcConnection taskConn = mock(AdbcConnection.class); + AdbcConnectionPool.CachedDatabase cached = new AdbcConnectionPool.CachedDatabase(root, db); + BufferAllocator child = root.newChildAllocator("task", 0, Long.MAX_VALUE); + + AdbcConnectionPool.Lease lease = new AdbcConnectionPool.Lease(cached, child, taskConn); + lease.close(); + + verify(taskConn, times(1)).close(); + verify(db, never()).close(); + } + } + + // ---- acquire opens a per-task connection off the one cached database ---- + + @Test + void acquireOpensConnectionPerTask() throws Exception { + AdbcDatabase db = mock(AdbcDatabase.class); + when(db.connect()).thenAnswer(invocation -> mock(AdbcConnection.class)); + AdbcConnectionPool.setBuilderForTesting( + (key, opts) -> new AdbcConnectionPool.CachedDatabase(new RootAllocator(), db)); + + AdbcOptions opts = options(base()); + AdbcConnectionPool.Lease l1 = AdbcConnectionPool.acquire(opts); + AdbcConnectionPool.Lease l2 = AdbcConnectionPool.acquire(opts); + + assertEquals(1, AdbcConnectionPool.cacheSizeForTesting(), "one cached database"); + verify(db, times(2)).connect(); // one connection per task + assertNotEquals(l1.connection(), l2.connection()); + l1.close(); + l2.close(); + } + + private static void runConcurrently(int threads, Runnable body) throws InterruptedException { + Thread[] pool = new Thread[threads]; + for (int i = 0; i < threads; i++) { + pool[i] = new Thread(body, "acquire-" + i); + pool[i].start(); + } + for (Thread t : pool) { + t.join(); + } + } +} diff --git a/spark/src/test/java/org/apache/datafusion/spark/AdbcOptionsTest.java b/spark/src/test/java/org/apache/datafusion/spark/AdbcOptionsTest.java new file mode 100644 index 0000000..bde17e9 --- /dev/null +++ b/spark/src/test/java/org/apache/datafusion/spark/AdbcOptionsTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Map; + +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.junit.jupiter.api.Test; + +class AdbcOptionsTest { + + private static CaseInsensitiveStringMap opts(Map m) { + return new CaseInsensitiveStringMap(m); + } + + @Test + void parsesTargetPartitions() { + AdbcOptions o = + AdbcOptions.fromOptions( + opts(Map.of("driver", "/lib.so", "table", "t", "target_partitions", "16"))); + assertTrue(o.targetPartitions().isPresent()); + assertEquals(16, o.targetPartitions().getAsInt()); + } + + @Test + void targetPartitionsAbsentByDefault() { + AdbcOptions o = AdbcOptions.fromOptions(opts(Map.of("driver", "/lib.so", "table", "t"))); + assertFalse(o.targetPartitions().isPresent()); + } + + @Test + void rejectsNonPositiveOrNonNumeric() { + assertThrows( + IllegalArgumentException.class, + () -> + AdbcOptions.fromOptions( + opts(Map.of("driver", "/lib.so", "table", "t", "target_partitions", "0")))); + assertThrows( + IllegalArgumentException.class, + () -> + AdbcOptions.fromOptions( + opts(Map.of("driver", "/lib.so", "table", "t", "target_partitions", "abc")))); + } + + @Test + void controlKeysNotForwardedAsDatabaseOptions() { + AdbcOptions o = + AdbcOptions.fromOptions( + opts( + Map.of( + "driver", "/lib.so", + "table", "t", + "target_partitions", "8", + "my.provider.opt", "v"))); + Map params = o.driverParameters(); + // driver -> jni.driver; provider opt passed through; control keys absent. + assertTrue(params.containsKey("my.provider.opt")); + assertFalse(params.containsKey("target_partitions")); + assertFalse(params.containsKey("table")); + } +} diff --git a/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java b/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java new file mode 100644 index 0000000..fc9105e --- /dev/null +++ b/spark/src/test/java/org/apache/datafusion/spark/AdbcSourceTest.java @@ -0,0 +1,373 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import java.math.BigDecimal; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.arrow.adbc.core.TypedKey; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +/** + * End-to-end test of the {@code adbc-datafusion} source against the example driver cdylib (built + * from {@code examples/adbc-datafusion-driver}). It drives the full stack: Spark DataSourceV2 -> + * arrow-adbc JNI driver manager -> the DataFusion ADBC cdylib -> our custom provider. + * + *

Requires the cdylib path in the {@code adbc.example.driver.path} system property; the test is + * skipped (not failed) when it is absent, so the build is green without the native artifact. To + * run: + * + *

+ *   (cd examples/adbc-datafusion-driver && cargo build --release)
+ *   mvn -pl spark -am test -Dtest=AdbcSourceTest \
+ *     -Dadbc.example.driver.path=$PWD/rust-target/release/libadbc_datafusion_example_driver.dylib
+ * 
+ */ +class AdbcSourceTest { + + private static final String ENTRYPOINT = "AdbcDatafusionExampleInit"; + private static final String TABLE = "example"; + // A second table whose schema spans the Arrow types the connector maps or casts. + private static final String TYPES_TABLE = "types"; + + private static SparkSession spark; + private static String driverPath; + + @BeforeAll + static void setUp() { + driverPath = System.getProperty("adbc.example.driver.path"); + assumeTrue( + driverPath != null && Files.exists(Path.of(driverPath)), + "set -Dadbc.example.driver.path to the example driver cdylib to run this test"); + // local[8]: several task slots in one executor JVM, so the per-executor connection cache + // (AdbcConnectionPool) is exercised across concurrent tasks. + // Java 8 date/time API: DateType -> java.time.LocalDate, avoiding Spark's legacy + // java.sql.Date conversion which needs sun.util.calendar opened on JDK 17. + spark = + SparkSession.builder() + .appName("adbc-source-test") + .master("local[8]") + .config("spark.sql.datetime.java8API.enabled", "true") + .getOrCreate(); + } + + @AfterAll + static void tearDown() { + if (spark != null) { + spark.stop(); + } + } + + private Dataset load() { + return load(TABLE); + } + + private Dataset load(String table) { + return spark + .read() + .format("adbc-datafusion") + .option("driver", driverPath) + .option("entrypoint", ENTRYPOINT) + .option("table", table) + .load(); + } + + /** + * The {@code types} table exercises the full converter: directly-representable columns pass + * through, cast-required columns (unsigned, ns timestamp, Float16, nested {@code List}) + * are cast to a Spark-native layout by a source-side {@code arrow_cast} pushed into the scan. + * This asserts the reported Spark types and the cast-column metadata flags. + */ + @Test + void typesSchemaMapsAndFlagsCasts() { + StructType schema = load(TYPES_TABLE).schema(); + + assertEquals(DataTypes.BinaryType, schema.apply("payload").dataType()); + assertEquals(DataTypes.IntegerType, schema.apply("channel").dataType()); + assertEquals(DataTypes.createDecimalType(20, 0), schema.apply("big").dataType()); + assertEquals(DataTypes.TimestampNTZType, schema.apply("event_time").dataType()); + assertEquals(DataTypes.FloatType, schema.apply("score").dataType()); + assertEquals( + DataTypes.createArrayType(DataTypes.IntegerType, true), schema.apply("tags").dataType()); + // FixedSizeList -> variable Array (fixed layout not readable by Spark). + assertEquals( + DataTypes.createArrayType(DataTypes.IntegerType, true), schema.apply("vec").dataType()); + // LargeList -> Array; FixedSizeBinary -> Binary; Date64 -> Date. + assertEquals( + DataTypes.createArrayType(DataTypes.StringType, true), schema.apply("labels").dataType()); + assertEquals(DataTypes.BinaryType, schema.apply("digest").dataType()); + assertEquals(DataTypes.DateType, schema.apply("day").dataType()); + + // Cast columns are flagged (so filter pushdown stays off them); pass-through columns are not. + assertTrue(schema.apply("channel").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("big").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("event_time").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("score").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("tags").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("vec").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("labels").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("digest").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(schema.apply("day").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertFalse(schema.apply("payload").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertFalse(schema.apply("attrs").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + } + + /** + * The cast-requiring scan still parallelizes. Casts force the SQL wire (Substrait can't encode + * unsigned/Float16), but ADBC {@code executePartitioned} is wire-independent -- it partitions on + * the physical plan's output partitioning, and the source-side {@code arrow_cast} is a + * partition-preserving projection -- so the {@code types} scan gets one Spark partition per + * driver partition, exactly like the Substrait-wire {@code example} scan. + */ + @Test + void typesScanIsMultiPartitionDespiteCasts() { + assertTrue( + load(TYPES_TABLE).rdd().getNumPartitions() >= 2, + "expected multiple Spark partitions for the cast (SQL-wire) scan, got " + + load(TYPES_TABLE).rdd().getNumPartitions()); + } + + /** + * count() over a cast-requiring table. Catalyst prunes the projection to empty for a count, and a + * bare {@code SELECT *} would return the raw (uncast) schema -- the reader would then reject a + * non-Spark-native column (e.g. {@code Timestamp(NANOSECOND)}). The connector emits a single + * readable probe column instead, so the row count still comes back. + */ + @Test + void typesCountWorksWithoutMaterializingCastColumns() { + assertEquals(3L, load(TYPES_TABLE).count()); + } + + /** Every column decodes to the expected Spark-native value -- casts are value-correct. */ + @Test + void typesValuesRoundTripThroughCasts() { + Map byId = + load(TYPES_TABLE).collectAsList().stream() + .collect(Collectors.toMap(r -> r.getLong(r.fieldIndex("id")), r -> r)); + assertEquals(Set.of(1L, 2L, 3L), byId.keySet()); + Row r1 = byId.get(1L); + Row r2 = byId.get(2L); + Row r3 = byId.get(3L); + + // unsigned UInt16 -> Integer, widened past i16::MAX. + assertEquals(100, r1.getInt(r1.fieldIndex("channel"))); + assertEquals(40_000, r2.getInt(r2.fieldIndex("channel"))); + assertEquals(65_535, r3.getInt(r3.fieldIndex("channel"))); + + // unsigned UInt64 -> Decimal(20,0): lossless for values past i64::MAX (a Long would overflow). + assertEquals(0, new BigDecimal("18446744073709551615").compareTo(bigValue(r1))); + assertEquals(0, BigDecimal.ZERO.compareTo(bigValue(r2))); + assertEquals(0, new BigDecimal("9223372036854775808").compareTo(bigValue(r3))); + + // nanosecond Timestamp -> microsecond TimestampNTZ, rescaled (a relabel would land near 1970). + assertEquals(LocalDateTime.of(2020, 9, 13, 12, 26, 40), r1.getAs("event_time")); + assertEquals(LocalDateTime.of(2021, 1, 7, 6, 13, 20), r2.getAs("event_time")); + assertEquals(LocalDateTime.of(2021, 5, 3, 0, 0, 0), r3.getAs("event_time")); + + // Float16 -> Float. + assertEquals(1.5f, r1.getFloat(r1.fieldIndex("score"))); + assertEquals(2.5f, r2.getFloat(r2.fieldIndex("score"))); + + // Binary passes through. + assertArrayEquals(new byte[] {0x01, 0x02}, (byte[]) r1.getAs("payload")); + assertArrayEquals(new byte[] {}, (byte[]) r2.getAs("payload")); + assertArrayEquals(new byte[] {(byte) 0xff, (byte) 0xfe}, (byte[]) r3.getAs("payload")); + + // nested List -> Array (recursive cast). + assertEquals(List.of(1, 2), r1.getList(r1.fieldIndex("tags"))); + assertEquals(List.of(), r2.getList(r2.fieldIndex("tags"))); + assertEquals(List.of(3), r3.getList(r3.fieldIndex("tags"))); + + // FixedSizeList -> Array (fixed->variable + element widening). + assertEquals(List.of(10, 20), r1.getList(r1.fieldIndex("vec"))); + assertEquals(List.of(30, 40), r2.getList(r2.fieldIndex("vec"))); + assertEquals(List.of(50, 60), r3.getList(r3.fieldIndex("vec"))); + + // LargeList -> Array (large list not readable -> cast to variable list). + assertEquals(List.of("a", "b"), r1.getList(r1.fieldIndex("labels"))); + assertEquals(List.of(), r2.getList(r2.fieldIndex("labels"))); + assertEquals(List.of("c"), r3.getList(r3.fieldIndex("labels"))); + + // FixedSizeBinary -> Binary. + assertArrayEquals(new byte[] {0x01, 0x02, 0x03, 0x04}, (byte[]) r1.getAs("digest")); + assertArrayEquals( + new byte[] {(byte) 0xff, (byte) 0xfe, (byte) 0xfd, (byte) 0xfc}, + (byte[]) r3.getAs("digest")); + + // Date64 -> Date32 (day-aligned). + assertEquals(LocalDate.ofEpochDay(18518), toLocalDate(r1.getAs("day"))); + assertEquals(LocalDate.ofEpochDay(18634), toLocalDate(r2.getAs("day"))); + assertEquals(LocalDate.ofEpochDay(18750), toLocalDate(r3.getAs("day"))); + + // nested List> passes through. + List attrs3 = r3.getList(r3.fieldIndex("attrs")); + assertEquals(2, attrs3.size()); + assertEquals("b", attrs3.get(0).getAs("key")); + assertEquals("2", attrs3.get(0).getAs("val")); + assertEquals("c", attrs3.get(1).getAs("key")); + assertEquals(List.of(), r2.getList(r2.fieldIndex("attrs"))); + } + + private static BigDecimal bigValue(Row row) { + return row.getDecimal(row.fieldIndex("big")); + } + + /** Spark returns DateType as java.sql.Date, or java.time.LocalDate under the Java 8 date API. */ + private static LocalDate toLocalDate(Object value) { + return value instanceof LocalDate d ? d : ((java.sql.Date) value).toLocalDate(); + } + + @Test + void fullScanAcrossPartitions() { + Dataset df = load(); + + // Schema inferred via ADBC getTableSchema. + assertEquals("id", df.schema().fields()[0].name()); + assertEquals("name", df.schema().fields()[1].name()); + + // The provider has three partitions; with arrow-adbc >= 0.24 (executePartitioned + // forwarded by the JNI bridge) they surface as multiple Spark input partitions. + assertTrue( + df.rdd().getNumPartitions() >= 2, + "expected multiple Spark partitions, got " + df.rdd().getNumPartitions()); + + List ids = + df.collectAsList().stream().map(r -> r.getLong(0)).sorted().collect(Collectors.toList()); + assertEquals(List.of(1L, 2L, 3L), ids); + } + + @Test + void projectionPushdown() { + List columns = + load().select("name").collectAsList().stream() + .map(r -> r.getString(0)) + .sorted() + .collect(Collectors.toList()); + assertEquals(List.of("alice", "bob", "carol"), columns); + } + + @Test + void filterPushdown() { + List ids = + load().filter("id > 1").collectAsList().stream() + .map(r -> r.getLong(0)) + .sorted() + .collect(Collectors.toList()); + assertEquals(List.of(2L, 3L), ids); + } + + /** + * Regression for the per-executor cache: across a multi-partition scan over many task slots, the + * pool builds the native database (and so runs the driver's {@code ContextInit} provider + * registration) exactly once per executor JVM -- not once per task. Each task still opens its own + * connection off that one cached database. + */ + @Test + void providerRegisteredOncePerExecutor() { + // Force reader tasks to run across the executor's slots. + long rows = load().collectAsList().size(); + assertTrue(rows > 0, "scan returned rows"); + + // All tasks share one driver+options key, so the pool builds exactly one native database for + // this executor JVM regardless of partition/task count. + assertEquals( + 1, + AdbcConnectionPool.databasesBuiltForTesting(), + "expected one native database per executor JVM, got " + + AdbcConnectionPool.databasesBuiltForTesting()); + + // One cached database, but each task opens its own connection off it. + assertTrue( + AdbcConnectionPool.taskConnectionsOpenedForTesting() > 0, + "expected a per-task connection off the one cached database"); + } + + /** + * Proves the driver's database-scoped plan cache is actually used. Every partition descriptor of + * one query carries the same serialized physical plan; the driver deserializes it once and caches + * it across all connections opened from the (per-executor) cached database. So the {@code N} + * per-task connections of one scan must trigger exactly one deserialize, not {@code N}. + * + *

The driver exposes its deserialize count as the read-only int option {@code + * adbc.datafusion.plan_deserialize_count}. The count is JVM-lifetime and shared across tests, so + * we assert on the delta around a single scan, and use a distinct predicate so the plan bytes are + * fresh (not already cached by another test) -- making the expected delta exactly 1. + */ + @Test + void planDeserializedOncePerExecutorNotPerTask() throws Exception { + TypedKey deserializeCount = + new TypedKey<>("adbc.datafusion.plan_deserialize_count", Long.class); + + // Distinct predicate -> distinct physical plan -> plan bytes not cached by another test. + Dataset df = load().filter("id >= 1"); + int partitions = df.rdd().getNumPartitions(); + assertTrue(partitions >= 2, "need multiple partitions to test the collapse, got " + partitions); + + // Same options the scan uses -> same pool cache key -> same cached database (and counter). + LinkedHashMap raw = new LinkedHashMap<>(); + raw.put("driver", driverPath); + raw.put("entrypoint", ENTRYPOINT); + raw.put("table", TABLE); + AdbcOptions opts = AdbcOptions.fromOptions(new CaseInsensitiveStringMap(raw)); + + long before; + try (AdbcConnectionPool.Lease lease = AdbcConnectionPool.acquire(opts)) { + before = lease.connection().getOption(deserializeCount); + } + + df.collectAsList(); // runs read_partition on `partitions` separate task connections + + long after; + try (AdbcConnectionPool.Lease lease = AdbcConnectionPool.acquire(opts)) { + after = lease.connection().getOption(deserializeCount); + } + + assertEquals( + 1L, + after - before, + "the " + + partitions + + " partitions of one query must share ONE plan deserialize (database-scoped cache); " + + "a per-connection cache would deserialize once per task = " + + partitions); + } +} diff --git a/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java b/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java new file mode 100644 index 0000000..fe535f0 --- /dev/null +++ b/spark/src/test/java/org/apache/datafusion/spark/SchemaConverterTest.java @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; + +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.datafusion.spark.SchemaConverter.ProjectionColumn; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for the Arrow -> Spark schema mapping and the {@code arrow_cast} target planning: one + * case per representable Arrow type, plus the cast-requiring types (unsigned, Float16, non-µs + * timestamps) and nested forms. + */ +class SchemaConverterTest { + + private static Field nullable(String name, ArrowType type) { + return Field.nullable(name, type); + } + + private static Field nullable(String name, ArrowType type, Field... children) { + return new Field(name, FieldType.nullable(type), List.of(children)); + } + + private static DataType spark(Field field) { + return SchemaConverter.toSparkType(field); + } + + // --- directly representable (A): mapped, no cast -------------------------- + + @Test + void primitivesMapWithoutCast() { + assertEquals(DataTypes.ByteType, spark(nullable("c", new ArrowType.Int(8, true)))); + assertEquals(DataTypes.ShortType, spark(nullable("c", new ArrowType.Int(16, true)))); + assertEquals(DataTypes.IntegerType, spark(nullable("c", new ArrowType.Int(32, true)))); + assertEquals(DataTypes.LongType, spark(nullable("c", new ArrowType.Int(64, true)))); + assertEquals( + DataTypes.DoubleType, + spark(nullable("c", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)))); + assertEquals(DataTypes.StringType, spark(nullable("c", ArrowType.Utf8.INSTANCE))); + assertEquals(DataTypes.BooleanType, spark(nullable("c", ArrowType.Bool.INSTANCE))); + + assertNull(SchemaConverter.castTargetString(nullable("c", new ArrowType.Int(32, true)))); + } + + @Test + void binaryMapsToBinary() { + assertEquals(DataTypes.BinaryType, spark(nullable("c", ArrowType.Binary.INSTANCE))); + assertEquals(DataTypes.BinaryType, spark(nullable("c", ArrowType.LargeBinary.INSTANCE))); + assertEquals(DataTypes.BinaryType, spark(nullable("c", new ArrowType.FixedSizeBinary(16)))); + assertNull(SchemaConverter.castTargetString(nullable("c", ArrowType.Binary.INSTANCE))); + } + + @Test + void dateDecimalNull() { + assertEquals(DataTypes.DateType, spark(nullable("c", new ArrowType.Date(DateUnit.DAY)))); + assertEquals( + DataTypes.createDecimalType(10, 2), + spark(nullable("c", new ArrowType.Decimal(10, 2, 128)))); + assertEquals(DataTypes.NullType, spark(nullable("c", ArrowType.Null.INSTANCE))); + } + + @Test + void microsecondTimestampNeedsNoCast() { + Field ntz = nullable("c", new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)); + assertEquals(DataTypes.TimestampNTZType, spark(ntz)); + assertNull(SchemaConverter.castTargetString(ntz)); + + Field zoned = nullable("c", new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")); + assertEquals(DataTypes.TimestampType, spark(zoned)); + assertNull(SchemaConverter.castTargetString(zoned)); + } + + @Test + void nestedListStructMap() { + Field struct = + nullable( + "s", + ArrowType.Struct.INSTANCE, + nullable("first", ArrowType.Utf8.INSTANCE), + nullable("second", ArrowType.Utf8.INSTANCE)); + Field listOfStruct = nullable("metadata", new ArrowType.List(), struct); + + DataType mapped = spark(listOfStruct); + assertTrue(mapped instanceof org.apache.spark.sql.types.ArrayType); + // A pure-(A) nested column needs no cast. + assertNull(SchemaConverter.castTargetString(listOfStruct)); + } + + // --- cast required (B): mapped to widened type + arrow_cast target -------- + + @Test + void unsignedIntsWidenAndCast() { + assertEquals(DataTypes.ShortType, spark(nullable("c", new ArrowType.Int(8, false)))); + assertEquals(DataTypes.IntegerType, spark(nullable("c", new ArrowType.Int(16, false)))); + assertEquals(DataTypes.LongType, spark(nullable("c", new ArrowType.Int(32, false)))); + assertEquals( + DataTypes.createDecimalType(20, 0), spark(nullable("c", new ArrowType.Int(64, false)))); + + assertEquals( + "Int16", SchemaConverter.castTargetString(nullable("c", new ArrowType.Int(8, false)))); + assertEquals( + "Int32", SchemaConverter.castTargetString(nullable("c", new ArrowType.Int(16, false)))); + assertEquals( + "Int64", SchemaConverter.castTargetString(nullable("c", new ArrowType.Int(32, false)))); + assertEquals( + "Decimal128(20, 0)", + SchemaConverter.castTargetString(nullable("c", new ArrowType.Int(64, false)))); + } + + @Test + void float16WidensToFloat() { + Field f = nullable("c", new ArrowType.FloatingPoint(FloatingPointPrecision.HALF)); + assertEquals(DataTypes.FloatType, spark(f)); + assertEquals("Float32", SchemaConverter.castTargetString(f)); + } + + @Test + void nonMicrosecondTimestampRescales() { + Field ns = nullable("t", new ArrowType.Timestamp(TimeUnit.NANOSECOND, null)); + assertEquals(DataTypes.TimestampNTZType, spark(ns)); + assertEquals("Timestamp(Microsecond)", SchemaConverter.castTargetString(ns)); + + Field zoned = nullable("t", new ArrowType.Timestamp(TimeUnit.NANOSECOND, "UTC")); + assertEquals(DataTypes.TimestampType, spark(zoned)); + assertEquals("Timestamp(Microsecond, \"UTC\")", SchemaConverter.castTargetString(zoned)); + } + + @Test + void nestedListOfUnsignedCastsRecursively() { + Field list = + nullable("ids", new ArrowType.List(), nullable("item", new ArrowType.Int(16, false))); + DataType mapped = spark(list); + assertEquals(DataTypes.createArrayType(DataTypes.IntegerType, true), mapped); + assertEquals("List(Int32)", SchemaConverter.castTargetString(list)); + } + + @Test + void fixedSizeListCastsToVariableList() { + // Spark reads ArrayType only from a variable ListVector, so a fixed-size list is always cast + // to a variable list -- even when its element needs no cast. + Field intFixed = + new Field( + "v", + FieldType.nullable(new ArrowType.FixedSizeList(2)), + List.of(nullable("item", new ArrowType.Int(32, true)))); + assertEquals(DataTypes.createArrayType(DataTypes.IntegerType, true), spark(intFixed)); + assertEquals("List(Int32)", SchemaConverter.castTargetString(intFixed)); + + // The element is rendered cast-aware: FixedSizeList -> List(Float32). + Field halfFixed = + new Field( + "v", + FieldType.nullable(new ArrowType.FixedSizeList(3)), + List.of(nullable("item", new ArrowType.FloatingPoint(FloatingPointPrecision.HALF)))); + assertEquals(DataTypes.createArrayType(DataTypes.FloatType, true), spark(halfFixed)); + assertEquals("List(Float32)", SchemaConverter.castTargetString(halfFixed)); + } + + @Test + void largeListCastsToVariableList() { + // LargeList maps to ArrayType but has no ArrowColumnVector accessor -> must cast to List. + Field large = + new Field( + "v", + FieldType.nullable(new ArrowType.LargeList()), + List.of(nullable("item", new ArrowType.Int(32, true)))); + assertEquals(DataTypes.createArrayType(DataTypes.IntegerType, true), spark(large)); + assertEquals("List(Int32)", SchemaConverter.castTargetString(large)); + } + + @Test + void fixedSizeBinaryCastsToBinary() { + // FixedSizeBinary maps to BinaryType but has no accessor -> must cast to variable Binary. + Field fsb = nullable("c", new ArrowType.FixedSizeBinary(16)); + assertEquals(DataTypes.BinaryType, spark(fsb)); + assertEquals("Binary", SchemaConverter.castTargetString(fsb)); + } + + @Test + void date64CastsToDate32() { + Field d64 = nullable("c", new ArrowType.Date(DateUnit.MILLISECOND)); + assertEquals(DataTypes.DateType, spark(d64)); + assertEquals("Date32", SchemaConverter.castTargetString(d64)); + + // Date32 is consumable -> no cast. + assertNull(SchemaConverter.castTargetString(nullable("c", new ArrowType.Date(DateUnit.DAY)))); + } + + @Test + void sparkTargetIsAlwaysConsumable() { + // The core invariant behind the plan-time guard: every widened target is Spark-consumable. + List samples = + List.of( + nullable("a", new ArrowType.Int(8, false)), + nullable("b", new ArrowType.Int(64, false)), + nullable("c", new ArrowType.FloatingPoint(FloatingPointPrecision.HALF)), + nullable("d", new ArrowType.Timestamp(TimeUnit.NANOSECOND, "UTC")), + nullable("e", new ArrowType.Date(DateUnit.MILLISECOND)), + nullable("f", new ArrowType.Time(TimeUnit.MICROSECOND, 64)), + nullable("g", new ArrowType.FixedSizeBinary(4)), + new Field( + "h", + FieldType.nullable(new ArrowType.FixedSizeList(2)), + List.of( + nullable("item", new ArrowType.FloatingPoint(FloatingPointPrecision.HALF)))), + new Field( + "i", + FieldType.nullable(new ArrowType.LargeList()), + List.of(nullable("item", new ArrowType.Int(16, false)))), + new Field( + "j", + FieldType.nullable(ArrowType.Struct.INSTANCE), + List.of(nullable("x", new ArrowType.Int(32, false))))); + for (Field f : samples) { + assertConsumable(SchemaConverter.sparkTarget(f)); + } + } + + private static void assertConsumable(Field field) { + assertTrue( + SchemaConverter.sparkConsumable(field.getType()), + "sparkTarget produced a non-consumable type: " + field.getType()); + field.getChildren().forEach(SchemaConverterTest::assertConsumable); + } + + @Test + void unsupportedTypeFailsAtSchemaBuild() { + // An Interval has no accessor and no supported cast -> fail fast at inferSchema, naming it. + Schema schema = + new Schema(List.of(Field.nullable("dur", new ArrowType.Interval(IntervalUnit.DAY_TIME)))); + IllegalArgumentException e = + assertThrows(IllegalArgumentException.class, () -> SchemaConverter.toSparkSchema(schema)); + assertTrue(e.getMessage().contains("dur"), e.getMessage()); + } + + // --- schema-level: metadata flag + projection planning ------------------- + + @Test + void castColumnsTaggedInSchemaMetadata() { + Schema schema = + new Schema( + List.of( + Field.nullable("id", new ArrowType.Int(64, true)), + Field.nullable("channel", new ArrowType.Int(16, false)), + Field.nullable("ts", new ArrowType.Timestamp(TimeUnit.NANOSECOND, null)))); + StructType spark = SchemaConverter.toSparkSchema(schema); + + assertTrue(spark.apply("id").metadata().isEmpty()); + assertTrue(spark.apply("channel").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + assertTrue(spark.apply("ts").metadata().contains(SchemaConverter.CAST_METADATA_KEY)); + } + + @Test + void probeColumnPrefersCastlessThenFallsBackToCast() { + // A castless column is preferred (cheapest, no cast) for a column-less scan. + Schema mixed = + new Schema( + List.of( + Field.nullable("channel", new ArrowType.Int(16, false)), + Field.nullable("id", new ArrowType.Int(64, true)))); + ProjectionColumn probe = SchemaConverter.probeColumn(mixed); + assertEquals("id", probe.name()); + assertNull(probe.castType()); + + // When every column needs a cast, the first column is used with its cast applied. + Schema allCast = + new Schema( + List.of( + Field.nullable("channel", new ArrowType.Int(16, false)), + Field.nullable("ts", new ArrowType.Timestamp(TimeUnit.NANOSECOND, null)))); + ProjectionColumn casted = SchemaConverter.probeColumn(allCast); + assertEquals("channel", casted.name()); + assertEquals("Int32", casted.castType()); + } + + @Test + void projectionColumnsCarryCastTargets() { + Schema schema = + new Schema( + List.of( + Field.nullable("id", new ArrowType.Int(64, true)), + Field.nullable("channel", new ArrowType.Int(16, false)))); + + // All columns. + List all = SchemaConverter.projectionColumns(schema, null); + assertEquals(2, all.size()); + assertEquals("id", all.get(0).name()); + assertNull(all.get(0).castType()); + assertEquals("channel", all.get(1).name()); + assertEquals("Int32", all.get(1).castType()); + + // Projected subset, order preserved. + List pruned = SchemaConverter.projectionColumns(schema, List.of("channel")); + assertEquals(1, pruned.size()); + assertEquals("channel", pruned.get(0).name()); + assertEquals("Int32", pruned.get(0).castType()); + } +} diff --git a/spark/src/test/java/org/apache/datafusion/spark/SqlQueryTest.java b/spark/src/test/java/org/apache/datafusion/spark/SqlQueryTest.java new file mode 100644 index 0000000..e9b6aa9 --- /dev/null +++ b/spark/src/test/java/org/apache/datafusion/spark/SqlQueryTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; +import java.util.OptionalLong; + +import org.apache.datafusion.spark.SchemaConverter.ProjectionColumn; +import org.apache.spark.sql.sources.Filter; +import org.junit.jupiter.api.Test; + +/** Verifies the SQL fallback wire, including {@code arrow_cast} injection for cast columns. */ +class SqlQueryTest { + + private static final List NO_FILTERS = List.of(); + + @Test + void selectStarWhenNoColumns() { + assertEquals( + "SELECT * FROM \"t\"", SqlQuery.build("t", null, NO_FILTERS, OptionalLong.empty())); + } + + @Test + void castColumnsWrappedInArrowCastAndAliased() { + List columns = + List.of( + new ProjectionColumn("id", null), + new ProjectionColumn("channel", "Int32"), + new ProjectionColumn("ts", "Timestamp(Microsecond)")); + assertEquals( + "SELECT \"id\", arrow_cast(\"channel\", 'Int32') AS \"channel\", " + + "arrow_cast(\"ts\", 'Timestamp(Microsecond)') AS \"ts\" FROM \"t\"", + SqlQuery.build("t", columns, NO_FILTERS, OptionalLong.empty())); + } +} diff --git a/spark/src/test/java/org/apache/datafusion/spark/SubstraitPlanTest.java b/spark/src/test/java/org/apache/datafusion/spark/SubstraitPlanTest.java new file mode 100644 index 0000000..5c8b44f --- /dev/null +++ b/spark/src/test/java/org/apache/datafusion/spark/SubstraitPlanTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion.spark; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import java.util.OptionalLong; +import java.util.Set; + +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.spark.sql.sources.EqualTo; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.StringContains; +import org.junit.jupiter.api.Test; + +import io.substrait.relation.Fetch; +import io.substrait.relation.Filter; +import io.substrait.relation.NamedScan; +import io.substrait.relation.Project; +import io.substrait.relation.Rel; + +/** + * Verifies the Substrait pushdown encoding without needing the native driver: build a plan, then + * re-parse the proto bytes and assert the relation tree and the standard function extensions. + */ +class SubstraitPlanTest { + + private static Schema schema() { + return new Schema( + List.of( + Field.nullable("id", new ArrowType.Int(64, true)), + Field.nullable("name", ArrowType.Utf8.INSTANCE))); + } + + @Test + void buildsScanFilterProjectFetch() throws Exception { + List filters = + List.of(new GreaterThan("id", 1L), new EqualTo("name", "bob")); + byte[] bytes = SubstraitPlan.build("t", schema(), List.of("name"), filters, OptionalLong.of(2)); + + // Re-parse the proto and walk it back into the substrait-java model. + io.substrait.proto.Plan proto = io.substrait.proto.Plan.parseFrom(bytes); + io.substrait.plan.Plan plan = new io.substrait.plan.ProtoPlanConverter().from(proto); + + Rel top = plan.getRoots().get(0).getInput(); + // Fetch(limit) -> Project(projection) -> Filter(predicates) -> NamedScan(table). + Fetch fetch = (Fetch) top; + assertEquals(OptionalLong.of(2), fetch.getCount()); + Project project = (Project) fetch.getInput(); + assertEquals(1, project.getExpressions().size()); // only the kept column "name" + Filter filter = (Filter) project.getInput(); + NamedScan scan = (NamedScan) filter.getInput(); + assertEquals(List.of("t"), scan.getNames()); + + // The interop guarantee: predicates resolve against the STANDARD catalog, the + // same function URIs DataFusion's Substrait consumer recognizes. + String text = proto.toString(); + assertTrue(text.contains("functions_comparison"), "uses standard comparison functions"); + assertTrue(text.contains("functions_boolean"), "uses standard boolean functions"); + } + + @Test + void wholeTableScanWhenNoPushdown() throws Exception { + byte[] bytes = SubstraitPlan.build("t", schema(), null, List.of(), OptionalLong.empty()); + io.substrait.proto.Plan proto = io.substrait.proto.Plan.parseFrom(bytes); + io.substrait.plan.Plan plan = new io.substrait.plan.ProtoPlanConverter().from(proto); + assertTrue(plan.getRoots().get(0).getInput() instanceof NamedScan); + } + + @Test + void canPushWhitelist() { + Set columns = Set.of("id", "name"); + assertTrue(SubstraitPlan.canPush(new GreaterThan("id", 1L), columns)); + assertTrue(SubstraitPlan.canPush(new EqualTo("name", "bob"), columns)); + // Unknown column -> not pushable. + assertFalse(SubstraitPlan.canPush(new GreaterThan("missing", 1L), columns)); + // Unsupported predicate shape -> not pushable. + assertFalse(SubstraitPlan.canPush(new StringContains("name", "b"), columns)); + } +}