From 62487d3dfc9b6b6705ac74b393cd362c749c5e22 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Thu, 14 May 2026 15:10:20 +0000 Subject: [PATCH 01/16] feat(backend/kernel): route use_sea=True through the Rust kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 2 of the PySQL × kernel integration plan (databricks-sql-kernel/docs/designs/pysql-kernel-integration.md). Wires `use_sea=True` to a new `backend/kernel/` module that delegates to the Rust kernel via the `databricks_sql_kernel` PyO3 extension (kernel PR #13). New module: `src/databricks/sql/backend/kernel/` - `client.py` — `KernelDatabricksClient(DatabricksClient)`. Lazy- imports `databricks_sql_kernel` so a connector install without the kernel wheel doesn't `ImportError` at startup; only `use_sea=True` surfaces the missing-extra message. Implements open/close_session, sync + async execute_command (async_op=True goes through `Statement.submit()` and stashes the handle in a dict keyed on `CommandId`), cancel/close_command, get_query_state, get_execution_result, and the metadata calls (catalogs / schemas / tables / columns) via `Session.metadata().list_*`. Real server-issued session and statement IDs flow through (no synthetic UUIDs). - `auth_bridge.py` — translate the connector's `AuthProvider` into kernel `Session` kwargs. PAT (including federation-wrapped PAT — `get_python_sql_connector_auth_provider` always wraps the base in `TokenFederationProvider`, so a naive isinstance check never matches) routes through `auth_type="pat"`. Everything else routes through `auth_type="external"` with a callback that delegates to `auth_provider.add_headers({})`. (External today is rejected by the kernel at `build_auth_provider`; the separate kernel-side enablement PR will flip it on.) - `result_set.py` — `KernelResultSet(ResultSet)`. Duck-typed over `databricks_sql_kernel.ExecutedStatement` (sync execute) and `ResultStream` (metadata + async await_result) since both expose `arrow_schema()` / `fetch_next_batch()` / `fetch_all_arrow()` / `close()`. Same FIFO batch buffer the prior ADBC POC used, so `fetchmany(n)` for n smaller than the kernel's natural batch size doesn't re-fetch. - `type_mapping.py` — Arrow → PEP 249 description-string mapper. Lifted from the prior ADBC POC; centralised here so future kernel-result wrappers reuse the same mapping. Kernel errors → PEP 249 exceptions: `KernelError.code` is mapped in a single table to `ProgrammingError` / `OperationalError` / `DatabaseError`. The structured fields (`sql_state`, `error_code`, `query_id`, …) are copied onto the re-raised exception so callers can branch on them without reaching through `__cause__`. Routing: `Session._create_backend` flips the `use_sea=True` branch to instantiate `KernelDatabricksClient` instead of the native `SeaDatabricksClient`. The native `backend/sea/` module is left in place (no users on `use_sea=True` after this PR; its long- term fate is out of scope here). Packaging: `[tool.poetry.extras] kernel = ["databricks-sql-kernel"]`. `pip install 'databricks-sql-connector[kernel]'` pulls in the kernel wheel; `use_sea=True` without the extra raises a pointed ImportError telling the user how to install it. Known gaps (acknowledged, will be follow-ups): - Parameter binding (`execute_command(parameters=[...])`) raises NotSupportedError — PyO3 `Statement.bind_param` lands in a follow-up. - Statement-level `query_tags` raises NotSupportedError. - `get_tables(table_types=[...])` returns unfiltered rows (the native SEA backend's filter is keyed on `SeaResultSet`; needs a small port to operate on `KernelResultSet`). - External-auth end-to-end blocked on the kernel-side `AuthConfig::External` enablement PR. - Volume PUT/GET (staging operations): kernel has no Volume API. Test plan: - Unit: 37 new tests across `tests/unit/test_kernel_auth_bridge.py` (auth provider → kwargs mapping, including federation-wrapped PAT and the External trampoline call-counter check), `tests/unit/test_kernel_type_mapping.py` (Arrow type mapping + description shape), and `tests/unit/test_kernel_result_set.py` (buffer semantics, fetchmany across batch boundaries, idempotent close, close() swallowing handle-close failures). All pass. - Full unit suite: 600 pre-existing tests still pass; one pre-existing failure (`test_useragent_header` — agent detection adds `agent/claude-code` in this env) was already failing on main, unrelated to this change. - Live e2e against dogfood with `use_sea=True`: SELECT 1, `range(10000)`, `fetchmany` pacing, `fetchall_arrow`, all four metadata calls (returned 75 catalogs / 144 schemas in main / 47 tables in `system.information_schema` / 15 columns), `session_configuration={'ANSI_MODE': 'false'}` round-trips, bad SQL surfaces as DatabaseError with `code='SqlError'` and `sql_state='42P01'` on the exception. All checks pass. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- pyproject.toml | 6 + src/databricks/sql/backend/kernel/__init__.py | 15 + .../sql/backend/kernel/auth_bridge.py | 131 +++++ src/databricks/sql/backend/kernel/client.py | 503 ++++++++++++++++++ .../sql/backend/kernel/result_set.py | 220 ++++++++ .../sql/backend/kernel/type_mapping.py | 71 +++ src/databricks/sql/session.py | 34 +- tests/unit/test_kernel_auth_bridge.py | 116 ++++ tests/unit/test_kernel_result_set.py | 165 ++++++ tests/unit/test_kernel_type_mapping.py | 68 +++ 10 files changed, 1321 insertions(+), 8 deletions(-) create mode 100644 src/databricks/sql/backend/kernel/__init__.py create mode 100644 src/databricks/sql/backend/kernel/auth_bridge.py create mode 100644 src/databricks/sql/backend/kernel/client.py create mode 100644 src/databricks/sql/backend/kernel/result_set.py create mode 100644 src/databricks/sql/backend/kernel/type_mapping.py create mode 100644 tests/unit/test_kernel_auth_bridge.py create mode 100644 tests/unit/test_kernel_result_set.py create mode 100644 tests/unit/test_kernel_type_mapping.py diff --git a/pyproject.toml b/pyproject.toml index 5e9f7f0ca..a436132c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,10 +32,16 @@ pyarrow = [ pyjwt = "^2.0.0" pybreaker = "^1.0.0" requests-kerberos = {version = "^0.15.0", optional = true} +# Optional kernel backend: `pip install 'databricks-sql-connector[kernel]'` +# unlocks use_sea=True, which routes through the Rust kernel via PyO3. +# Without it, use_sea=True raises a pointed ImportError. The kernel +# wheel itself ships from the databricks-sql-kernel repo. +databricks-sql-kernel = {version = "^0.1.0", optional = true} [tool.poetry.extras] pyarrow = ["pyarrow"] +kernel = ["databricks-sql-kernel"] [tool.poetry.group.dev.dependencies] pytest = "^7.1.2" diff --git a/src/databricks/sql/backend/kernel/__init__.py b/src/databricks/sql/backend/kernel/__init__.py new file mode 100644 index 000000000..a0de1861c --- /dev/null +++ b/src/databricks/sql/backend/kernel/__init__.py @@ -0,0 +1,15 @@ +"""Backend that delegates to the Databricks SQL Kernel (Rust) via PyO3. + +Routed when ``use_sea=True`` is passed to ``databricks.sql.connect``. +The module's identity is "delegates to the kernel" — not the wire +protocol the kernel happens to use today (SEA REST). The kernel may +switch its default transport (SEA REST → SEA gRPC → …) without +renaming this module. + +See ``docs/designs/pysql-kernel-integration.md`` in +``databricks-sql-kernel`` for the full integration design. +""" + +from databricks.sql.backend.kernel.client import KernelDatabricksClient + +__all__ = ["KernelDatabricksClient"] diff --git a/src/databricks/sql/backend/kernel/auth_bridge.py b/src/databricks/sql/backend/kernel/auth_bridge.py new file mode 100644 index 000000000..1f14b8a5e --- /dev/null +++ b/src/databricks/sql/backend/kernel/auth_bridge.py @@ -0,0 +1,131 @@ +"""Translate the connector's ``AuthProvider`` into ``databricks_sql_kernel`` +``Session`` auth kwargs. + +The connector already implements every auth flow it supports (PAT, +OAuth M2M, OAuth U2M, external token providers, federation). The +kernel must not re-implement them. Decision D9 in the integration +design: PAT goes through the kernel's PAT path; everything else +delegates back to the connector via the kernel's ``External`` +trampoline, with a Python callback that returns a fresh bearer +token. + +Token extraction goes through ``AuthProvider.add_headers({})`` +rather than touching auth-provider-specific attributes, so the +bridge works for every subclass — including custom providers a +caller may have wired in. + +End-to-end limitation: the kernel's +``build_auth_provider`` currently rejects ``AuthConfig::External`` +("reserved; v0 wires PAT + OAuthM2M + OAuthU2M only"). Until the +kernel-side follow-up PR lands, non-PAT auth surfaces a clear +``KernelError(code='InvalidArgument', message='AuthConfig::External +is reserved...')`` from ``Session.open_session``. PAT works today. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional + +from databricks.sql.auth.authenticators import AccessTokenAuthProvider, AuthProvider +from databricks.sql.auth.token_federation import TokenFederationProvider + +logger = logging.getLogger(__name__) + + +_BEARER_PREFIX = "Bearer " + + +def _is_pat(auth_provider: AuthProvider) -> bool: + """Return True iff this provider ultimately wraps an + ``AccessTokenAuthProvider``. + + ``get_python_sql_connector_auth_provider`` always wraps the + base provider in a ``TokenFederationProvider``, so an + ``isinstance`` check against ``AccessTokenAuthProvider`` alone + never matches in practice. We peek through the federation + wrapper to find the real type. + """ + if isinstance(auth_provider, AccessTokenAuthProvider): + return True + if isinstance(auth_provider, TokenFederationProvider) and isinstance( + auth_provider.external_provider, AccessTokenAuthProvider + ): + return True + return False + + +def _extract_bearer_token(auth_provider: AuthProvider) -> Optional[str]: + """Pull the current bearer token out of an ``AuthProvider``. + + The connector's ``AuthProvider.add_headers`` mutates a header + dict and writes the ``Authorization: Bearer `` value. + Going through that public surface keeps us insulated from + provider-specific internals. + + Returns ``None`` if the provider did not write an Authorization + header or wrote a non-Bearer scheme — neither shape is + representable in the kernel's auth surface today. + """ + headers: Dict[str, str] = {} + auth_provider.add_headers(headers) + auth = headers.get("Authorization") + if not auth: + return None + if not auth.startswith(_BEARER_PREFIX): + return None + return auth[len(_BEARER_PREFIX) :] + + +def kernel_auth_kwargs(auth_provider: AuthProvider) -> Dict[str, Any]: + """Build the kwargs passed to ``databricks_sql_kernel.Session(...)``. + + Two routing decisions: + + 1. ``AccessTokenAuthProvider`` → ``auth_type='pat'`` with the + static token. Kernel uses it verbatim for every request. + 2. Anything else → ``auth_type='external'`` with a callback that + calls ``auth_provider.add_headers({})`` and returns the + fresh bearer token. The connector keeps owning the OAuth / + MSAL / federation flow; the kernel asks for a token whenever + it needs one. + + The PAT special-case exists because it's the only path the + kernel actually serves end-to-end today. Once the kernel-side + External enablement lands, PAT could collapse into the + External path too (one callback that returns the static token); + but keeping the explicit ``pat`` route means the kernel does + not pay the GIL-reacquire cost on every HTTP request for PAT + users. + """ + if _is_pat(auth_provider): + # PAT case: pull the static token out and feed the kernel's + # PAT path. We go through ``add_headers`` regardless of + # whether the provider was wrapped in TokenFederation or + # not — both shapes write the same Authorization header. + token = _extract_bearer_token(auth_provider) + if not token: + raise ValueError( + "PAT auth provider did not produce a Bearer Authorization " + "header; cannot route through the kernel's PAT path" + ) + return {"auth_type": "pat", "access_token": token} + + # Every other provider: trampoline a callback. The callback is + # invoked once per HTTP request that needs auth (the kernel does + # not cache the returned token), so the auth_provider's own + # caching is what keeps this fast. + def token_callback() -> str: + token = _extract_bearer_token(auth_provider) + if not token: + raise RuntimeError( + f"{type(auth_provider).__name__}.add_headers did not produce " + "a Bearer Authorization header; cannot supply a token to the kernel" + ) + return token + + logger.debug( + "Routing %s through kernel External trampoline", + type(auth_provider).__name__, + ) + return {"auth_type": "external", "token_callback": token_callback} diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py new file mode 100644 index 000000000..67b6a2cda --- /dev/null +++ b/src/databricks/sql/backend/kernel/client.py @@ -0,0 +1,503 @@ +"""``DatabricksClient`` backed by the Rust kernel via PyO3. + +Routed when ``use_sea=True``. Constructor takes the connector's +already-built ``auth_provider`` and forwards everything else to the +kernel's ``Session``. Every kernel call goes through this thin +wrapper; this module is the single seam between the connector's +``DatabricksClient`` contract and the kernel's Python surface. + +Errors map cleanly: ``KernelError`` from the kernel is inspected +for its ``code`` attribute and re-raised as the appropriate PEP +249 exception (``DatabaseError``, ``OperationalError``, +``ProgrammingError``, etc.). Connector callers see standard +exception types, never the underlying kernel error. + +Phase 1 gaps documented in the integration design: + +- Parameter binding (``parameters=[TSparkParameter, ...]``) is not + yet supported — the PyO3 ``Statement`` doesn't expose + ``bind_param``. ``execute_command(parameters=[...])`` raises + ``NotSupportedError``. +- ``query_tags`` on execute is not supported (kernel exposes + ``statement_conf`` but PyO3 doesn't surface it). +- ``get_tables`` with a non-empty ``table_types`` filter applies + the filter client-side; today the kernel returns the full + ``SHOW TABLES`` shape unchanged. The connector's existing + ``ResultSetFilter.filter_tables_by_type`` is keyed on + ``SeaResultSet`` not ``KernelResultSet``, so we punt and let + the caller see all rows — documented as a known gap in the + design doc. +- Volume PUT/GET (staging operations): kernel has no Volume API + yet. Users on Thrift-only paths. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union + +from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.backend.kernel.auth_bridge import kernel_auth_kwargs +from databricks.sql.backend.kernel.result_set import KernelResultSet +from databricks.sql.backend.types import ( + BackendType, + CommandId, + CommandState, + SessionId, +) +from databricks.sql.exc import ( + DatabaseError, + Error, + InterfaceError, + NotSupportedError, + OperationalError, + ProgrammingError, +) +from databricks.sql.thrift_api.TCLIService import ttypes + +if TYPE_CHECKING: + from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet + +logger = logging.getLogger(__name__) + + +try: + import databricks_sql_kernel as _kernel # type: ignore[import-not-found] +except ImportError as exc: # pragma: no cover - import-time error surfaces clearly + raise ImportError( + "use_sea=True requires the databricks-sql-kernel package. Install it with:\n" + " pip install 'databricks-sql-connector[kernel]'\n" + "or for local development from the kernel repo:\n" + " cd databricks-sql-kernel/pyo3 && maturin develop --release" + ) from exc + + +# ─── Error mapping ────────────────────────────────────────────────────────── + + +# Map a kernel `code` slug to the PEP 249 exception class that best +# captures it. The match isn't a perfect 1:1 — PEP 249 has a +# narrower taxonomy than the kernel — so several kernel codes +# collapse onto the same Python exception. This table is the only +# place that mapping lives. +_CODE_TO_EXCEPTION = { + "InvalidArgument": ProgrammingError, + "Unauthenticated": OperationalError, + "PermissionDenied": OperationalError, + "NotFound": ProgrammingError, + "ResourceExhausted": OperationalError, + "Unavailable": OperationalError, + "Timeout": OperationalError, + "Cancelled": OperationalError, + "DataLoss": DatabaseError, + "Internal": DatabaseError, + "InvalidStatementHandle": ProgrammingError, + "NetworkError": OperationalError, + "SqlError": DatabaseError, + "Unknown": DatabaseError, +} + + +def _reraise_kernel_error(exc: BaseException) -> "Error": + """Convert a ``databricks_sql_kernel.KernelError`` to a PEP 249 + exception. Other exception types fall through unchanged. + + Kernel errors carry their structured attrs (``code``, + ``message``, ``sql_state``, ``error_code``, ``query_id`` …) as + plain attributes — we copy them onto the re-raised exception so + callers can branch on them without reaching back through + ``__cause__``. + """ + if not isinstance(exc, _kernel.KernelError): + return exc # type: ignore[return-value] + code = getattr(exc, "code", "Unknown") + cls = _CODE_TO_EXCEPTION.get(code, DatabaseError) + new = cls(getattr(exc, "message", str(exc))) + # Forward the structured fields so connector users can read + # err.sql_state / err.query_id / etc. without a type-switch. + for attr in ( + "code", + "sql_state", + "error_code", + "vendor_code", + "http_status", + "retryable", + "query_id", + ): + try: + setattr(new, attr, getattr(exc, attr)) + except (AttributeError, TypeError): # pragma: no cover - defensive + pass + new.__cause__ = exc + return new + + +# ─── Client ───────────────────────────────────────────────────────────────── + + +class KernelDatabricksClient(DatabricksClient): + """``DatabricksClient`` that delegates to the Rust kernel. + + Owns one ``databricks_sql_kernel.Session`` per ``open_session`` + call. Async-execute handles (from ``submit()``) live in a dict + keyed on ``CommandId`` so the connector's polling APIs + (``get_query_state`` / ``get_execution_result`` / + ``cancel_command`` / ``close_command``) can find them again. + """ + + def __init__( + self, + server_hostname: str, + http_path: str, + auth_provider, + ssl_options, + catalog: Optional[str] = None, + schema: Optional[str] = None, + http_headers=None, + http_client=None, + _use_arrow_native_complex_types: Optional[bool] = True, + **kwargs, + ): + # The connector hands us several fields the kernel doesn't + # consume directly (ssl_options, http_headers, http_client, + # port, _use_arrow_native_complex_types). Kernel manages + # its own HTTP stack so we accept-and-ignore. + self._server_hostname = server_hostname + self._http_path = http_path + self._auth_provider = auth_provider + self._catalog = catalog + self._schema = schema + self._auth_kwargs = kernel_auth_kwargs(auth_provider) + # Open ``databricks_sql_kernel.Session`` lazily in + # ``open_session`` so the Session lifecycle gates the + # underlying connection setup — same shape as Thrift's + # ``TOpenSession``. + self._kernel_session: Optional[Any] = None + self._session_id: Optional[SessionId] = None + # Async-exec handles keyed by CommandId.guid. Populated by + # ``execute_command(async_op=True)``; drained by ``close_command``. + self._async_handles: Dict[str, Any] = {} + + # ── Session lifecycle ────────────────────────────────────────── + + def open_session( + self, + session_configuration: Optional[Dict[str, Any]], + catalog: Optional[str], + schema: Optional[str], + ) -> SessionId: + if self._kernel_session is not None: + raise InterfaceError("KernelDatabricksClient already has an open session.") + # ``session_configuration`` flows through to the kernel's + # ``session_conf`` map verbatim; the SEA endpoint enforces + # its own allow-list and rejects unknown keys. + session_conf: Optional[Dict[str, str]] = None + if session_configuration: + session_conf = {k: str(v) for k, v in session_configuration.items()} + try: + self._kernel_session = _kernel.Session( + host=self._server_hostname, + http_path=self._http_path, + catalog=catalog or self._catalog, + schema=schema or self._schema, + session_conf=session_conf, + **self._auth_kwargs, + ) + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + + # Use the kernel's real server-issued session id, not a + # synthetic UUID. Matches what the native SEA backend does. + self._session_id = SessionId.from_sea_session_id( + self._kernel_session.session_id + ) + logger.info("Opened kernel-backed session %s", self._session_id) + return self._session_id + + def close_session(self, session_id: SessionId) -> None: + if self._kernel_session is None: + return + # Close any tracked async handles first so they fire their + # server-side CloseStatement before the session goes away. + for handle in list(self._async_handles.values()): + try: + handle.close() + except _kernel.KernelError as exc: + logger.warning("Error closing async handle during session close: %s", exc) + self._async_handles.clear() + try: + self._kernel_session.close() + except _kernel.KernelError as exc: + # Surface as a non-fatal warning — the kernel's Drop + # impl will retry the close fire-and-forget. PEP 249 + # discourages raising from connection.close(). + logger.warning("Error closing kernel session: %s", exc) + self._kernel_session = None + self._session_id = None + + # ── Query execution ──────────────────────────────────────────── + + def execute_command( + self, + operation: str, + session_id: SessionId, + max_rows: int, + max_bytes: int, + lz4_compression: bool, + cursor: "Cursor", + use_cloud_fetch: bool, + parameters: List[ttypes.TSparkParameter], + async_op: bool, + enforce_embedded_schema_correctness: bool, + row_limit: Optional[int] = None, + query_tags: Optional[Dict[str, Optional[str]]] = None, + ) -> Union["ResultSet", None]: + if self._kernel_session is None: + raise InterfaceError("Cannot execute_command without an open session.") + if parameters: + raise NotSupportedError( + "Parameter binding is not yet supported on the kernel backend " + "(PyO3 Statement.bind_param lands in a follow-up PR)." + ) + if query_tags: + raise NotSupportedError( + "Statement-level query_tags are not yet supported on the kernel backend." + ) + + stmt = self._kernel_session.statement() + try: + stmt.set_sql(operation) + if async_op: + async_exec = stmt.submit() + command_id = CommandId.from_sea_statement_id(async_exec.statement_id) + cursor.active_command_id = command_id + self._async_handles[command_id.guid] = async_exec + return None + executed = stmt.execute() + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + finally: + # ``Statement`` is a lifecycle owner separate from the + # executed handle it produces. Drop it here so the + # parent doesn't keep the handle alive longer than the + # caller expects. + try: + stmt.close() + except _kernel.KernelError: + pass + + command_id = CommandId.from_sea_statement_id(executed.statement_id) + cursor.active_command_id = command_id + return KernelResultSet( + connection=cursor.connection, + backend=self, + kernel_handle=executed, + command_id=command_id, + arraysize=cursor.arraysize, + buffer_size_bytes=cursor.buffer_size_bytes, + ) + + def cancel_command(self, command_id: CommandId) -> None: + handle = self._async_handles.get(command_id.guid) + if handle is None: + # Sync-execute paths fully materialise the result before + # ``execute_command`` returns, so by the time + # cancel_command can fire there's nothing in flight. + # Match the Thrift backend's tolerant behaviour. + logger.debug("cancel_command: no in-flight async handle for %s", command_id) + return + try: + handle.cancel() + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + + def close_command(self, command_id: CommandId) -> None: + handle = self._async_handles.pop(command_id.guid, None) + if handle is None: + logger.debug("close_command: no tracked handle for %s", command_id) + return + try: + handle.close() + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + + def get_query_state(self, command_id: CommandId) -> CommandState: + handle = self._async_handles.get(command_id.guid) + if handle is None: + # No tracked async handle means execute_command ran + # sync and the result was materialised before returning; + # the command is terminal by construction. + return CommandState.SUCCEEDED + try: + state, failure = handle.status() + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + if state == "Failed" and failure is not None: + # Surface server-reported failure as a database error so + # the cursor's polling loop terminates with the right + # exception class — matches the Thrift backend's + # behaviour on TOperationState::ERROR_STATE. + raise _reraise_kernel_error(failure) + return _STATE_TO_COMMAND_STATE.get(state, CommandState.FAILED) + + def get_execution_result( + self, + command_id: CommandId, + cursor: "Cursor", + ) -> "ResultSet": + handle = self._async_handles.get(command_id.guid) + if handle is None: + raise ProgrammingError( + "get_execution_result called for an unknown command_id; " + "the kernel backend only tracks async-submitted statements." + ) + try: + stream = handle.await_result() + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + return KernelResultSet( + connection=cursor.connection, + backend=self, + kernel_handle=stream, + command_id=command_id, + arraysize=cursor.arraysize, + buffer_size_bytes=cursor.buffer_size_bytes, + ) + + # ── Metadata ─────────────────────────────────────────────────── + + def _metadata_result(self, stream, cursor, command_id): + return KernelResultSet( + connection=cursor.connection, + backend=self, + kernel_handle=stream, + command_id=command_id, + arraysize=cursor.arraysize, + buffer_size_bytes=cursor.buffer_size_bytes, + ) + + def _synthetic_command_id(self) -> CommandId: + """Metadata calls don't produce a server statement id; mint + a synthetic one so the ``ResultSet`` still has a stable + identifier the cursor can attribute logs to.""" + return CommandId.from_sea_statement_id(f"metadata-{uuid.uuid4()}") + + def get_catalogs( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + ) -> "ResultSet": + if self._kernel_session is None: + raise InterfaceError("get_catalogs requires an open session.") + try: + stream = self._kernel_session.metadata().list_catalogs() + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + return self._metadata_result(stream, cursor, self._synthetic_command_id()) + + def get_schemas( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + ) -> "ResultSet": + if self._kernel_session is None: + raise InterfaceError("get_schemas requires an open session.") + try: + stream = self._kernel_session.metadata().list_schemas( + catalog=catalog_name, + schema_pattern=schema_name, + ) + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + return self._metadata_result(stream, cursor, self._synthetic_command_id()) + + def get_tables( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + table_types: Optional[List[str]] = None, + ) -> "ResultSet": + if self._kernel_session is None: + raise InterfaceError("get_tables requires an open session.") + if table_types: + # Documented gap: native SEA backend filters here, but + # its filter is keyed on SeaResultSet. Day-1 we surface + # the unfiltered result; a small follow-up ports the + # filter to operate on KernelResultSet. + logger.warning( + "get_tables: client-side table_types filter not yet implemented " + "on the kernel backend; returning unfiltered rows for %r", + table_types, + ) + try: + stream = self._kernel_session.metadata().list_tables( + catalog=catalog_name, + schema_pattern=schema_name, + table_pattern=table_name, + table_types=table_types, + ) + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + return self._metadata_result(stream, cursor, self._synthetic_command_id()) + + def get_columns( + self, + session_id: SessionId, + max_rows: int, + max_bytes: int, + cursor: "Cursor", + catalog_name: Optional[str] = None, + schema_name: Optional[str] = None, + table_name: Optional[str] = None, + column_name: Optional[str] = None, + ) -> "ResultSet": + if self._kernel_session is None: + raise InterfaceError("get_columns requires an open session.") + if not catalog_name: + # Kernel's list_columns requires a catalog (SEA `SHOW + # COLUMNS` cannot span catalogs). Surface the constraint + # explicitly rather than letting the kernel error. + raise ProgrammingError("get_columns requires catalog_name on the kernel backend.") + try: + stream = self._kernel_session.metadata().list_columns( + catalog=catalog_name, + schema_pattern=schema_name, + table_pattern=table_name, + column_pattern=column_name, + ) + except _kernel.KernelError as exc: + raise _reraise_kernel_error(exc) + return self._metadata_result(stream, cursor, self._synthetic_command_id()) + + # ── Misc ─────────────────────────────────────────────────────── + + @property + def max_download_threads(self) -> int: + # CloudFetch parallelism lives kernel-side. This property is + # consulted by Thrift code paths that don't run for + # use_sea=True; return a non-zero default so anything that + # peeks at it does not divide by zero. + return 10 + + +_STATE_TO_COMMAND_STATE: Dict[str, CommandState] = { + "Pending": CommandState.PENDING, + "Running": CommandState.RUNNING, + "Succeeded": CommandState.SUCCEEDED, + "Failed": CommandState.FAILED, + "Cancelled": CommandState.CANCELLED, + "Closed": CommandState.CLOSED, +} diff --git a/src/databricks/sql/backend/kernel/result_set.py b/src/databricks/sql/backend/kernel/result_set.py new file mode 100644 index 000000000..d6a0e8588 --- /dev/null +++ b/src/databricks/sql/backend/kernel/result_set.py @@ -0,0 +1,220 @@ +"""Streaming ``ResultSet`` over a kernel ``ExecutedStatement`` or +``ResultStream``. + +The kernel surfaces two flavours of result-bearing handle: + +- ``ExecutedStatement`` — returned by ``Statement.execute()``. Has a + ``statement_id`` and a ``cancel()`` method. +- ``ResultStream`` — returned by ``Session.metadata().list_*`` and by + ``ExecutedAsyncStatement.await_result()``. No statement id; no + cancel. + +Both implement the same three methods this class actually calls: +``arrow_schema() / fetch_next_batch() / fetch_all_arrow() / close()``. +``KernelResultSet`` takes either via the ``kernel_handle`` parameter +and treats them uniformly — the connector's ``ResultSet`` contract +doesn't need to distinguish them. + +Buffer shape mirrors the prior ADBC POC's ``AdbcResultSet``: a FIFO +of pyarrow ``RecordBatch``es, fed one batch at a time from the +kernel as the connector calls ``fetch*``. ``fetchmany(n)`` slices +within a batch when ``n`` is smaller than the kernel's natural +batch size; ``fetchall`` drains the whole stream. +""" + +from __future__ import annotations + +import logging +from collections import deque +from typing import Any, Deque, List, Optional, TYPE_CHECKING + +import pyarrow + +from databricks.sql.backend.kernel.type_mapping import description_from_arrow_schema +from databricks.sql.backend.types import CommandId, CommandState +from databricks.sql.result_set import ResultSet +from databricks.sql.types import Row + +if TYPE_CHECKING: + from databricks.sql.client import Connection + from databricks.sql.backend.kernel.client import KernelDatabricksClient + +logger = logging.getLogger(__name__) + + +class KernelResultSet(ResultSet): + """Streaming ``ResultSet`` over a kernel handle. + + The ``kernel_handle`` is duck-typed: it must implement + ``arrow_schema() -> pyarrow.Schema``, ``fetch_next_batch() -> + Optional[pyarrow.RecordBatch]``, and ``close() -> None``. + Both ``databricks_sql_kernel.ExecutedStatement`` and + ``databricks_sql_kernel.ResultStream`` satisfy that contract. + """ + + def __init__( + self, + connection: "Connection", + backend: "KernelDatabricksClient", + kernel_handle: Any, + command_id: CommandId, + arraysize: int, + buffer_size_bytes: int, + ): + schema = kernel_handle.arrow_schema() + super().__init__( + connection=connection, + backend=backend, + arraysize=arraysize, + buffer_size_bytes=buffer_size_bytes, + command_id=command_id, + status=CommandState.RUNNING, + has_been_closed_server_side=False, + has_more_rows=True, + results_queue=None, + description=description_from_arrow_schema(schema), + is_staging_operation=False, + lz4_compressed=False, + arrow_schema_bytes=None, + ) + self._kernel_handle = kernel_handle + self._schema: pyarrow.Schema = schema + # FIFO of record batches plus a per-head row offset, so + # partial fetches (fetchmany(n) for n < batch_size) don't + # re-fetch from the kernel. + self._buffer: Deque[pyarrow.RecordBatch] = deque() + self._buffer_offset: int = 0 + self._exhausted: bool = False + + # ----- internal helpers ----- + + def _pull_one_batch(self) -> bool: + """Pull the next batch from the kernel into the local buffer. + Returns True if a batch was added; False if the kernel side + is exhausted.""" + if self._exhausted: + return False + batch = self._kernel_handle.fetch_next_batch() + if batch is None: + self._exhausted = True + self.has_more_rows = False + self.status = CommandState.SUCCEEDED + return False + if batch.num_rows > 0: + self._buffer.append(batch) + return True + + def _ensure_buffered(self, n_rows: int) -> int: + """Pull batches until ``n_rows`` are buffered or the kernel + is exhausted. Returns total rows currently buffered.""" + while self._buffered_rows() < n_rows: + if not self._pull_one_batch(): + break + return self._buffered_rows() + + def _buffered_rows(self) -> int: + if not self._buffer: + return 0 + first = self._buffer[0].num_rows - self._buffer_offset + rest = sum(b.num_rows for b in list(self._buffer)[1:]) + return first + rest + + def _take_buffered(self, n: int) -> pyarrow.Table: + """Slice up to ``n`` rows out of the buffer; advances state.""" + slices: List[pyarrow.RecordBatch] = [] + remaining = n + while remaining > 0 and self._buffer: + head = self._buffer[0] + avail = head.num_rows - self._buffer_offset + take = min(avail, remaining) + slices.append(head.slice(self._buffer_offset, take)) + self._buffer_offset += take + remaining -= take + if self._buffer_offset >= head.num_rows: + self._buffer.popleft() + self._buffer_offset = 0 + self._next_row_index += n - remaining + if not slices: + return pyarrow.Table.from_batches([], schema=self._schema) + return pyarrow.Table.from_batches(slices, schema=self._schema) + + def _drain(self) -> pyarrow.Table: + """Consume everything left in the buffer + kernel stream + and return as a single Table.""" + chunks: List[pyarrow.RecordBatch] = [] + if self._buffer and self._buffer_offset > 0: + head = self._buffer.popleft() + chunks.append(head.slice(self._buffer_offset, head.num_rows - self._buffer_offset)) + self._buffer_offset = 0 + while self._buffer: + chunks.append(self._buffer.popleft()) + if not self._exhausted: + while True: + batch = self._kernel_handle.fetch_next_batch() + if batch is None: + self._exhausted = True + self.has_more_rows = False + self.status = CommandState.SUCCEEDED + break + if batch.num_rows > 0: + chunks.append(batch) + rows = sum(c.num_rows for c in chunks) + self._next_row_index += rows + if not chunks: + return pyarrow.Table.from_batches([], schema=self._schema) + return pyarrow.Table.from_batches(chunks, schema=self._schema) + + # ----- Arrow fetches ----- + + def fetchall_arrow(self) -> pyarrow.Table: + return self._drain() + + def fetchmany_arrow(self, size: int) -> pyarrow.Table: + if size < 0: + raise ValueError(f"fetchmany_arrow size must be >= 0, got {size}") + if size == 0: + return pyarrow.Table.from_batches([], schema=self._schema) + self._ensure_buffered(size) + return self._take_buffered(size) + + # ----- Row fetches ----- + + def fetchone(self) -> Optional[Row]: + self._ensure_buffered(1) + if self._buffered_rows() == 0: + return None + table = self._take_buffered(1) + rows = self._convert_arrow_table(table) + return rows[0] if rows else None + + def fetchmany(self, size: int) -> List[Row]: + if size < 0: + raise ValueError(f"fetchmany size must be >= 0, got {size}") + if size == 0: + return [] + self._ensure_buffered(size) + table = self._take_buffered(size) + return self._convert_arrow_table(table) + + def fetchall(self) -> List[Row]: + return self._convert_arrow_table(self._drain()) + + def close(self) -> None: + """Close the underlying kernel handle. Idempotent — the + kernel's own ``close()`` is idempotent, and we guard against + repeated calls so partially-drained streams don't double- + decrement reference counts.""" + if self._kernel_handle is None: + return + try: + self._kernel_handle.close() + except Exception as exc: + # close() failures are not actionable at the connector + # level; log and swallow so the cursor's __del__ / + # connection close path stays clean. + logger.warning("Error closing kernel handle: %s", exc) + self._buffer.clear() + self._kernel_handle = None + self._exhausted = True + self.has_been_closed_server_side = True + self.status = CommandState.CLOSED diff --git a/src/databricks/sql/backend/kernel/type_mapping.py b/src/databricks/sql/backend/kernel/type_mapping.py new file mode 100644 index 000000000..bc4ffe5d2 --- /dev/null +++ b/src/databricks/sql/backend/kernel/type_mapping.py @@ -0,0 +1,71 @@ +"""Arrow ↔ PEP 249 type translation for the kernel backend. + +The kernel returns results as pyarrow ``Schema`` / ``RecordBatch``; +PEP 249 ``cursor.description`` is a list of 7-tuples with a +type-name string per column. ``description_from_arrow_schema`` +flattens the conversion so ``KernelResultSet`` and any future +kernel-result wrapper share the same mapping. + +Parameter binding (``TSparkParameter`` → kernel ``TypedValue``) is +not yet implemented — the PyO3 ``Statement`` doesn't expose a +``bind_param`` method on this branch. It'll land in a follow-up +once that PyO3 surface ships. +""" + +from __future__ import annotations + +from typing import List, Tuple + +import pyarrow + + +def _arrow_type_to_dbapi_string(arrow_type: pyarrow.DataType) -> str: + """Map a pyarrow type to the Databricks SQL type name used in + PEP 249 ``description``. Names match what the Thrift backend + produces so consumers can branch on them identically. + """ + if pyarrow.types.is_boolean(arrow_type): + return "boolean" + if pyarrow.types.is_int8(arrow_type): + return "tinyint" + if pyarrow.types.is_int16(arrow_type): + return "smallint" + if pyarrow.types.is_int32(arrow_type): + return "int" + if pyarrow.types.is_int64(arrow_type): + return "bigint" + if pyarrow.types.is_float32(arrow_type): + return "float" + if pyarrow.types.is_float64(arrow_type): + return "double" + if pyarrow.types.is_decimal(arrow_type): + return "decimal" + if pyarrow.types.is_string(arrow_type) or pyarrow.types.is_large_string(arrow_type): + return "string" + if pyarrow.types.is_binary(arrow_type) or pyarrow.types.is_large_binary(arrow_type): + return "binary" + if pyarrow.types.is_date(arrow_type): + return "date" + if pyarrow.types.is_timestamp(arrow_type): + return "timestamp" + if pyarrow.types.is_list(arrow_type) or pyarrow.types.is_large_list(arrow_type): + return "array" + if pyarrow.types.is_struct(arrow_type): + return "struct" + if pyarrow.types.is_map(arrow_type): + return "map" + return str(arrow_type) + + +def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]: + """Build a PEP 249 ``description`` list from a pyarrow Schema. + + Each tuple is ``(name, type_code, display_size, internal_size, + precision, scale, null_ok)``. The kernel does not report the + last five so they're all ``None`` — same shape the existing + ADBC / Thrift result paths produce. + """ + return [ + (field.name, _arrow_type_to_dbapi_string(field.type), None, None, None, None, None) + for field in schema + ] diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 65c0d6aca..be2bdb4c2 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -9,7 +9,6 @@ from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME from databricks.sql.backend.thrift_backend import ThriftDatabricksClient -from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.common.unified_http_client import UnifiedHttpClient @@ -123,14 +122,33 @@ def _create_backend( """Create and return the appropriate backend client.""" self.use_sea = kwargs.get("use_sea", False) - databricks_client_class: Type[DatabricksClient] if self.use_sea: - logger.debug("Creating SEA backend client") - databricks_client_class = SeaDatabricksClient - else: - logger.debug("Creating Thrift backend client") - databricks_client_class = ThriftDatabricksClient + # `use_sea=True` now routes through the Rust kernel via + # PyO3. The native pure-Python SEA backend + # (`backend/sea/`) is no longer reachable through this + # flag; whether it's removed is tracked separately. See + # `docs/designs/pysql-kernel-integration.md` in the + # databricks-sql-kernel repo. + # + # Lazy import so the connector doesn't ImportError at + # startup when the kernel wheel isn't installed — the + # error surfaces only when a caller actually requests + # use_sea=True. + from databricks.sql.backend.kernel.client import KernelDatabricksClient + + logger.debug("Creating kernel-backed client for use_sea=True") + return KernelDatabricksClient( + server_hostname=server_hostname, + http_path=http_path, + http_headers=all_headers, + auth_provider=auth_provider, + ssl_options=self.ssl_options, + http_client=self.http_client, + catalog=kwargs.get("catalog"), + schema=kwargs.get("schema"), + ) + logger.debug("Creating Thrift backend client") common_args = { "server_hostname": server_hostname, "port": self.port, @@ -142,7 +160,7 @@ def _create_backend( "_use_arrow_native_complex_types": _use_arrow_native_complex_types, **kwargs, } - return databricks_client_class(**common_args) + return ThriftDatabricksClient(**common_args) @staticmethod def _extract_spog_headers(http_path, existing_headers): diff --git a/tests/unit/test_kernel_auth_bridge.py b/tests/unit/test_kernel_auth_bridge.py new file mode 100644 index 000000000..920e94202 --- /dev/null +++ b/tests/unit/test_kernel_auth_bridge.py @@ -0,0 +1,116 @@ +"""Unit tests for the kernel backend's auth bridge. + +The bridge translates the connector's ``AuthProvider`` hierarchy +into ``databricks_sql_kernel.Session`` auth kwargs. PAT goes through +the kernel's PAT path; everything else trampolines through the +``External`` path with a Python callback. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from databricks.sql.auth.authenticators import AccessTokenAuthProvider, AuthProvider +from databricks.sql.backend.kernel.auth_bridge import ( + _extract_bearer_token, + kernel_auth_kwargs, +) + + +class _FakeOAuthProvider(AuthProvider): + """Stand-in for OAuth/MSAL/federation providers — anything that + isn't ``AccessTokenAuthProvider``. Returns a counter-stamped + token so tests can prove the callback is invoked each call.""" + + def __init__(self): + self.calls = 0 + + def add_headers(self, request_headers): + self.calls += 1 + request_headers["Authorization"] = f"Bearer token-{self.calls}" + + +class _MalformedProvider(AuthProvider): + """Provider that returns a non-Bearer Authorization header + (e.g. Basic auth). The bridge should reject this rather than + silently sending the wrong shape to the kernel.""" + + def add_headers(self, request_headers): + request_headers["Authorization"] = "Basic dXNlcjpwYXNz" + + +class _SilentProvider(AuthProvider): + """Provider that writes nothing — represents misconfigured + auth or a placeholder. The bridge must surface this clearly.""" + + def add_headers(self, request_headers): + pass + + +class TestExtractBearerToken: + def test_pat_provider_returns_token(self): + p = AccessTokenAuthProvider("dapi-abc-123") + assert _extract_bearer_token(p) == "dapi-abc-123" + + def test_non_bearer_auth_returns_none(self): + assert _extract_bearer_token(_MalformedProvider()) is None + + def test_silent_provider_returns_none(self): + assert _extract_bearer_token(_SilentProvider()) is None + + +class TestKernelAuthKwargs: + def test_pat_routes_to_kernel_pat(self): + kwargs = kernel_auth_kwargs(AccessTokenAuthProvider("dapi-xyz")) + assert kwargs == {"auth_type": "pat", "access_token": "dapi-xyz"} + + def test_federation_wrapped_pat_routes_to_kernel_pat(self): + """``get_python_sql_connector_auth_provider`` always wraps + the base provider in a ``TokenFederationProvider``, so the + PAT case never reaches us unwrapped in practice. The bridge + must look through the federation wrapper to find the + underlying ``AccessTokenAuthProvider``.""" + from databricks.sql.auth.token_federation import TokenFederationProvider + + # TokenFederationProvider needs an http_client; a MagicMock + # is sufficient because we don't trigger any token exchange + # in the test (the cached-token path is never hit). + base = AccessTokenAuthProvider("dapi-abc") + federated = TokenFederationProvider.__new__(TokenFederationProvider) + federated.external_provider = base + # The bridge only touches `add_headers` (delegated to the + # base) and `external_provider`. Other attrs would be set + # by __init__ but aren't exercised here. + federated.add_headers = base.add_headers + kwargs = kernel_auth_kwargs(federated) + assert kwargs == {"auth_type": "pat", "access_token": "dapi-abc"} + + def test_pat_with_silent_provider_raises(self): + """An AccessTokenAuthProvider that produces no Authorization + header is misconfigured; surface that at bridge-build time, + not on the first kernel HTTP request.""" + broken = AccessTokenAuthProvider("dapi-x") + # Force the broken state by monkey-patching add_headers. + broken.add_headers = lambda h: None # type: ignore[method-assign] + with pytest.raises(ValueError, match="Bearer"): + kernel_auth_kwargs(broken) + + def test_oauth_routes_to_external_trampoline(self): + provider = _FakeOAuthProvider() + kwargs = kernel_auth_kwargs(provider) + assert kwargs["auth_type"] == "external" + callback = kwargs["token_callback"] + assert callable(callback) + # First call -> token-1, second call -> token-2. Proves the + # callback delegates to the live auth_provider each time + # rather than caching. + assert callback() == "token-1" + assert callback() == "token-2" + assert provider.calls == 2 + + def test_external_callback_raises_on_missing_header(self): + kwargs = kernel_auth_kwargs(_SilentProvider()) + with pytest.raises(RuntimeError, match="Bearer"): + kwargs["token_callback"]() diff --git a/tests/unit/test_kernel_result_set.py b/tests/unit/test_kernel_result_set.py new file mode 100644 index 000000000..7a4023193 --- /dev/null +++ b/tests/unit/test_kernel_result_set.py @@ -0,0 +1,165 @@ +"""Unit tests for ``KernelResultSet`` — the buffer behavior + +close() semantics. Uses a fake kernel handle so tests run with no +network and no Rust extension dependency.""" + +from __future__ import annotations + +from collections import deque +from typing import Deque +from unittest.mock import MagicMock + +import pyarrow as pa +import pytest + +from databricks.sql.backend.kernel.result_set import KernelResultSet +from databricks.sql.backend.types import CommandId, CommandState + + +class _FakeKernelHandle: + """Stand-in for ``databricks_sql_kernel.ExecutedStatement`` / + ``ResultStream``. Emits a configured list of ``RecordBatch``es + via ``fetch_next_batch`` and then returns ``None``.""" + + def __init__(self, schema: pa.Schema, batches): + self._schema = schema + self._batches: Deque[pa.RecordBatch] = deque(batches) + self.closed = False + + def arrow_schema(self) -> pa.Schema: + return self._schema + + def fetch_next_batch(self): + if self.closed: + raise RuntimeError("fetched after close") + if not self._batches: + return None + return self._batches.popleft() + + def close(self): + self.closed = True + + +def _make_rs(handle) -> KernelResultSet: + # The base ResultSet __init__ takes a `connection` ref it never + # actually dereferences during these buffer tests, so a Mock is + # fine. + connection = MagicMock() + backend = MagicMock() + return KernelResultSet( + connection=connection, + backend=backend, + kernel_handle=handle, + command_id=CommandId.from_sea_statement_id("smoke-test"), + arraysize=100, + buffer_size_bytes=1024, + ) + + +def _batch(schema: pa.Schema, values) -> pa.RecordBatch: + return pa.RecordBatch.from_arrays( + [pa.array(values, type=schema.field(0).type)], schema=schema + ) + + +# Renamed from `schema` -> `int_schema` because the connector's +# top-level conftest.py defines a session-scoped `schema` fixture +# for E2E tests; pytest's fixture-resolution complains about +# scope-mismatch if we shadow it with a function-scoped one here. +@pytest.fixture +def int_schema(): + return pa.schema([("n", pa.int64())]) + + +def test_description_built_from_kernel_schema(int_schema): + handle = _FakeKernelHandle(int_schema, []) + rs = _make_rs(handle) + assert rs.description == [("n", "bigint", None, None, None, None, None)] + + +def test_fetchall_arrow_drains_all_batches(int_schema): + handle = _FakeKernelHandle( + int_schema, [_batch(int_schema, [1, 2]), _batch(int_schema, [3, 4, 5])] + ) + rs = _make_rs(handle) + table = rs.fetchall_arrow() + assert table.num_rows == 5 + assert table.column(0).to_pylist() == [1, 2, 3, 4, 5] + assert rs.status == CommandState.SUCCEEDED + assert rs.has_more_rows is False + + +def test_fetchmany_arrow_slices_within_batch(int_schema): + handle = _FakeKernelHandle(int_schema, [_batch(int_schema, [10, 20, 30, 40])]) + rs = _make_rs(handle) + t1 = rs.fetchmany_arrow(2) + assert t1.num_rows == 2 and t1.column(0).to_pylist() == [10, 20] + t2 = rs.fetchmany_arrow(2) + assert t2.num_rows == 2 and t2.column(0).to_pylist() == [30, 40] + t3 = rs.fetchmany_arrow(2) + assert t3.num_rows == 0 + + +def test_fetchmany_arrow_spans_batch_boundary(int_schema): + handle = _FakeKernelHandle( + int_schema, + [_batch(int_schema, [1, 2]), _batch(int_schema, [3, 4]), _batch(int_schema, [5, 6])], + ) + rs = _make_rs(handle) + t = rs.fetchmany_arrow(5) + assert t.num_rows == 5 + assert t.column(0).to_pylist() == [1, 2, 3, 4, 5] + t = rs.fetchmany_arrow(2) + assert t.column(0).to_pylist() == [6] + + +def test_fetchone_returns_row_then_none(int_schema): + handle = _FakeKernelHandle(int_schema, [_batch(int_schema, [42])]) + rs = _make_rs(handle) + row = rs.fetchone() + assert row is not None + assert row[0] == 42 + assert rs.fetchone() is None + + +def test_fetchall_rows(int_schema): + handle = _FakeKernelHandle( + int_schema, [_batch(int_schema, [1, 2]), _batch(int_schema, [3])] + ) + rs = _make_rs(handle) + rows = rs.fetchall() + assert [r[0] for r in rows] == [1, 2, 3] + + +def test_fetchmany_negative_raises(int_schema): + rs = _make_rs(_FakeKernelHandle(int_schema, [])) + with pytest.raises(ValueError): + rs.fetchmany(-1) + with pytest.raises(ValueError): + rs.fetchmany_arrow(-1) + + +def test_close_is_idempotent_and_calls_handle(int_schema): + handle = _FakeKernelHandle(int_schema, [_batch(int_schema, [1])]) + rs = _make_rs(handle) + rs.close() + assert handle.closed is True + assert rs.status == CommandState.CLOSED + rs.close() # second call is a no-op (kernel handle is None) + + +def test_empty_stream(int_schema): + rs = _make_rs(_FakeKernelHandle(int_schema, [])) + assert rs.fetchone() is None + assert rs.fetchall_arrow().num_rows == 0 + assert rs.status == CommandState.SUCCEEDED + + +def test_close_swallows_handle_close_failures(int_schema): + """ResultSet.close() must not raise even if the kernel + handle's close() fails — PEP 249 discourages exceptions from + close paths (cursor/connection teardown depends on it).""" + handle = _FakeKernelHandle(int_schema, []) + handle.close = MagicMock(side_effect=RuntimeError("kernel boom")) + rs = _make_rs(handle) + rs.close() # must not raise + assert rs.status == CommandState.CLOSED diff --git a/tests/unit/test_kernel_type_mapping.py b/tests/unit/test_kernel_type_mapping.py new file mode 100644 index 000000000..3c6fe9b15 --- /dev/null +++ b/tests/unit/test_kernel_type_mapping.py @@ -0,0 +1,68 @@ +"""Unit tests for Arrow → PEP 249 description-string mapping.""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from databricks.sql.backend.kernel.type_mapping import ( + _arrow_type_to_dbapi_string, + description_from_arrow_schema, +) + + +@pytest.mark.parametrize( + "arrow_type, expected", + [ + (pa.bool_(), "boolean"), + (pa.int8(), "tinyint"), + (pa.int16(), "smallint"), + (pa.int32(), "int"), + (pa.int64(), "bigint"), + (pa.float32(), "float"), + (pa.float64(), "double"), + (pa.decimal128(10, 2), "decimal"), + (pa.string(), "string"), + (pa.large_string(), "string"), + (pa.binary(), "binary"), + (pa.large_binary(), "binary"), + (pa.date32(), "date"), + (pa.timestamp("us"), "timestamp"), + (pa.list_(pa.int32()), "array"), + (pa.large_list(pa.int32()), "array"), + (pa.struct([("a", pa.int32())]), "struct"), + (pa.map_(pa.string(), pa.int32()), "map"), + ], +) +def test_arrow_to_dbapi_known_types(arrow_type, expected): + assert _arrow_type_to_dbapi_string(arrow_type) == expected + + +def test_arrow_to_dbapi_unknown_falls_back_to_str(): + # null type isn't in the explicit list but should fall through + # to the default str() so unknown variants are still printable + # rather than silently misclassified. + assert _arrow_type_to_dbapi_string(pa.null()) == "null" + + +def test_description_from_schema_preserves_field_names_and_order(): + schema = pa.schema( + [ + ("user_id", pa.int64()), + ("name", pa.string()), + ("created_at", pa.timestamp("us")), + ] + ) + desc = description_from_arrow_schema(schema) + assert len(desc) == 3 + assert [(d[0], d[1]) for d in desc] == [ + ("user_id", "bigint"), + ("name", "string"), + ("created_at", "timestamp"), + ] + # PEP 249 says all 7-tuples; the last 5 slots are None for the + # kernel backend (we don't report display_size / precision / + # scale / nullability). + for d in desc: + assert len(d) == 7 + assert d[2:] == (None, None, None, None, None) From 90d1de931c6ea73e8b6cdec56b427eef1ad7bc97 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Thu, 14 May 2026 15:32:08 +0000 Subject: [PATCH 02/16] refactor(backend/kernel): PAT-only auth, drop External trampoline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The earlier auth_bridge routed OAuth/MSAL/federation through the kernel's External token-provider trampoline (a Python callable the kernel invoked per HTTP request). Removing that for now. Why: routing OAuth into the kernel inherently requires per-request token resolution to keep refresh working during a long-running session. Two viable mechanisms (kernel-native OAuth, or the External callback); both have costs (duplicate OAuth flows vs GIL-per-request). Punting the decision until there's actual demand on use_sea=True. Today: the bridge accepts PAT (including TokenFederationProvider- wrapped PAT, which is how `get_python_sql_connector_auth_provider` always shapes it). Any non-PAT auth_provider raises a clear NotSupportedError pointing the user at use_sea=False (Thrift). This shrinks the auth_bridge to ~50 lines and means the kernel- side External enablement PR is no longer on the connector's critical path — there's no kernel-side prerequisite for shipping use_sea=True for PAT users. Unit tests updated: - TokenFederationProvider-wrapped PAT still routes to PAT (kept). - Generic OAuth provider raises NotSupportedError (new). - ExternalAuthProvider raises NotSupportedError (new). - Silent non-PAT provider raises NotSupportedError (new) — reject the type itself rather than trying to extract a token we already know we can't use. Live e2e against dogfood with use_sea=True (PAT): all checks still pass (SELECT 1, range(10000), fetchmany pacing, four metadata calls, session_configuration round-trip, structured DatabaseError on bad SQL). Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- .../sql/backend/kernel/auth_bridge.py | 78 +++++---------- tests/unit/test_kernel_auth_bridge.py | 97 +++++++++++-------- 2 files changed, 76 insertions(+), 99 deletions(-) diff --git a/src/databricks/sql/backend/kernel/auth_bridge.py b/src/databricks/sql/backend/kernel/auth_bridge.py index 1f14b8a5e..bb94dddf1 100644 --- a/src/databricks/sql/backend/kernel/auth_bridge.py +++ b/src/databricks/sql/backend/kernel/auth_bridge.py @@ -1,25 +1,19 @@ """Translate the connector's ``AuthProvider`` into ``databricks_sql_kernel`` ``Session`` auth kwargs. -The connector already implements every auth flow it supports (PAT, -OAuth M2M, OAuth U2M, external token providers, federation). The -kernel must not re-implement them. Decision D9 in the integration -design: PAT goes through the kernel's PAT path; everything else -delegates back to the connector via the kernel's ``External`` -trampoline, with a Python callback that returns a fresh bearer -token. +This phase ships PAT only. The kernel-side PyO3 binding accepts +``auth_type='pat'``; OAuth / federation / custom credentials +providers are reserved but not yet wired in either layer. Non-PAT +auth raises ``NotSupportedError`` from this bridge so the failure +surfaces at session-open time with a clear message rather than +deep inside the kernel. Token extraction goes through ``AuthProvider.add_headers({})`` rather than touching auth-provider-specific attributes, so the -bridge works for every subclass — including custom providers a -caller may have wired in. - -End-to-end limitation: the kernel's -``build_auth_provider`` currently rejects ``AuthConfig::External`` -("reserved; v0 wires PAT + OAuthM2M + OAuthU2M only"). Until the -kernel-side follow-up PR lands, non-PAT auth surfaces a clear -``KernelError(code='InvalidArgument', message='AuthConfig::External -is reserved...')`` from ``Session.open_session``. PAT works today. +bridge works uniformly for every PAT shape — including +``AccessTokenAuthProvider`` wrapped in ``TokenFederationProvider`` +(which ``get_python_sql_connector_auth_provider`` does for every +provider it builds). """ from __future__ import annotations @@ -29,6 +23,7 @@ from databricks.sql.auth.authenticators import AccessTokenAuthProvider, AuthProvider from databricks.sql.auth.token_federation import TokenFederationProvider +from databricks.sql.exc import NotSupportedError logger = logging.getLogger(__name__) @@ -64,8 +59,8 @@ def _extract_bearer_token(auth_provider: AuthProvider) -> Optional[str]: provider-specific internals. Returns ``None`` if the provider did not write an Authorization - header or wrote a non-Bearer scheme — neither shape is - representable in the kernel's auth surface today. + header or wrote a non-Bearer scheme — neither is representable + in the kernel's PAT auth surface. """ headers: Dict[str, str] = {} auth_provider.add_headers(headers) @@ -80,29 +75,13 @@ def _extract_bearer_token(auth_provider: AuthProvider) -> Optional[str]: def kernel_auth_kwargs(auth_provider: AuthProvider) -> Dict[str, Any]: """Build the kwargs passed to ``databricks_sql_kernel.Session(...)``. - Two routing decisions: - - 1. ``AccessTokenAuthProvider`` → ``auth_type='pat'`` with the - static token. Kernel uses it verbatim for every request. - 2. Anything else → ``auth_type='external'`` with a callback that - calls ``auth_provider.add_headers({})`` and returns the - fresh bearer token. The connector keeps owning the OAuth / - MSAL / federation flow; the kernel asks for a token whenever - it needs one. - - The PAT special-case exists because it's the only path the - kernel actually serves end-to-end today. Once the kernel-side - External enablement lands, PAT could collapse into the - External path too (one callback that returns the static token); - but keeping the explicit ``pat`` route means the kernel does - not pay the GIL-reacquire cost on every HTTP request for PAT - users. + PAT (including ``TokenFederationProvider``-wrapped PAT) routes + through the kernel's PAT path. Anything else raises + ``NotSupportedError`` — the kernel binding doesn't accept OAuth + today, and routing OAuth through PAT would silently break + token refresh during long-running sessions. """ if _is_pat(auth_provider): - # PAT case: pull the static token out and feed the kernel's - # PAT path. We go through ``add_headers`` regardless of - # whether the provider was wrapped in TokenFederation or - # not — both shapes write the same Authorization header. token = _extract_bearer_token(auth_provider) if not token: raise ValueError( @@ -111,21 +90,8 @@ def kernel_auth_kwargs(auth_provider: AuthProvider) -> Dict[str, Any]: ) return {"auth_type": "pat", "access_token": token} - # Every other provider: trampoline a callback. The callback is - # invoked once per HTTP request that needs auth (the kernel does - # not cache the returned token), so the auth_provider's own - # caching is what keeps this fast. - def token_callback() -> str: - token = _extract_bearer_token(auth_provider) - if not token: - raise RuntimeError( - f"{type(auth_provider).__name__}.add_headers did not produce " - "a Bearer Authorization header; cannot supply a token to the kernel" - ) - return token - - logger.debug( - "Routing %s through kernel External trampoline", - type(auth_provider).__name__, + raise NotSupportedError( + f"The kernel backend (use_sea=True) currently only supports PAT auth, " + f"but got {type(auth_provider).__name__}. Use use_sea=False (Thrift) " + "for OAuth / federation / custom credential providers." ) - return {"auth_type": "external", "token_callback": token_callback} diff --git a/tests/unit/test_kernel_auth_bridge.py b/tests/unit/test_kernel_auth_bridge.py index 920e94202..4ef85a471 100644 --- a/tests/unit/test_kernel_auth_bridge.py +++ b/tests/unit/test_kernel_auth_bridge.py @@ -1,9 +1,12 @@ """Unit tests for the kernel backend's auth bridge. -The bridge translates the connector's ``AuthProvider`` hierarchy -into ``databricks_sql_kernel.Session`` auth kwargs. PAT goes through -the kernel's PAT path; everything else trampolines through the -``External`` path with a Python callback. +Phase 1 ships PAT only. Tests verify: + - PAT routes through ``auth_type='pat'``. + - ``TokenFederationProvider``-wrapped PAT also routes through + PAT (every provider built by ``get_python_sql_connector_auth_provider`` + is federation-wrapped, so the naive isinstance check has to + look through the wrapper). + - Anything else raises ``NotSupportedError`` with a clear message. """ from __future__ import annotations @@ -12,38 +15,36 @@ import pytest -from databricks.sql.auth.authenticators import AccessTokenAuthProvider, AuthProvider +from databricks.sql.auth.authenticators import ( + AccessTokenAuthProvider, + AuthProvider, + DatabricksOAuthProvider, + ExternalAuthProvider, +) from databricks.sql.backend.kernel.auth_bridge import ( _extract_bearer_token, kernel_auth_kwargs, ) +from databricks.sql.exc import NotSupportedError class _FakeOAuthProvider(AuthProvider): - """Stand-in for OAuth/MSAL/federation providers — anything that - isn't ``AccessTokenAuthProvider``. Returns a counter-stamped - token so tests can prove the callback is invoked each call.""" - - def __init__(self): - self.calls = 0 + """Stand-in for any non-PAT provider. The bridge should reject + these with NotSupportedError.""" def add_headers(self, request_headers): - self.calls += 1 - request_headers["Authorization"] = f"Bearer token-{self.calls}" + request_headers["Authorization"] = "Bearer oauth-token-xyz" class _MalformedProvider(AuthProvider): - """Provider that returns a non-Bearer Authorization header - (e.g. Basic auth). The bridge should reject this rather than - silently sending the wrong shape to the kernel.""" + """Provider that returns a non-Bearer Authorization header.""" def add_headers(self, request_headers): request_headers["Authorization"] = "Basic dXNlcjpwYXNz" class _SilentProvider(AuthProvider): - """Provider that writes nothing — represents misconfigured - auth or a placeholder. The bridge must surface this clearly.""" + """Provider that writes nothing — misconfigured auth.""" def add_headers(self, request_headers): pass @@ -74,43 +75,53 @@ def test_federation_wrapped_pat_routes_to_kernel_pat(self): underlying ``AccessTokenAuthProvider``.""" from databricks.sql.auth.token_federation import TokenFederationProvider - # TokenFederationProvider needs an http_client; a MagicMock - # is sufficient because we don't trigger any token exchange - # in the test (the cached-token path is never hit). base = AccessTokenAuthProvider("dapi-abc") + # TokenFederationProvider's __init__ requires an http_client + # to construct cleanly; for this unit test we only exercise + # the add_headers passthrough + the external_provider + # attribute. Bypass __init__ with __new__ and stash just + # the fields the bridge touches. federated = TokenFederationProvider.__new__(TokenFederationProvider) federated.external_provider = base - # The bridge only touches `add_headers` (delegated to the - # base) and `external_provider`. Other attrs would be set - # by __init__ but aren't exercised here. federated.add_headers = base.add_headers kwargs = kernel_auth_kwargs(federated) assert kwargs == {"auth_type": "pat", "access_token": "dapi-abc"} - def test_pat_with_silent_provider_raises(self): + def test_pat_with_silent_provider_raises_value_error(self): """An AccessTokenAuthProvider that produces no Authorization header is misconfigured; surface that at bridge-build time, not on the first kernel HTTP request.""" broken = AccessTokenAuthProvider("dapi-x") - # Force the broken state by monkey-patching add_headers. broken.add_headers = lambda h: None # type: ignore[method-assign] with pytest.raises(ValueError, match="Bearer"): kernel_auth_kwargs(broken) - def test_oauth_routes_to_external_trampoline(self): - provider = _FakeOAuthProvider() - kwargs = kernel_auth_kwargs(provider) - assert kwargs["auth_type"] == "external" - callback = kwargs["token_callback"] - assert callable(callback) - # First call -> token-1, second call -> token-2. Proves the - # callback delegates to the live auth_provider each time - # rather than caching. - assert callback() == "token-1" - assert callback() == "token-2" - assert provider.calls == 2 - - def test_external_callback_raises_on_missing_header(self): - kwargs = kernel_auth_kwargs(_SilentProvider()) - with pytest.raises(RuntimeError, match="Bearer"): - kwargs["token_callback"]() + def test_generic_oauth_provider_raises_not_supported(self): + with pytest.raises(NotSupportedError, match="only supports PAT"): + kernel_auth_kwargs(_FakeOAuthProvider()) + + def test_external_credentials_provider_raises_not_supported(self): + """``ExternalAuthProvider`` wraps user-supplied + credentials_provider — kernel doesn't accept these today, + and the bridge surfaces that explicitly.""" + # ExternalAuthProvider's __init__ calls the credentials + # provider; supply a noop one. + from databricks.sql.auth.authenticators import CredentialsProvider + + class _NoopCreds(CredentialsProvider): + def auth_type(self): + return "noop" + + def __call__(self, *args, **kwargs): + return lambda: {"Authorization": "Bearer noop"} + + ext = ExternalAuthProvider(_NoopCreds()) + with pytest.raises(NotSupportedError, match="only supports PAT"): + kernel_auth_kwargs(ext) + + def test_silent_non_pat_provider_also_raises_not_supported(self): + """Even if a non-PAT provider produces no header, the bridge + rejects the type itself — we don't try to extract a token + from something we already know is unsupported.""" + with pytest.raises(NotSupportedError): + kernel_auth_kwargs(_SilentProvider()) From 4b07e4cbb97d7341abc4a2600882ec49c66348cc Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Thu, 14 May 2026 15:34:54 +0000 Subject: [PATCH 03/16] test(e2e): live kernel-backend (use_sea=True) suite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Moves the previously-ad-hoc /tmp/connector_smoke.py into the repo as a real pytest module under tests/e2e/ — same convention as the rest of the e2e suite. Uses the existing session-scoped `connection_details` fixture from the top-level conftest so it shares the credential surface with every other live test. 11 tests cover: - connect() with use_sea=True opens a session. - SELECT 1: rows + description shape (column name + dbapi type slug). - SELECT * FROM range(10000): multi-batch drain. - fetchmany() pacing across the buffer boundary. - fetchall_arrow() returns a pyarrow Table. - All four metadata methods (catalogs / schemas / tables / columns). - session_configuration={'ANSI_MODE': 'false'} round-trips. - Bad SQL surfaces as DatabaseError with `code='SqlError'` and `sql_state='42P01'` attached as exception attributes. Module-level skips: - `databricks_sql_kernel` not importable → whole module skipped via pytest.importorskip (the wheel hasn't been installed). - Live creds missing → fixture-level skip with a pointed message. Run: `pytest tests/e2e/test_kernel_backend.py -v`. All 11 pass against dogfood in ~20s. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- tests/e2e/test_kernel_backend.py | 186 +++++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 tests/e2e/test_kernel_backend.py diff --git a/tests/e2e/test_kernel_backend.py b/tests/e2e/test_kernel_backend.py new file mode 100644 index 000000000..19fa5072f --- /dev/null +++ b/tests/e2e/test_kernel_backend.py @@ -0,0 +1,186 @@ +"""E2E tests for ``use_sea=True`` (routes through the Rust kernel +via the PyO3 ``databricks_sql_kernel`` module). + +PAT auth only. Anything else surfaces as ``NotSupportedError`` +from the auth bridge — covered as a unit test, not exercised here. + +Skipped automatically when: + - The standard ``DATABRICKS_SERVER_HOSTNAME`` / ``HTTP_PATH`` / + ``TOKEN`` creds aren't set (existing connector convention). + - ``databricks_sql_kernel`` isn't importable (the wheel hasn't + been installed; run ``pip install + 'databricks-sql-connector[kernel]'`` or, for local dev, + ``cd databricks-sql-kernel/pyo3 && maturin develop --release`` + into this venv). + +Run from the connector repo root: + + set -a && source ~/.databricks/pecotesting-creds && set +a + .venv/bin/pytest tests/e2e/test_kernel_backend.py -v +""" + +from __future__ import annotations + +import pytest + +import databricks.sql as sql +from databricks.sql.exc import DatabaseError + + +# Skip the whole module unless the kernel wheel is importable. +pytest.importorskip( + "databricks_sql_kernel", + reason="use_sea=True requires the databricks-sql-kernel package", +) + + +@pytest.fixture(scope="module") +def kernel_conn_params(connection_details): + """Live-cred check + connection params for use_sea=True. + + Skips the module if any cred is missing rather than letting + every test fail with a confusing connect-time error. + """ + host = connection_details.get("host") + http_path = connection_details.get("http_path") + token = connection_details.get("access_token") + if not (host and http_path and token): + pytest.skip( + "DATABRICKS_SERVER_HOSTNAME / DATABRICKS_HTTP_PATH / " + "DATABRICKS_TOKEN not set" + ) + return { + "server_hostname": host, + "http_path": http_path, + "access_token": token, + "use_sea": True, + } + + +@pytest.fixture +def conn(kernel_conn_params): + """One-shot connection per test (the simple_test pattern the + existing e2e suite uses for cursor-level tests).""" + c = sql.connect(**kernel_conn_params) + try: + yield c + finally: + c.close() + + +def test_connect_with_use_sea_opens_a_session(conn): + assert conn.open, "connection should report open after connect()" + + +def test_select_one(conn): + with conn.cursor() as cur: + cur.execute("SELECT 1 AS n") + assert cur.description[0][0] == "n" + # description type slug matches what Thrift produces + assert cur.description[0][1] == "int" + rows = cur.fetchall() + assert len(rows) == 1 + assert rows[0][0] == 1 + + +def test_drain_large_range_to_arrow(conn): + """SELECT * FROM range(10000) drains as a pyarrow Table with + 10000 rows. Exercises the CloudFetch / multi-batch path on the + kernel side.""" + with conn.cursor() as cur: + cur.execute("SELECT * FROM range(10000)") + rows = cur.fetchall() + assert len(rows) == 10000 + + +def test_fetchmany_pacing(conn): + """fetchmany honours the requested size and stops cleanly at + end-of-stream — covers the buffer-slicing logic in + KernelResultSet.""" + with conn.cursor() as cur: + cur.execute("SELECT * FROM range(50)") + r1 = cur.fetchmany(10) + r2 = cur.fetchmany(20) + r3 = cur.fetchmany(100) # capped at remaining + assert (len(r1), len(r2), len(r3)) == (10, 20, 20) + + +def test_fetchall_arrow(conn): + with conn.cursor() as cur: + cur.execute("SELECT 1 AS a, 'hi' AS b") + table = cur.fetchall_arrow() + assert table.num_rows == 1 + assert table.column_names == ["a", "b"] + + +# ── Metadata ────────────────────────────────────────────────────── + + +def test_metadata_catalogs(conn): + with conn.cursor() as cur: + cur.catalogs() + rows = cur.fetchall() + assert len(rows) > 0 + + +def test_metadata_schemas(conn): + with conn.cursor() as cur: + cur.schemas(catalog_name="main") + rows = cur.fetchall() + assert len(rows) > 0 + + +def test_metadata_tables(conn): + with conn.cursor() as cur: + cur.tables(catalog_name="system", schema_name="information_schema") + rows = cur.fetchall() + assert len(rows) > 0 + + +def test_metadata_columns(conn): + with conn.cursor() as cur: + cur.columns( + catalog_name="system", + schema_name="information_schema", + table_name="tables", + ) + rows = cur.fetchall() + assert len(rows) > 0 + + +# ── Session configuration ───────────────────────────────────────── + + +def test_session_configuration_round_trips(kernel_conn_params): + """`session_configuration` flows through to the kernel's + `session_conf` and is honoured by the server. + + `ANSI_MODE` is the safe choice — it's on the SEA allow-list and + isn't workspace-policy-clamped (unlike `STATEMENT_TIMEOUT`) or + rejected by the warehouse (unlike `TIMEZONE` on dogfood).""" + params = dict(kernel_conn_params) + params["session_configuration"] = {"ANSI_MODE": "false"} + with sql.connect(**params) as c: + with c.cursor() as cur: + cur.execute("SET ANSI_MODE") + rows = cur.fetchall() + kv = {r[0]: r[1] for r in rows} + assert kv.get("ANSI_MODE") == "false", f"got {rows!r}" + + +# ── Error mapping ───────────────────────────────────────────────── + + +def test_bad_sql_surfaces_as_databaseerror(conn): + """Bad SQL should surface as a PEP 249 ``DatabaseError`` with + the kernel's structured fields (`code`, `sql_state`, `query_id`) + attached as attributes — the connector backend re-raises the + kernel's ``SqlError`` to ``DatabaseError`` while preserving the + server-reported state.""" + with conn.cursor() as cur: + with pytest.raises(DatabaseError) as exc_info: + cur.execute("SELECT * FROM definitely_not_a_table_xyz_kernel_e2e") + err = exc_info.value + # Structured fields copied off the kernel exception: + assert getattr(err, "code", None) == "SqlError" + assert getattr(err, "sql_state", None) == "42P01" From aab4fa834537fbd238af7e138077eca9b431d1b3 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Thu, 14 May 2026 15:38:50 +0000 Subject: [PATCH 04/16] fix(backend/kernel): defer databricks-sql-kernel poetry dep declaration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CI is failing across all jobs at \`poetry lock\` time: Because databricks-sql-connector depends on databricks-sql-kernel (^0.1.0) which doesn't match any versions, version solving failed. The kernel wheel isn't yet published to PyPI — we verified the name is available via the Databricks proxy, but the package itself hasn't been built and uploaded yet. Declaring it as a poetry dep (even an optional one inside an extra) requires the version to be resolvable, and \`poetry lock\` runs as the setup step for every CI job: unit tests, linting, type checks, all of them. Fix: drop the \`databricks-sql-kernel\` dep declaration and the \`[kernel]\` extra from pyproject.toml until the wheel is on PyPI. The lazy import in \`backend/kernel/client.py\` still raises a clear ImportError pointing at \`pip install databricks-sql-kernel\` (or local maturin) when use_sea=True is invoked without the kernel present. When the kernel is published, a small follow-up will add back: databricks-sql-kernel = {version = "^0.1.0", optional = true} [tool.poetry.extras] kernel = ["databricks-sql-kernel"] A pointed comment in pyproject.toml documents the deferred change. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- pyproject.toml | 20 ++++++++++++++------ src/databricks/sql/backend/kernel/client.py | 6 +++++- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a436132c4..6868919d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,16 +32,24 @@ pyarrow = [ pyjwt = "^2.0.0" pybreaker = "^1.0.0" requests-kerberos = {version = "^0.15.0", optional = true} -# Optional kernel backend: `pip install 'databricks-sql-connector[kernel]'` -# unlocks use_sea=True, which routes through the Rust kernel via PyO3. -# Without it, use_sea=True raises a pointed ImportError. The kernel -# wheel itself ships from the databricks-sql-kernel repo. -databricks-sql-kernel = {version = "^0.1.0", optional = true} [tool.poetry.extras] pyarrow = ["pyarrow"] -kernel = ["databricks-sql-kernel"] +# `[kernel]` extra is intentionally not declared here yet. +# `databricks-sql-kernel` is built from the databricks-sql-kernel +# repo and not yet published to PyPI; declaring it as a poetry dep +# breaks `poetry lock` for every CI job. Once the wheel is on PyPI +# the extra will be added back here: +# +# databricks-sql-kernel = {version = "^0.1.0", optional = true} +# [tool.poetry.extras] +# kernel = ["databricks-sql-kernel"] +# +# Until then, install the kernel separately: +# pip install databricks-sql-kernel +# or (local dev): +# cd databricks-sql-kernel/pyo3 && maturin develop --release [tool.poetry.group.dev.dependencies] pytest = "^7.1.2" diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index 67b6a2cda..42f4da409 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -66,9 +66,13 @@ try: import databricks_sql_kernel as _kernel # type: ignore[import-not-found] except ImportError as exc: # pragma: no cover - import-time error surfaces clearly + # The `databricks-sql-kernel` wheel is not yet on PyPI, so we + # don't yet declare it as an optional extra in pyproject.toml + # (doing so breaks `poetry lock`). Once published the install + # hint will move to `pip install 'databricks-sql-connector[kernel]'`. raise ImportError( "use_sea=True requires the databricks-sql-kernel package. Install it with:\n" - " pip install 'databricks-sql-connector[kernel]'\n" + " pip install databricks-sql-kernel\n" "or for local development from the kernel repo:\n" " cd databricks-sql-kernel/pyo3 && maturin develop --release" ) from exc From 85663255f178c72fbf24c05150368abeec59f458 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Thu, 14 May 2026 15:48:50 +0000 Subject: [PATCH 05/16] fix(backend/kernel): unit tests skip without pyarrow, mypy + black MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three CI failures after the poetry-lock fix uncovered three real issues: 1. pyarrow is optional in the connector. The default-deps CI test job installs without it; the +PyArrow job installs with. The kernel backend's result_set.py + type_mapping.py import pyarrow eagerly (the kernel always returns pyarrow), and the unit tests import the backend at collection time — which crashes the default-deps job at ModuleNotFoundError. Fix: gate the three kernel unit tests on `pytest.importorskip( "pyarrow")` so they skip on default-deps and run on +PyArrow. Verified locally: 39 pass with pyarrow, 3 skipped without. No change to the backend module itself — nothing imports it until use_sea=True is invoked, and pyarrow is on the kernel wheel's runtime dep list so use_sea=True can't hit this either. 2. mypy: KernelDatabricksClient.open_session returns self._session_id, which mypy types as Optional[SessionId] because the field starts as None. Fix: bind the new id to a local non-Optional variable, assign to the field, return the local. CI's check-types runs cleanly on backend/kernel/ now; pre-existing mypy noise elsewhere isn't mine. 3. black --check: black 22.12.0 (the version CI pins) wants reformatting on result_set.py / type_mapping.py / client.py. Applied. Verified locally with the same black version. All 39 kernel unit tests + 619 pre-existing unit tests pass. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- src/databricks/sql/backend/kernel/client.py | 18 +++++++++++------- .../sql/backend/kernel/result_set.py | 4 +++- .../sql/backend/kernel/type_mapping.py | 10 +++++++++- tests/unit/test_kernel_auth_bridge.py | 8 ++++++++ tests/unit/test_kernel_result_set.py | 6 +++++- tests/unit/test_kernel_type_mapping.py | 7 ++++++- 6 files changed, 42 insertions(+), 11 deletions(-) diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index 42f4da409..6d62e986a 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -214,11 +214,11 @@ def open_session( # Use the kernel's real server-issued session id, not a # synthetic UUID. Matches what the native SEA backend does. - self._session_id = SessionId.from_sea_session_id( - self._kernel_session.session_id - ) - logger.info("Opened kernel-backed session %s", self._session_id) - return self._session_id + # Bind to a local first so mypy sees a non-Optional return. + session_id = SessionId.from_sea_session_id(self._kernel_session.session_id) + self._session_id = session_id + logger.info("Opened kernel-backed session %s", session_id) + return session_id def close_session(self, session_id: SessionId) -> None: if self._kernel_session is None: @@ -229,7 +229,9 @@ def close_session(self, session_id: SessionId) -> None: try: handle.close() except _kernel.KernelError as exc: - logger.warning("Error closing async handle during session close: %s", exc) + logger.warning( + "Error closing async handle during session close: %s", exc + ) self._async_handles.clear() try: self._kernel_session.close() @@ -474,7 +476,9 @@ def get_columns( # Kernel's list_columns requires a catalog (SEA `SHOW # COLUMNS` cannot span catalogs). Surface the constraint # explicitly rather than letting the kernel error. - raise ProgrammingError("get_columns requires catalog_name on the kernel backend.") + raise ProgrammingError( + "get_columns requires catalog_name on the kernel backend." + ) try: stream = self._kernel_session.metadata().list_columns( catalog=catalog_name, diff --git a/src/databricks/sql/backend/kernel/result_set.py b/src/databricks/sql/backend/kernel/result_set.py index d6a0e8588..0ee85c2be 100644 --- a/src/databricks/sql/backend/kernel/result_set.py +++ b/src/databricks/sql/backend/kernel/result_set.py @@ -144,7 +144,9 @@ def _drain(self) -> pyarrow.Table: chunks: List[pyarrow.RecordBatch] = [] if self._buffer and self._buffer_offset > 0: head = self._buffer.popleft() - chunks.append(head.slice(self._buffer_offset, head.num_rows - self._buffer_offset)) + chunks.append( + head.slice(self._buffer_offset, head.num_rows - self._buffer_offset) + ) self._buffer_offset = 0 while self._buffer: chunks.append(self._buffer.popleft()) diff --git a/src/databricks/sql/backend/kernel/type_mapping.py b/src/databricks/sql/backend/kernel/type_mapping.py index bc4ffe5d2..a91160d17 100644 --- a/src/databricks/sql/backend/kernel/type_mapping.py +++ b/src/databricks/sql/backend/kernel/type_mapping.py @@ -66,6 +66,14 @@ def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]: ADBC / Thrift result paths produce. """ return [ - (field.name, _arrow_type_to_dbapi_string(field.type), None, None, None, None, None) + ( + field.name, + _arrow_type_to_dbapi_string(field.type), + None, + None, + None, + None, + None, + ) for field in schema ] diff --git a/tests/unit/test_kernel_auth_bridge.py b/tests/unit/test_kernel_auth_bridge.py index 4ef85a471..01789898a 100644 --- a/tests/unit/test_kernel_auth_bridge.py +++ b/tests/unit/test_kernel_auth_bridge.py @@ -15,6 +15,14 @@ import pytest +# The kernel backend's result_set + type_mapping modules transitively +# import pyarrow; the connector's default-deps test job doesn't +# install pyarrow, so importing the auth_bridge in that environment +# would fail at module-collection time. Gate the whole module on +# pyarrow availability — matches the convention the connector uses +# for pyarrow-dependent tests. +pytest.importorskip("pyarrow") + from databricks.sql.auth.authenticators import ( AccessTokenAuthProvider, AuthProvider, diff --git a/tests/unit/test_kernel_result_set.py b/tests/unit/test_kernel_result_set.py index 7a4023193..c83bfce94 100644 --- a/tests/unit/test_kernel_result_set.py +++ b/tests/unit/test_kernel_result_set.py @@ -8,9 +8,13 @@ from typing import Deque from unittest.mock import MagicMock -import pyarrow as pa import pytest +# pyarrow is an optional connector dep; the default-deps CI test +# job runs without it. KernelResultSet imports pyarrow eagerly, +# so the whole module must skip when pyarrow is unavailable. +pa = pytest.importorskip("pyarrow") + from databricks.sql.backend.kernel.result_set import KernelResultSet from databricks.sql.backend.types import CommandId, CommandState diff --git a/tests/unit/test_kernel_type_mapping.py b/tests/unit/test_kernel_type_mapping.py index 3c6fe9b15..5ab5bde74 100644 --- a/tests/unit/test_kernel_type_mapping.py +++ b/tests/unit/test_kernel_type_mapping.py @@ -2,9 +2,14 @@ from __future__ import annotations -import pyarrow as pa import pytest +# pyarrow is an optional connector dep; the default-deps CI test +# job runs without it. The kernel backend itself imports pyarrow +# at module load, so any test that touches the backend must skip +# when pyarrow is unavailable. +pa = pytest.importorskip("pyarrow") + from databricks.sql.backend.kernel.type_mapping import ( _arrow_type_to_dbapi_string, description_from_arrow_schema, From 58afea4f67a0e7546d58670da853e1787ac47bbb Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Thu, 14 May 2026 15:56:16 +0000 Subject: [PATCH 06/16] fix(backend/kernel): make package importable without the kernel wheel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The +PyArrow CI matrix installs pyarrow but not the databricks-sql-kernel wheel (the wheel isn't on PyPI yet, and the [kernel] extra is deferred — see commit 31ca581c). The previous fix gated unit tests on `pytest.importorskip("pyarrow")` but test_kernel_auth_bridge.py was still pulled into a kernel-wheel ImportError because: src/databricks/sql/backend/kernel/__init__.py -> from databricks.sql.backend.kernel.client import KernelDatabricksClient -> import databricks_sql_kernel # ImportError on +PyArrow CI The eager re-export from `__init__.py` was a convenience that broke every consumer that only needed a submodule (type_mapping, result_set, auth_bridge) — they all triggered the kernel wheel import for no reason. Fix: - Drop the eager re-export from `kernel/__init__.py`. Comment documents why and points callers (= session.py::_create_backend, already this shape) at the direct `from .client import ...`. - Drop the no-longer-needed `pytest.importorskip("pyarrow")` / `importorskip("databricks_sql_kernel")` from test_kernel_auth_bridge.py — auth_bridge.py itself has neither dep, so the test now runs on every CI matrix variant. - test_kernel_result_set.py and test_kernel_type_mapping.py keep the pyarrow importorskip because they themselves use pyarrow. Verified locally across the three matrix shapes: - both pyarrow + kernel installed: 39 pass. - pyarrow only (no kernel wheel — the +PyArrow CI shape): 39 pass. - neither: 9 pass (auth_bridge only), 2 modules skip (the others use pyarrow). Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- src/databricks/sql/backend/kernel/__init__.py | 18 ++++++++++++++---- tests/unit/test_kernel_auth_bridge.py | 12 +++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/databricks/sql/backend/kernel/__init__.py b/src/databricks/sql/backend/kernel/__init__.py index a0de1861c..4a1ad8205 100644 --- a/src/databricks/sql/backend/kernel/__init__.py +++ b/src/databricks/sql/backend/kernel/__init__.py @@ -6,10 +6,20 @@ switch its default transport (SEA REST → SEA gRPC → …) without renaming this module. +This ``__init__`` deliberately does **not** re-export +``KernelDatabricksClient`` from ``.client``. Importing ``.client`` +loads the ``databricks_sql_kernel`` PyO3 extension at module-import +time; doing that eagerly here would make ``import +databricks.sql.backend.kernel.type_mapping`` (used by tests / by +``KernelResultSet`` consumers) require the kernel wheel even when +the caller never plans to open a kernel-backed session. Callers +that need the client import it directly: + + from databricks.sql.backend.kernel.client import KernelDatabricksClient + +``session.py::_create_backend`` already does this lazy import under +the ``use_sea=True`` branch. + See ``docs/designs/pysql-kernel-integration.md`` in ``databricks-sql-kernel`` for the full integration design. """ - -from databricks.sql.backend.kernel.client import KernelDatabricksClient - -__all__ = ["KernelDatabricksClient"] diff --git a/tests/unit/test_kernel_auth_bridge.py b/tests/unit/test_kernel_auth_bridge.py index 01789898a..57f1ecaaf 100644 --- a/tests/unit/test_kernel_auth_bridge.py +++ b/tests/unit/test_kernel_auth_bridge.py @@ -15,13 +15,11 @@ import pytest -# The kernel backend's result_set + type_mapping modules transitively -# import pyarrow; the connector's default-deps test job doesn't -# install pyarrow, so importing the auth_bridge in that environment -# would fail at module-collection time. Gate the whole module on -# pyarrow availability — matches the convention the connector uses -# for pyarrow-dependent tests. -pytest.importorskip("pyarrow") +# auth_bridge.py itself has no pyarrow or kernel-wheel deps. The +# `databricks.sql.backend.kernel` package's __init__.py deliberately +# does *not* eagerly re-export from .client either (which would +# require the kernel wheel). So this test can run on the +# default-deps CI matrix without any extras. No importorskip needed. from databricks.sql.auth.authenticators import ( AccessTokenAuthProvider, From c5a5162164acc6c174bea09a06396d846ec5e19d Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Thu, 14 May 2026 16:12:10 +0000 Subject: [PATCH 07/16] test(e2e): skip use_sea=True parametrized cases when kernel wheel missing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The connector's coverage CI job runs the full e2e suite, several of whose test classes parametrize ``extra_params`` over ``{}`` and ``{"use_sea": True}``. With ``use_sea=True`` now routing through the Rust kernel via PyO3, those cases die at ``connect()`` with our pointed ImportError because the ``databricks-sql-kernel`` wheel isn't yet on PyPI — and that CI job (sensibly) doesn't try to build it from a sibling repo. Fix: ``pytest_collection_modifyitems`` hook in the top-level ``conftest.py`` that adds a ``skip`` marker to any parametrize case with ``extra_params={"use_sea": True, ...}`` when ``importlib.util.find_spec("databricks_sql_kernel")`` returns ``None``. Behavior change is CI-only — local dev with the kernel wheel installed (via ``maturin develop`` from the kernel repo) runs those cases as before. Once the kernel wheel is published, the [kernel] extra in pyproject.toml gets enabled (see comment block there) and the default-deps CI matrix will install it; the skip then becomes a no-op. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- conftest.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/conftest.py b/conftest.py index c8b350bee..748f73443 100644 --- a/conftest.py +++ b/conftest.py @@ -1,7 +1,41 @@ +import importlib.util import os import pytest +def _kernel_wheel_available() -> bool: + """The ``use_sea=True`` code path now routes through the Rust + kernel via PyO3. The ``databricks_sql_kernel`` wheel is not + yet on PyPI (built from a separate repo); CI environments + without it should skip ``use_sea=True`` parametrized cases + rather than fail with a hard ImportError.""" + return importlib.util.find_spec("databricks_sql_kernel") is not None + + +def pytest_collection_modifyitems(config, items): + """Skip parametrized test cases that pass ``use_sea=True`` when + the kernel wheel isn't installed. + + The existing e2e suite uses ``@pytest.mark.parametrize( + "extra_params", [{}, {"use_sea": True}])`` to exercise both + backends. When the kernel wheel is missing those cases die at + ``connect()`` time with our pointed ImportError; mark them + skipped at collection time so CI signal stays accurate. + """ + if _kernel_wheel_available(): + return + skip_marker = pytest.mark.skip( + reason="use_sea=True requires databricks-sql-kernel (not installed)" + ) + for item in items: + params = getattr(item, "callspec", None) + if params is None: + continue + extra_params = params.params.get("extra_params") + if isinstance(extra_params, dict) and extra_params.get("use_sea") is True: + item.add_marker(skip_marker) + + @pytest.fixture(scope="session") def host(): return os.getenv("DATABRICKS_SERVER_HOSTNAME") From 92243251fd59c3f394a59a1be7d2445027c5a4cb Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 15 May 2026 10:45:15 +0000 Subject: [PATCH 08/16] =?UTF-8?q?refactor(backend/kernel):=20address=20rev?= =?UTF-8?q?iew=20feedback=20=E2=80=94=20mechanical=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cleanup pass on the kernel-backend PR addressing reviewer feedback that doesn't change observable behaviour: - result_set.py: replace O(M²) `_buffered_rows` with running counter `_buffered_count` maintained by pull/take/drain (perf F6). - result_set.py: docstring corrections — drop nonexistent `fetch_all_arrow` from kernel-handle contract (F20); document `buffer_size_bytes` as no-op on the kernel backend (F21). - client.py: tighten `_reraise_kernel_error` signature to `_kernel.KernelError` only; drop dead passthrough branch and the defensive setattr try/except (F17). - client.py: drop unused `_use_arrow_native_complex_types` kwarg (F18). - client.py: collapse three `KernelResultSet(...)` construction sites through `_make_result_set` (renamed from `_metadata_result`) (F19). - client.py: drop `metadata-` prefix from synthetic CommandId; use a plain `uuid.uuid4().hex` so anything reading `cursor.query_id` downstream sees a UUID-shaped string (F14). - client.py: clear the raw access token from `_auth_kwargs` after the kernel session is constructed — kernel owns the credential from then on, no need to retain a cleartext copy on the connector instance (F24). - auth_bridge.py: reject bearer tokens containing ASCII control characters at extraction time (defense-in-depth against header injection if a misbehaving HTTP stack ever places the token back into a header without scrubbing) (F25). - tests/unit/test_kernel_auth_bridge.py: construct a real `TokenFederationProvider(http_client=Mock())` instead of bypassing `__init__` with `__new__` + monkey-patching `add_headers`. Exercises the real federation passthrough path the bridge sees in production (F12). Drop unused `MagicMock` import (F27). - tests/e2e/test_kernel_backend.py: drop misleading CloudFetch claim on `test_drain_large_range_to_arrow` — 10000 BIGINT rows is ~80 KB, single inline chunk on a typical warehouse (F26). All 39 existing kernel unit tests pass. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- .../sql/backend/kernel/auth_bridge.py | 16 ++++- src/databricks/sql/backend/kernel/client.py | 72 +++++++++---------- .../sql/backend/kernel/result_set.py | 35 +++++---- tests/e2e/test_kernel_backend.py | 6 +- tests/unit/test_kernel_auth_bridge.py | 24 ++++--- 5 files changed, 92 insertions(+), 61 deletions(-) diff --git a/src/databricks/sql/backend/kernel/auth_bridge.py b/src/databricks/sql/backend/kernel/auth_bridge.py index bb94dddf1..4721a3b04 100644 --- a/src/databricks/sql/backend/kernel/auth_bridge.py +++ b/src/databricks/sql/backend/kernel/auth_bridge.py @@ -19,6 +19,7 @@ from __future__ import annotations import logging +import re from typing import Any, Dict, Optional from databricks.sql.auth.authenticators import AccessTokenAuthProvider, AuthProvider @@ -30,6 +31,13 @@ _BEARER_PREFIX = "Bearer " +# Defense-in-depth: reject tokens containing ASCII control characters. +# A token with embedded CR/LF/NUL would let a misbehaving HTTP stack +# split or terminate the Authorization header line, opening a header- +# injection sink. Real PATs and federation-exchanged tokens never +# contain these. +_CONTROL_CHAR_RE = re.compile(r"[\x00-\x1f\x7f]") + def _is_pat(auth_provider: AuthProvider) -> bool: """Return True iff this provider ultimately wraps an @@ -69,7 +77,13 @@ def _extract_bearer_token(auth_provider: AuthProvider) -> Optional[str]: return None if not auth.startswith(_BEARER_PREFIX): return None - return auth[len(_BEARER_PREFIX) :] + token = auth[len(_BEARER_PREFIX) :] + if _CONTROL_CHAR_RE.search(token): + raise ValueError( + "Bearer token contains ASCII control characters; refusing to " + "forward it to the kernel auth bridge." + ) + return token def kernel_auth_kwargs(auth_provider: AuthProvider) -> Dict[str, Any]: diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index 6d62e986a..6466070b4 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -104,9 +104,9 @@ } -def _reraise_kernel_error(exc: BaseException) -> "Error": +def _reraise_kernel_error(exc: "_kernel.KernelError") -> "Error": """Convert a ``databricks_sql_kernel.KernelError`` to a PEP 249 - exception. Other exception types fall through unchanged. + exception. Kernel errors carry their structured attrs (``code``, ``message``, ``sql_state``, ``error_code``, ``query_id`` …) as @@ -114,8 +114,6 @@ def _reraise_kernel_error(exc: BaseException) -> "Error": callers can branch on them without reaching back through ``__cause__``. """ - if not isinstance(exc, _kernel.KernelError): - return exc # type: ignore[return-value] code = getattr(exc, "code", "Unknown") cls = _CODE_TO_EXCEPTION.get(code, DatabaseError) new = cls(getattr(exc, "message", str(exc))) @@ -130,10 +128,7 @@ def _reraise_kernel_error(exc: BaseException) -> "Error": "retryable", "query_id", ): - try: - setattr(new, attr, getattr(exc, attr)) - except (AttributeError, TypeError): # pragma: no cover - defensive - pass + setattr(new, attr, getattr(exc, attr, None)) new.__cause__ = exc return new @@ -161,13 +156,12 @@ def __init__( schema: Optional[str] = None, http_headers=None, http_client=None, - _use_arrow_native_complex_types: Optional[bool] = True, **kwargs, ): # The connector hands us several fields the kernel doesn't # consume directly (ssl_options, http_headers, http_client, - # port, _use_arrow_native_complex_types). Kernel manages - # its own HTTP stack so we accept-and-ignore. + # port). Kernel manages its own HTTP stack so we + # accept-and-ignore. self._server_hostname = server_hostname self._http_path = http_path self._auth_provider = auth_provider @@ -211,6 +205,13 @@ def open_session( ) except _kernel.KernelError as exc: raise _reraise_kernel_error(exc) + finally: + # Drop the raw access token from the instance once the + # kernel session is constructed (or failed). The kernel + # owns the credential from this point on; keeping a + # cleartext copy on a long-lived connector object risks + # accidental capture by pickling / debuggers / telemetry. + self._auth_kwargs.pop("access_token", None) # Use the kernel's real server-issued session id, not a # synthetic UUID. Matches what the native SEA backend does. @@ -296,14 +297,7 @@ def execute_command( command_id = CommandId.from_sea_statement_id(executed.statement_id) cursor.active_command_id = command_id - return KernelResultSet( - connection=cursor.connection, - backend=self, - kernel_handle=executed, - command_id=command_id, - arraysize=cursor.arraysize, - buffer_size_bytes=cursor.buffer_size_bytes, - ) + return self._make_result_set(executed, cursor, command_id) def cancel_command(self, command_id: CommandId) -> None: handle = self._async_handles.get(command_id.guid) @@ -363,22 +357,23 @@ def get_execution_result( stream = handle.await_result() except _kernel.KernelError as exc: raise _reraise_kernel_error(exc) - return KernelResultSet( - connection=cursor.connection, - backend=self, - kernel_handle=stream, - command_id=command_id, - arraysize=cursor.arraysize, - buffer_size_bytes=cursor.buffer_size_bytes, - ) + return self._make_result_set(stream, cursor, command_id) # ── Metadata ─────────────────────────────────────────────────── - def _metadata_result(self, stream, cursor, command_id): + def _make_result_set( + self, + kernel_handle: Any, + cursor: "Cursor", + command_id: CommandId, + ) -> "ResultSet": + """Build a ``KernelResultSet`` from any kernel handle. Used + by sync execute, ``get_execution_result``, and all metadata + paths to keep construction in one place.""" return KernelResultSet( connection=cursor.connection, backend=self, - kernel_handle=stream, + kernel_handle=kernel_handle, command_id=command_id, arraysize=cursor.arraysize, buffer_size_bytes=cursor.buffer_size_bytes, @@ -386,9 +381,14 @@ def _metadata_result(self, stream, cursor, command_id): def _synthetic_command_id(self) -> CommandId: """Metadata calls don't produce a server statement id; mint - a synthetic one so the ``ResultSet`` still has a stable - identifier the cursor can attribute logs to.""" - return CommandId.from_sea_statement_id(f"metadata-{uuid.uuid4()}") + a synthetic UUID so the ``ResultSet`` still has a stable + identifier the cursor can attribute logs to. + + Plain ``uuid.uuid4().hex`` (no prefix) — anything that + consumes ``cursor.query_id`` downstream (telemetry, log + ingestion) sees a UUID-shaped string rather than a + connector-internal magic prefix it cannot parse.""" + return CommandId.from_sea_statement_id(uuid.uuid4().hex) def get_catalogs( self, @@ -403,7 +403,7 @@ def get_catalogs( stream = self._kernel_session.metadata().list_catalogs() except _kernel.KernelError as exc: raise _reraise_kernel_error(exc) - return self._metadata_result(stream, cursor, self._synthetic_command_id()) + return self._make_result_set(stream, cursor, self._synthetic_command_id()) def get_schemas( self, @@ -423,7 +423,7 @@ def get_schemas( ) except _kernel.KernelError as exc: raise _reraise_kernel_error(exc) - return self._metadata_result(stream, cursor, self._synthetic_command_id()) + return self._make_result_set(stream, cursor, self._synthetic_command_id()) def get_tables( self, @@ -457,7 +457,7 @@ def get_tables( ) except _kernel.KernelError as exc: raise _reraise_kernel_error(exc) - return self._metadata_result(stream, cursor, self._synthetic_command_id()) + return self._make_result_set(stream, cursor, self._synthetic_command_id()) def get_columns( self, @@ -488,7 +488,7 @@ def get_columns( ) except _kernel.KernelError as exc: raise _reraise_kernel_error(exc) - return self._metadata_result(stream, cursor, self._synthetic_command_id()) + return self._make_result_set(stream, cursor, self._synthetic_command_id()) # ── Misc ─────────────────────────────────────────────────────── diff --git a/src/databricks/sql/backend/kernel/result_set.py b/src/databricks/sql/backend/kernel/result_set.py index 0ee85c2be..40181f236 100644 --- a/src/databricks/sql/backend/kernel/result_set.py +++ b/src/databricks/sql/backend/kernel/result_set.py @@ -10,16 +10,22 @@ cancel. Both implement the same three methods this class actually calls: -``arrow_schema() / fetch_next_batch() / fetch_all_arrow() / close()``. -``KernelResultSet`` takes either via the ``kernel_handle`` parameter -and treats them uniformly — the connector's ``ResultSet`` contract -doesn't need to distinguish them. +``arrow_schema() / fetch_next_batch() / close()``. ``KernelResultSet`` +takes either via the ``kernel_handle`` parameter and treats them +uniformly — the connector's ``ResultSet`` contract doesn't need to +distinguish them. Buffer shape mirrors the prior ADBC POC's ``AdbcResultSet``: a FIFO of pyarrow ``RecordBatch``es, fed one batch at a time from the kernel as the connector calls ``fetch*``. ``fetchmany(n)`` slices within a batch when ``n`` is smaller than the kernel's natural batch size; ``fetchall`` drains the whole stream. + +Note: ``buffer_size_bytes`` is accepted by the constructor for +contract compatibility with the base ``ResultSet`` but is not +consulted — the kernel backend currently caps buffering by rows +pulled, not bytes. Memory ceilings should be controlled by the +kernel-side batch sizing. """ from __future__ import annotations @@ -84,6 +90,11 @@ def __init__( # re-fetch from the kernel. self._buffer: Deque[pyarrow.RecordBatch] = deque() self._buffer_offset: int = 0 + # Running count of rows currently buffered (sum of batch + # sizes minus the head-batch offset). Maintained by + # _pull_one_batch / _take_buffered / _drain so _buffered_rows + # stays O(1) instead of walking the deque. + self._buffered_count: int = 0 self._exhausted: bool = False # ----- internal helpers ----- @@ -102,22 +113,19 @@ def _pull_one_batch(self) -> bool: return False if batch.num_rows > 0: self._buffer.append(batch) + self._buffered_count += batch.num_rows return True def _ensure_buffered(self, n_rows: int) -> int: """Pull batches until ``n_rows`` are buffered or the kernel is exhausted. Returns total rows currently buffered.""" - while self._buffered_rows() < n_rows: + while self._buffered_count < n_rows: if not self._pull_one_batch(): break - return self._buffered_rows() + return self._buffered_count def _buffered_rows(self) -> int: - if not self._buffer: - return 0 - first = self._buffer[0].num_rows - self._buffer_offset - rest = sum(b.num_rows for b in list(self._buffer)[1:]) - return first + rest + return self._buffered_count def _take_buffered(self, n: int) -> pyarrow.Table: """Slice up to ``n`` rows out of the buffer; advances state.""" @@ -133,7 +141,9 @@ def _take_buffered(self, n: int) -> pyarrow.Table: if self._buffer_offset >= head.num_rows: self._buffer.popleft() self._buffer_offset = 0 - self._next_row_index += n - remaining + taken = n - remaining + self._buffered_count -= taken + self._next_row_index += taken if not slices: return pyarrow.Table.from_batches([], schema=self._schema) return pyarrow.Table.from_batches(slices, schema=self._schema) @@ -161,6 +171,7 @@ def _drain(self) -> pyarrow.Table: if batch.num_rows > 0: chunks.append(batch) rows = sum(c.num_rows for c in chunks) + self._buffered_count = 0 self._next_row_index += rows if not chunks: return pyarrow.Table.from_batches([], schema=self._schema) diff --git a/tests/e2e/test_kernel_backend.py b/tests/e2e/test_kernel_backend.py index 19fa5072f..32b1e94d6 100644 --- a/tests/e2e/test_kernel_backend.py +++ b/tests/e2e/test_kernel_backend.py @@ -85,8 +85,10 @@ def test_select_one(conn): def test_drain_large_range_to_arrow(conn): """SELECT * FROM range(10000) drains as a pyarrow Table with - 10000 rows. Exercises the CloudFetch / multi-batch path on the - kernel side.""" + 10000 rows. Exercises end-of-stream drain over multiple + ``fetch_next_batch`` calls; not large enough to cross a + CloudFetch chunk boundary — see test_driver for CloudFetch + coverage.""" with conn.cursor() as cur: cur.execute("SELECT * FROM range(10000)") rows = cur.fetchall() diff --git a/tests/unit/test_kernel_auth_bridge.py b/tests/unit/test_kernel_auth_bridge.py index 57f1ecaaf..a5e2e756b 100644 --- a/tests/unit/test_kernel_auth_bridge.py +++ b/tests/unit/test_kernel_auth_bridge.py @@ -11,7 +11,7 @@ from __future__ import annotations -from unittest.mock import MagicMock +from unittest.mock import Mock import pytest @@ -78,18 +78,22 @@ def test_federation_wrapped_pat_routes_to_kernel_pat(self): the base provider in a ``TokenFederationProvider``, so the PAT case never reaches us unwrapped in practice. The bridge must look through the federation wrapper to find the - underlying ``AccessTokenAuthProvider``.""" + underlying ``AccessTokenAuthProvider``. + + Construct a real ``TokenFederationProvider`` (with a mock + http_client — `_exchange_token` never fires for a plain + ``dapi-…`` PAT because it isn't a JWT, so the mock is never + called). This exercises the real ``add_headers`` path the + bridge sees in production. + """ from databricks.sql.auth.token_federation import TokenFederationProvider base = AccessTokenAuthProvider("dapi-abc") - # TokenFederationProvider's __init__ requires an http_client - # to construct cleanly; for this unit test we only exercise - # the add_headers passthrough + the external_provider - # attribute. Bypass __init__ with __new__ and stash just - # the fields the bridge touches. - federated = TokenFederationProvider.__new__(TokenFederationProvider) - federated.external_provider = base - federated.add_headers = base.add_headers + federated = TokenFederationProvider( + hostname="https://example.cloud.databricks.com", + external_provider=base, + http_client=Mock(), + ) kwargs = kernel_auth_kwargs(federated) assert kwargs == {"auth_type": "pat", "access_token": "dapi-abc"} From ea1ba45e061ef86e30d951dde4db5139946c2321 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Fri, 15 May 2026 10:55:50 +0000 Subject: [PATCH 09/16] feat(backend/kernel): introduce dedicated use_kernel flag + substantive review fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Major change: route the kernel backend through a new ``use_kernel=True`` connection kwarg instead of repurposing ``use_sea=True``. ``use_sea=True`` once again routes to the native pure-Python SEA backend (no behaviour change); ``use_kernel=True`` routes to the Rust kernel via PyO3. The two flags are mutually exclusive. This addresses the largest reviewer concern from the multi-agent review: silently hijacking a documented public flag broke OAuth / federation / parameter-binding callers on ``use_sea=True`` who had no opt-out. With the new flag, the kernel backend is fully opt-in and existing ``use_sea=True`` users continue to get the native SEA backend they signed up for. Other substantive fixes: - session.py: restore ``SeaDatabricksClient`` import + routing. Reject ``use_kernel=True`` + ``use_sea=True`` together with a clear ``ValueError``. - client.py (kernel ``Cursor.columns``): update docstring to flag the ``catalog_name=None`` divergence — kernel requires a catalog, Thrift / native SEA do not (F13). - conftest.py: drop the collection-time ``pytest_collection_modifyitems`` hook that was skipping ``extra_params={"use_sea": True}`` cases. With ``use_sea=True`` back on the native SEA backend, those cases run as they did before this PR (F8). - kernel/client.py: ``get_tables`` now applies the ``table_types`` filter client-side using ``ResultSetFilter._filter_arrow_table`` (the same helper the native SEA backend uses), wrapped in a tiny ``_StaticArrowHandle`` that flows the filtered table back through the normal ``KernelResultSet`` path. Replaces the previous "log a warning and return unfiltered" behaviour (F4). - kernel/client.py: guard ``_async_handles`` with ``threading.RLock`` so concurrent cursors on the same connection don't race on submit / close / close-session (F15). - kernel/result_set.py: ``KernelResultSet.close()`` now drops the entry from ``backend._async_handles`` so async-submitted statements don't leave stale references behind (F5). - kernel/{__init__,client,auth_bridge}.py, tests/e2e/test_kernel_backend.py: update docstrings, error messages, and the e2e fixture to refer to ``use_kernel=True`` instead of ``use_sea=True``. - client.py (``Connection`` docstring): document the new ``use_kernel`` kwarg + its Phase-1 limitations. New tests: - tests/unit/test_kernel_client.py (38 cases): cover the 14-entry ``_CODE_TO_EXCEPTION`` table, ``_reraise_kernel_error`` attribute forwarding, the 6-entry ``_STATE_TO_COMMAND_STATE`` table, the no-open-session guards on every method, ``open_session`` double-open, ``parameters`` / ``query_tags`` rejection, ``get_columns``' catalog-required check, ``cancel_command`` / ``close_command`` no-handle tolerance, ``get_query_state`` sync-path SUCCEEDED, the Failed-state re-raise, the synthetic-command-id UUID shape, and ``close_session`` cleanup even when per-handle close errors fire. Uses a fake ``databricks_sql_kernel`` module installed into ``sys.modules`` so the test runs with no Rust extension dependency (F9). 77/77 kernel unit tests pass. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- conftest.py | 34 -- src/databricks/sql/backend/kernel/__init__.py | 4 +- .../sql/backend/kernel/auth_bridge.py | 6 +- src/databricks/sql/backend/kernel/client.py | 112 ++++- .../sql/backend/kernel/result_set.py | 14 + src/databricks/sql/client.py | 18 +- src/databricks/sql/session.py | 31 +- tests/e2e/test_kernel_backend.py | 14 +- tests/unit/test_kernel_client.py | 397 ++++++++++++++++++ 9 files changed, 550 insertions(+), 80 deletions(-) create mode 100644 tests/unit/test_kernel_client.py diff --git a/conftest.py b/conftest.py index 748f73443..c8b350bee 100644 --- a/conftest.py +++ b/conftest.py @@ -1,41 +1,7 @@ -import importlib.util import os import pytest -def _kernel_wheel_available() -> bool: - """The ``use_sea=True`` code path now routes through the Rust - kernel via PyO3. The ``databricks_sql_kernel`` wheel is not - yet on PyPI (built from a separate repo); CI environments - without it should skip ``use_sea=True`` parametrized cases - rather than fail with a hard ImportError.""" - return importlib.util.find_spec("databricks_sql_kernel") is not None - - -def pytest_collection_modifyitems(config, items): - """Skip parametrized test cases that pass ``use_sea=True`` when - the kernel wheel isn't installed. - - The existing e2e suite uses ``@pytest.mark.parametrize( - "extra_params", [{}, {"use_sea": True}])`` to exercise both - backends. When the kernel wheel is missing those cases die at - ``connect()`` time with our pointed ImportError; mark them - skipped at collection time so CI signal stays accurate. - """ - if _kernel_wheel_available(): - return - skip_marker = pytest.mark.skip( - reason="use_sea=True requires databricks-sql-kernel (not installed)" - ) - for item in items: - params = getattr(item, "callspec", None) - if params is None: - continue - extra_params = params.params.get("extra_params") - if isinstance(extra_params, dict) and extra_params.get("use_sea") is True: - item.add_marker(skip_marker) - - @pytest.fixture(scope="session") def host(): return os.getenv("DATABRICKS_SERVER_HOSTNAME") diff --git a/src/databricks/sql/backend/kernel/__init__.py b/src/databricks/sql/backend/kernel/__init__.py index 4a1ad8205..230af47f2 100644 --- a/src/databricks/sql/backend/kernel/__init__.py +++ b/src/databricks/sql/backend/kernel/__init__.py @@ -1,6 +1,6 @@ """Backend that delegates to the Databricks SQL Kernel (Rust) via PyO3. -Routed when ``use_sea=True`` is passed to ``databricks.sql.connect``. +Routed when ``use_kernel=True`` is passed to ``databricks.sql.connect``. The module's identity is "delegates to the kernel" — not the wire protocol the kernel happens to use today (SEA REST). The kernel may switch its default transport (SEA REST → SEA gRPC → …) without @@ -18,7 +18,7 @@ from databricks.sql.backend.kernel.client import KernelDatabricksClient ``session.py::_create_backend`` already does this lazy import under -the ``use_sea=True`` branch. +the ``use_kernel=True`` branch. See ``docs/designs/pysql-kernel-integration.md`` in ``databricks-sql-kernel`` for the full integration design. diff --git a/src/databricks/sql/backend/kernel/auth_bridge.py b/src/databricks/sql/backend/kernel/auth_bridge.py index 4721a3b04..01123b96c 100644 --- a/src/databricks/sql/backend/kernel/auth_bridge.py +++ b/src/databricks/sql/backend/kernel/auth_bridge.py @@ -105,7 +105,7 @@ def kernel_auth_kwargs(auth_provider: AuthProvider) -> Dict[str, Any]: return {"auth_type": "pat", "access_token": token} raise NotSupportedError( - f"The kernel backend (use_sea=True) currently only supports PAT auth, " - f"but got {type(auth_provider).__name__}. Use use_sea=False (Thrift) " - "for OAuth / federation / custom credential providers." + f"The kernel backend (use_kernel=True) currently only supports PAT auth, " + f"but got {type(auth_provider).__name__}. Use the Thrift backend " + "(default) for OAuth / federation / custom credential providers." ) diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index 6466070b4..2bc70c618 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -1,6 +1,6 @@ """``DatabricksClient`` backed by the Rust kernel via PyO3. -Routed when ``use_sea=True``. Constructor takes the connector's +Routed when ``use_kernel=True``. Constructor takes the connector's already-built ``auth_provider`` and forwards everything else to the kernel's ``Session``. Every kernel call goes through this thin wrapper; this module is the single seam between the connector's @@ -34,6 +34,7 @@ from __future__ import annotations import logging +import threading import uuid from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union @@ -71,7 +72,7 @@ # (doing so breaks `poetry lock`). Once published the install # hint will move to `pip install 'databricks-sql-connector[kernel]'`. raise ImportError( - "use_sea=True requires the databricks-sql-kernel package. Install it with:\n" + "use_kernel=True requires the databricks-sql-kernel package. Install it with:\n" " pip install databricks-sql-kernel\n" "or for local development from the kernel repo:\n" " cd databricks-sql-kernel/pyo3 && maturin develop --release" @@ -176,7 +177,10 @@ def __init__( self._session_id: Optional[SessionId] = None # Async-exec handles keyed by CommandId.guid. Populated by # ``execute_command(async_op=True)``; drained by ``close_command``. + # Guarded by ``_async_handles_lock`` so concurrent cursors on the + # same connection don't race on submit / close / close-session. self._async_handles: Dict[str, Any] = {} + self._async_handles_lock = threading.RLock() # ── Session lifecycle ────────────────────────────────────────── @@ -226,14 +230,16 @@ def close_session(self, session_id: SessionId) -> None: return # Close any tracked async handles first so they fire their # server-side CloseStatement before the session goes away. - for handle in list(self._async_handles.values()): + with self._async_handles_lock: + handles_to_close = list(self._async_handles.values()) + self._async_handles.clear() + for handle in handles_to_close: try: handle.close() except _kernel.KernelError as exc: logger.warning( "Error closing async handle during session close: %s", exc ) - self._async_handles.clear() try: self._kernel_session.close() except _kernel.KernelError as exc: @@ -280,7 +286,8 @@ def execute_command( async_exec = stmt.submit() command_id = CommandId.from_sea_statement_id(async_exec.statement_id) cursor.active_command_id = command_id - self._async_handles[command_id.guid] = async_exec + with self._async_handles_lock: + self._async_handles[command_id.guid] = async_exec return None executed = stmt.execute() except _kernel.KernelError as exc: @@ -300,7 +307,8 @@ def execute_command( return self._make_result_set(executed, cursor, command_id) def cancel_command(self, command_id: CommandId) -> None: - handle = self._async_handles.get(command_id.guid) + with self._async_handles_lock: + handle = self._async_handles.get(command_id.guid) if handle is None: # Sync-execute paths fully materialise the result before # ``execute_command`` returns, so by the time @@ -314,7 +322,8 @@ def cancel_command(self, command_id: CommandId) -> None: raise _reraise_kernel_error(exc) def close_command(self, command_id: CommandId) -> None: - handle = self._async_handles.pop(command_id.guid, None) + with self._async_handles_lock: + handle = self._async_handles.pop(command_id.guid, None) if handle is None: logger.debug("close_command: no tracked handle for %s", command_id) return @@ -324,7 +333,8 @@ def close_command(self, command_id: CommandId) -> None: raise _reraise_kernel_error(exc) def get_query_state(self, command_id: CommandId) -> CommandState: - handle = self._async_handles.get(command_id.guid) + with self._async_handles_lock: + handle = self._async_handles.get(command_id.guid) if handle is None: # No tracked async handle means execute_command ran # sync and the result was materialised before returning; @@ -347,7 +357,8 @@ def get_execution_result( command_id: CommandId, cursor: "Cursor", ) -> "ResultSet": - handle = self._async_handles.get(command_id.guid) + with self._async_handles_lock: + handle = self._async_handles.get(command_id.guid) if handle is None: raise ProgrammingError( "get_execution_result called for an unknown command_id; " @@ -438,16 +449,6 @@ def get_tables( ) -> "ResultSet": if self._kernel_session is None: raise InterfaceError("get_tables requires an open session.") - if table_types: - # Documented gap: native SEA backend filters here, but - # its filter is keyed on SeaResultSet. Day-1 we surface - # the unfiltered result; a small follow-up ports the - # filter to operate on KernelResultSet. - logger.warning( - "get_tables: client-side table_types filter not yet implemented " - "on the kernel backend; returning unfiltered rows for %r", - table_types, - ) try: stream = self._kernel_session.metadata().list_tables( catalog=catalog_name, @@ -457,7 +458,27 @@ def get_tables( ) except _kernel.KernelError as exc: raise _reraise_kernel_error(exc) - return self._make_result_set(stream, cursor, self._synthetic_command_id()) + if not table_types: + return self._make_result_set(stream, cursor, self._synthetic_command_id()) + # The kernel today returns the unfiltered ``SHOW TABLES`` shape + # regardless of ``table_types``. Drain to a single Arrow table + # and apply the same client-side filter the native SEA backend + # uses (column index 5 is TABLE_TYPE, case-sensitive). Cheap + # because metadata result sets are small. + from databricks.sql.backend.sea.utils.filters import ResultSetFilter + + full_table = _drain_kernel_handle(stream) + filtered_table = ResultSetFilter._filter_arrow_table( + full_table, + column_name=full_table.schema.field(5).name, + allowed_values=table_types, + case_sensitive=True, + ) + return self._make_result_set( + _StaticArrowHandle(filtered_table), + cursor, + self._synthetic_command_id(), + ) def get_columns( self, @@ -496,7 +517,7 @@ def get_columns( def max_download_threads(self) -> int: # CloudFetch parallelism lives kernel-side. This property is # consulted by Thrift code paths that don't run for - # use_sea=True; return a non-zero default so anything that + # use_kernel=True; return a non-zero default so anything that # peeks at it does not divide by zero. return 10 @@ -509,3 +530,52 @@ def max_download_threads(self) -> int: "Cancelled": CommandState.CANCELLED, "Closed": CommandState.CLOSED, } + + +def _drain_kernel_handle(handle: Any) -> Any: + """Drain a kernel ResultStream / ExecutedStatement into a single + ``pyarrow.Table``. Used by ``get_tables`` to apply a client-side + ``table_types`` filter on a metadata result; cheap because + metadata streams are small.""" + import pyarrow + + schema = handle.arrow_schema() + batches = [] + while True: + batch = handle.fetch_next_batch() + if batch is None: + break + if batch.num_rows > 0: + batches.append(batch) + try: + handle.close() + except _kernel.KernelError: + pass + return pyarrow.Table.from_batches(batches, schema=schema) + + +class _StaticArrowHandle: + """Duck-typed kernel handle that replays a pre-built + ``pyarrow.Table`` through ``arrow_schema()`` / + ``fetch_next_batch()`` / ``close()``. Used to wrap a + post-processed table (e.g., the ``table_types``-filtered output + of ``get_tables``) so it flows back through the normal + ``KernelResultSet`` path.""" + + def __init__(self, table: Any) -> None: + self._schema = table.schema + self._batches = list(table.to_batches()) + self._idx = 0 + + def arrow_schema(self) -> Any: + return self._schema + + def fetch_next_batch(self) -> Optional[Any]: + if self._idx >= len(self._batches): + return None + batch = self._batches[self._idx] + self._idx += 1 + return batch + + def close(self) -> None: + self._batches = [] diff --git a/src/databricks/sql/backend/kernel/result_set.py b/src/databricks/sql/backend/kernel/result_set.py index 40181f236..2cc665656 100644 --- a/src/databricks/sql/backend/kernel/result_set.py +++ b/src/databricks/sql/backend/kernel/result_set.py @@ -226,7 +226,21 @@ def close(self) -> None: # level; log and swallow so the cursor's __del__ / # connection close path stays clean. logger.warning("Error closing kernel handle: %s", exc) + # Drop the entry from the backend's async-handle map (if + # present) — for async-submitted statements the handle is + # tracked there and the base ``ResultSet.close`` path would + # otherwise leave a stale entry pointing at a closed handle. + # No-op for the sync-execute and metadata paths, which never + # register in ``_async_handles``. + guid = getattr(self.command_id, "guid", None) + if guid is not None: + self.backend._async_handles_lock.acquire() + try: + self.backend._async_handles.pop(guid, None) + finally: + self.backend._async_handles_lock.release() self._buffer.clear() + self._buffered_count = 0 self._kernel_handle = None self._exhausted = True self.has_been_closed_server_side = True diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index fe52f0c79..e3c25fe65 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -115,7 +115,17 @@ def __init__( Parameters: :param use_sea: `bool`, optional (default is False) - Use the SEA backend instead of the Thrift backend. + Use the native pure-Python SEA backend instead of + the Thrift backend. + :param use_kernel: `bool`, optional (default is False) + Route the connection through the Rust kernel + (``databricks-sql-kernel`` via PyO3). Requires the + kernel wheel to be installed separately + (``pip install databricks-sql-kernel``); raises + ImportError otherwise. In active development — + PAT auth only today; OAuth / federation / external + credentials and native parameter binding land in + follow-ups. Mutually exclusive with ``use_sea``. :param use_hybrid_disposition: `bool`, optional (default is False) Use the hybrid disposition instead of the inline disposition. :param server_hostname: Databricks instance host name. @@ -1575,6 +1585,12 @@ def columns( Get columns corresponding to the catalog_name, schema_name, table_name and column_name. Names can contain % wildcards. + + Note: on ``use_kernel=True``, ``catalog_name`` is required — + the kernel's underlying ``SHOW COLUMNS`` cannot span catalogs. + Passing ``catalog_name=None`` raises ``ProgrammingError``. The + Thrift and native SEA backends accept ``catalog_name=None``. + :returns self """ self._check_not_closed() diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index be2bdb4c2..97790e4d9 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -9,6 +9,7 @@ from databricks.sql import __version__ from databricks.sql import USER_AGENT_NAME from databricks.sql.backend.thrift_backend import ThriftDatabricksClient +from databricks.sql.backend.sea.backend import SeaDatabricksClient from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.types import SessionId, BackendType from databricks.sql.common.unified_http_client import UnifiedHttpClient @@ -121,22 +122,21 @@ def _create_backend( ) -> DatabricksClient: """Create and return the appropriate backend client.""" self.use_sea = kwargs.get("use_sea", False) + self.use_kernel = kwargs.get("use_kernel", False) - if self.use_sea: - # `use_sea=True` now routes through the Rust kernel via - # PyO3. The native pure-Python SEA backend - # (`backend/sea/`) is no longer reachable through this - # flag; whether it's removed is tracked separately. See - # `docs/designs/pysql-kernel-integration.md` in the - # databricks-sql-kernel repo. - # + if self.use_kernel and self.use_sea: + raise ValueError( + "use_kernel and use_sea are mutually exclusive — pick one." + ) + + if self.use_kernel: # Lazy import so the connector doesn't ImportError at # startup when the kernel wheel isn't installed — the # error surfaces only when a caller actually requests - # use_sea=True. + # use_kernel=True. from databricks.sql.backend.kernel.client import KernelDatabricksClient - logger.debug("Creating kernel-backed client for use_sea=True") + logger.debug("Creating kernel-backed client for use_kernel=True") return KernelDatabricksClient( server_hostname=server_hostname, http_path=http_path, @@ -148,7 +148,14 @@ def _create_backend( schema=kwargs.get("schema"), ) - logger.debug("Creating Thrift backend client") + databricks_client_class: Type[DatabricksClient] + if self.use_sea: + logger.debug("Creating SEA backend client") + databricks_client_class = SeaDatabricksClient + else: + logger.debug("Creating Thrift backend client") + databricks_client_class = ThriftDatabricksClient + common_args = { "server_hostname": server_hostname, "port": self.port, @@ -160,7 +167,7 @@ def _create_backend( "_use_arrow_native_complex_types": _use_arrow_native_complex_types, **kwargs, } - return ThriftDatabricksClient(**common_args) + return databricks_client_class(**common_args) @staticmethod def _extract_spog_headers(http_path, existing_headers): diff --git a/tests/e2e/test_kernel_backend.py b/tests/e2e/test_kernel_backend.py index 32b1e94d6..0c0722b91 100644 --- a/tests/e2e/test_kernel_backend.py +++ b/tests/e2e/test_kernel_backend.py @@ -1,4 +1,4 @@ -"""E2E tests for ``use_sea=True`` (routes through the Rust kernel +"""E2E tests for ``use_kernel=True`` (routes through the Rust kernel via the PyO3 ``databricks_sql_kernel`` module). PAT auth only. Anything else surfaces as ``NotSupportedError`` @@ -8,8 +8,8 @@ - The standard ``DATABRICKS_SERVER_HOSTNAME`` / ``HTTP_PATH`` / ``TOKEN`` creds aren't set (existing connector convention). - ``databricks_sql_kernel`` isn't importable (the wheel hasn't - been installed; run ``pip install - 'databricks-sql-connector[kernel]'`` or, for local dev, + been installed; run ``pip install databricks-sql-kernel`` or, + for local dev, ``cd databricks-sql-kernel/pyo3 && maturin develop --release`` into this venv). @@ -30,13 +30,13 @@ # Skip the whole module unless the kernel wheel is importable. pytest.importorskip( "databricks_sql_kernel", - reason="use_sea=True requires the databricks-sql-kernel package", + reason="use_kernel=True requires the databricks-sql-kernel package", ) @pytest.fixture(scope="module") def kernel_conn_params(connection_details): - """Live-cred check + connection params for use_sea=True. + """Live-cred check + connection params for use_kernel=True. Skips the module if any cred is missing rather than letting every test fail with a confusing connect-time error. @@ -53,7 +53,7 @@ def kernel_conn_params(connection_details): "server_hostname": host, "http_path": http_path, "access_token": token, - "use_sea": True, + "use_kernel": True, } @@ -68,7 +68,7 @@ def conn(kernel_conn_params): c.close() -def test_connect_with_use_sea_opens_a_session(conn): +def test_connect_with_use_kernel_opens_a_session(conn): assert conn.open, "connection should report open after connect()" diff --git a/tests/unit/test_kernel_client.py b/tests/unit/test_kernel_client.py new file mode 100644 index 000000000..b23365c6e --- /dev/null +++ b/tests/unit/test_kernel_client.py @@ -0,0 +1,397 @@ +"""Unit tests for ``KernelDatabricksClient`` — the error mapping, +state-mapping, async-handle bookkeeping, and method-level guards +that don't require a live kernel session. + +The connector's ``databricks.sql.backend.kernel.client`` module +imports the ``databricks_sql_kernel`` extension at import time, so +this test installs a fake module into ``sys.modules`` *before* +importing the client. The fake exposes the minimum surface the +client touches (``Session``, ``KernelError``, ``Statement``, +``ExecutedStatement``, ``ExecutedAsyncStatement``, ``ResultStream``, +``metadata``). +""" + +from __future__ import annotations + +import sys +import types +from typing import Optional +from unittest.mock import MagicMock + +import pytest + +# pyarrow is an optional dep; the kernel client's result_set imports +# it eagerly, so the whole module must skip when pyarrow is missing. +pa = pytest.importorskip("pyarrow") + + +# --------------------------------------------------------------------------- +# Fake databricks_sql_kernel module — installed before client.py imports. +# --------------------------------------------------------------------------- + + +class _FakeKernelError(Exception): + """Stand-in for ``databricks_sql_kernel.KernelError``. Carries + the structured attrs the connector forwards onto the re-raised + PEP 249 exception.""" + + def __init__( + self, + code: str = "Unknown", + message: str = "boom", + sql_state: Optional[str] = None, + query_id: Optional[str] = None, + ) -> None: + super().__init__(message) + self.code = code + self.message = message + self.sql_state = sql_state + self.error_code = None + self.vendor_code = None + self.http_status = None + self.retryable = False + self.query_id = query_id + + +_fake_kernel_module = types.ModuleType("databricks_sql_kernel") +_fake_kernel_module.KernelError = _FakeKernelError # type: ignore[attr-defined] +_fake_kernel_module.Session = MagicMock() # type: ignore[attr-defined] +sys.modules.setdefault("databricks_sql_kernel", _fake_kernel_module) + + +# Importing the client now picks up the fake module via +# ``import databricks_sql_kernel as _kernel`` at the top of client.py. +from databricks.sql.auth.authenticators import AccessTokenAuthProvider +from databricks.sql.backend.kernel import client as kernel_client +from databricks.sql.backend.types import CommandId, CommandState +from databricks.sql.exc import ( + DatabaseError, + InterfaceError, + NotSupportedError, + OperationalError, + ProgrammingError, +) + + +# --------------------------------------------------------------------------- +# Error mapping +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "code, expected_cls", + [ + ("InvalidArgument", ProgrammingError), + ("Unauthenticated", OperationalError), + ("PermissionDenied", OperationalError), + ("NotFound", ProgrammingError), + ("ResourceExhausted", OperationalError), + ("Unavailable", OperationalError), + ("Timeout", OperationalError), + ("Cancelled", OperationalError), + ("DataLoss", DatabaseError), + ("Internal", DatabaseError), + ("InvalidStatementHandle", ProgrammingError), + ("NetworkError", OperationalError), + ("SqlError", DatabaseError), + ("Unknown", DatabaseError), + ], +) +def test_code_to_exception_mapping(code, expected_cls): + """Every entry in ``_CODE_TO_EXCEPTION`` maps to the documented + PEP 249 class.""" + err = _FakeKernelError(code=code, message=f"{code} boom") + out = kernel_client._reraise_kernel_error(err) + assert isinstance(out, expected_cls) + assert "boom" in str(out) + assert out.__cause__ is err + + +def test_unknown_code_falls_back_to_database_error(): + err = _FakeKernelError(code="SomethingNew", message="…") + out = kernel_client._reraise_kernel_error(err) + assert isinstance(out, DatabaseError) + + +def test_reraise_forwards_structured_attributes(): + err = _FakeKernelError( + code="SqlError", + message="table not found", + sql_state="42P01", + query_id="q-123", + ) + out = kernel_client._reraise_kernel_error(err) + assert out.code == "SqlError" + assert out.sql_state == "42P01" + assert out.query_id == "q-123" + # Optional fields default to None on the source exception and + # come through verbatim on the re-raised side. + for attr in ("error_code", "vendor_code", "http_status"): + assert getattr(out, attr) is None + assert out.retryable is False + + +# --------------------------------------------------------------------------- +# State mapping +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "kernel_state, expected", + [ + ("Pending", CommandState.PENDING), + ("Running", CommandState.RUNNING), + ("Succeeded", CommandState.SUCCEEDED), + ("Failed", CommandState.FAILED), + ("Cancelled", CommandState.CANCELLED), + ("Closed", CommandState.CLOSED), + ], +) +def test_state_to_command_state_mapping(kernel_state, expected): + assert kernel_client._STATE_TO_COMMAND_STATE[kernel_state] == expected + + +# --------------------------------------------------------------------------- +# Client lifecycle / guards (no live session) +# --------------------------------------------------------------------------- + + +def _make_client() -> kernel_client.KernelDatabricksClient: + """Build a client with a PAT auth provider; the kernel ``Session`` + isn't opened until ``open_session`` runs.""" + return kernel_client.KernelDatabricksClient( + server_hostname="example.cloud.databricks.com", + http_path="/sql/1.0/warehouses/abc", + auth_provider=AccessTokenAuthProvider("dapi-test"), + ssl_options=None, + ) + + +def test_no_open_session_guards_raise_interface_error(): + """Every method that depends on an open kernel session must + raise ``InterfaceError`` before any kernel call.""" + c = _make_client() + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + + with pytest.raises(InterfaceError, match="open session"): + c.execute_command( + operation="SELECT 1", + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + for method, kwargs in [ + ("get_catalogs", {}), + ("get_schemas", {}), + ("get_tables", {}), + ("get_columns", {"catalog_name": "main"}), + ]: + with pytest.raises(InterfaceError): + getattr(c, method)( + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + cursor=cursor, + **kwargs, + ) + + +def test_open_session_rejects_double_open(monkeypatch): + """Two ``open_session`` calls on the same client must fail — + the kernel session is bound to a single open call.""" + c = _make_client() + c._kernel_session = MagicMock() # pretend already open + with pytest.raises(InterfaceError, match="already has an open session"): + c.open_session(session_configuration=None, catalog=None, schema=None) + + +def test_execute_command_rejects_parameters(): + c = _make_client() + c._kernel_session = MagicMock() + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + with pytest.raises(NotSupportedError, match="Parameter binding"): + c.execute_command( + operation="SELECT ?", + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[object()], # any non-empty list + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + +def test_execute_command_rejects_query_tags(): + c = _make_client() + c._kernel_session = MagicMock() + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + with pytest.raises(NotSupportedError, match="query_tags"): + c.execute_command( + operation="SELECT 1", + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + query_tags={"team": "x"}, + ) + + +def test_get_columns_requires_catalog(): + c = _make_client() + c._kernel_session = MagicMock() + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + with pytest.raises(ProgrammingError, match="catalog_name"): + c.get_columns( + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + cursor=cursor, + catalog_name=None, + ) + + +# --------------------------------------------------------------------------- +# Async handle bookkeeping +# --------------------------------------------------------------------------- + + +def test_cancel_command_tolerant_when_handle_missing(): + """``cancel_command`` is documented to be a no-op when there's + no tracked async handle (matches Thrift's tolerance).""" + c = _make_client() + fake_command_id = CommandId.from_sea_statement_id("not-tracked") + c.cancel_command(fake_command_id) # must not raise + + +def test_close_command_tolerant_when_handle_missing(): + c = _make_client() + fake_command_id = CommandId.from_sea_statement_id("not-tracked") + c.close_command(fake_command_id) # must not raise + + +def test_get_query_state_returns_succeeded_when_handle_missing(): + """Sync-execute paths never register an async handle; by the + time ``get_query_state`` could be called the command is + terminal-by-construction. The client returns SUCCEEDED so the + cursor's polling loop terminates cleanly.""" + c = _make_client() + fake_command_id = CommandId.from_sea_statement_id("sync-only") + assert c.get_query_state(fake_command_id) == CommandState.SUCCEEDED + + +def test_get_execution_result_raises_for_unknown_command_id(): + """The kernel backend only tracks async-submitted statements; + a ``get_execution_result`` call for an unknown id is a + programming error.""" + c = _make_client() + fake_command_id = CommandId.from_sea_statement_id("unknown") + with pytest.raises(ProgrammingError, match="unknown command_id"): + c.get_execution_result(fake_command_id, cursor=MagicMock()) + + +def test_cancel_command_reraises_kernel_error(): + c = _make_client() + fake_handle = MagicMock() + fake_handle.cancel.side_effect = _FakeKernelError(code="Unavailable") + cid = CommandId.from_sea_statement_id("abc") + c._async_handles[cid.guid] = fake_handle + with pytest.raises(OperationalError): + c.cancel_command(cid) + + +def test_close_command_reraises_kernel_error(): + c = _make_client() + fake_handle = MagicMock() + fake_handle.close.side_effect = _FakeKernelError(code="Internal") + cid = CommandId.from_sea_statement_id("abc") + c._async_handles[cid.guid] = fake_handle + with pytest.raises(DatabaseError): + c.close_command(cid) + # The handle is popped before the kernel call, so a subsequent + # close_command is tolerantly a no-op. + c.close_command(cid) + + +def test_get_query_state_raises_on_failed_state_with_failure(): + c = _make_client() + fake_handle = MagicMock() + fake_handle.status.return_value = ( + "Failed", + _FakeKernelError(code="SqlError", message="bad"), + ) + cid = CommandId.from_sea_statement_id("abc") + c._async_handles[cid.guid] = fake_handle + with pytest.raises(DatabaseError, match="bad"): + c.get_query_state(cid) + + +def test_get_query_state_returns_state_when_no_failure(): + c = _make_client() + fake_handle = MagicMock() + fake_handle.status.return_value = ("Running", None) + cid = CommandId.from_sea_statement_id("abc") + c._async_handles[cid.guid] = fake_handle + assert c.get_query_state(cid) == CommandState.RUNNING + + +# --------------------------------------------------------------------------- +# Misc +# --------------------------------------------------------------------------- + + +def test_max_download_threads_is_nonzero(): + """Property is consulted by Thrift code paths that don't run for + ``use_kernel=True``; a non-zero default avoids divide-by-zero.""" + c = _make_client() + assert c.max_download_threads > 0 + + +def test_synthetic_command_id_is_uuid_shaped(): + """Synthetic metadata command IDs are plain hex UUIDs (no + ``metadata-`` prefix) so anything reading ``cursor.query_id`` + downstream sees a parseable shape.""" + c = _make_client() + cid = c._synthetic_command_id() + # 32-char lowercase hex + assert len(cid.guid) == 32 + int(cid.guid, 16) # raises if non-hex + + +def test_close_session_clears_async_handles_even_if_close_fails(): + """Per-handle close errors are logged but don't prevent the + rest of the close-session sweep from completing, and the dict + is cleared either way.""" + c = _make_client() + good = MagicMock() + bad = MagicMock() + bad.close.side_effect = _FakeKernelError(code="Unavailable") + c._async_handles["a"] = good + c._async_handles["b"] = bad + c._kernel_session = MagicMock() + c.close_session(MagicMock()) + assert c._async_handles == {} + assert good.close.called + assert bad.close.called From 5a6f1f0a9e2f724d3eabf98d320bbc9ae99d4c4c Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 18 May 2026 05:36:05 +0000 Subject: [PATCH 10/16] =?UTF-8?q?fix(backend/kernel):=20CI-greening=20?= =?UTF-8?q?=E2=80=94=20mypy=20+=20e2e=20module=20skip?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - src/databricks/sql/backend/kernel/result_set.py: fix the 3 mypy errors at L237/239/241 by casting ``self.backend`` to ``KernelDatabricksClient`` (the base ``DatabricksClient`` doesn't declare ``_async_handles`` / ``_async_handles_lock``). Folds in gopalldb's nit (3249904284) — replace the explicit ``acquire()/try/finally/release()`` with a ``with`` block to match the rest of the file. - tests/e2e/test_kernel_backend.py: harden the module-level skip so the suite doesn't run when the kernel wheel is absent in CI. The unit suite installs a fake ``databricks_sql_kernel`` ``ModuleType`` into ``sys.modules`` so the connector's import-time ``import databricks_sql_kernel`` succeeds without the Rust extension; that fake leaks across into the same pytest session and ``pytest.importorskip`` happily returns it. A real wheel exposes ``__file__`` (compiled extension on disk); the fake does not. Skip the module when ``__file__`` is missing. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- src/databricks/sql/backend/kernel/result_set.py | 10 ++++------ tests/e2e/test_kernel_backend.py | 17 +++++++++++++++-- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/backend/kernel/result_set.py b/src/databricks/sql/backend/kernel/result_set.py index 2cc665656..3aaaf7696 100644 --- a/src/databricks/sql/backend/kernel/result_set.py +++ b/src/databricks/sql/backend/kernel/result_set.py @@ -32,7 +32,7 @@ import logging from collections import deque -from typing import Any, Deque, List, Optional, TYPE_CHECKING +from typing import Any, Deque, List, Optional, TYPE_CHECKING, cast import pyarrow @@ -234,11 +234,9 @@ def close(self) -> None: # register in ``_async_handles``. guid = getattr(self.command_id, "guid", None) if guid is not None: - self.backend._async_handles_lock.acquire() - try: - self.backend._async_handles.pop(guid, None) - finally: - self.backend._async_handles_lock.release() + backend = cast("KernelDatabricksClient", self.backend) + with backend._async_handles_lock: + backend._async_handles.pop(guid, None) self._buffer.clear() self._buffered_count = 0 self._kernel_handle = None diff --git a/tests/e2e/test_kernel_backend.py b/tests/e2e/test_kernel_backend.py index 0c0722b91..67f6e858d 100644 --- a/tests/e2e/test_kernel_backend.py +++ b/tests/e2e/test_kernel_backend.py @@ -27,11 +27,24 @@ from databricks.sql.exc import DatabaseError -# Skip the whole module unless the kernel wheel is importable. -pytest.importorskip( +# Skip the whole module unless the kernel wheel is genuinely installed. +# ``pytest.importorskip`` alone isn't enough: the kernel unit tests inject a +# fake ``databricks_sql_kernel`` ModuleType into ``sys.modules`` so the +# connector's import-time ``import databricks_sql_kernel`` succeeds without +# the Rust extension. In the same pytest session that fake module is still +# in ``sys.modules`` when this e2e file is collected, and importorskip +# happily returns it. A real wheel exposes ``__file__`` (the compiled +# extension on disk); the fake ModuleType does not. +_kernel_mod = pytest.importorskip( "databricks_sql_kernel", reason="use_kernel=True requires the databricks-sql-kernel package", ) +if not getattr(_kernel_mod, "__file__", None): + pytest.skip( + "databricks_sql_kernel is a test stub (no __file__); " + "install the real wheel to run kernel e2e tests", + allow_module_level=True, + ) @pytest.fixture(scope="module") From da2bc446ed6bd32616d481a31c43008bd98146b1 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 18 May 2026 05:41:25 +0000 Subject: [PATCH 11/16] fix(backend/kernel): address gopalldb minor review comments (m1, m4) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit m1 — install hint (comment 3249904266): The ``databricks-sql-kernel`` wheel is not yet published on PyPI; ``pip install databricks-sql-kernel`` either finds nothing or pulls a squatted package. Drop the misleading hint from ``ImportError`` and from the ``use_kernel`` docstring on ``databricks.sql.connect``; point users at the ``maturin develop --release`` dev path until the wheel ships. m4a — auth_bridge ValueError → ProgrammingError (comment 3249904276): Two sites in ``_extract_bearer_token`` / ``kernel_auth_kwargs`` were raising bare ``ValueError`` for caller-misuse cases (control chars in the token, PAT provider that produced no Authorization header). The rest of the kernel-backend error surface uses PEP 249 exception types — code paths that catch ``DatabaseError`` / ``ProgrammingError`` would miss these. Convert to ``ProgrammingError`` and update the unit test. m4b — description null_ok (comment 3249904282): ``description_from_arrow_schema`` was hardcoding the 7th tuple element to ``None`` even though ``pyarrow.Field.nullable`` is available. PEP 249 §Cursor.description defines ``null_ok`` as "True if NULL values are allowed"; callers branching on it would have lost useful information the kernel already provides. Now emits ``field.nullable``; added a unit test covering both nullable and non-nullable fields; updated the two existing tests that asserted the old all-``None`` shape. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- .../sql/backend/kernel/auth_bridge.py | 6 +++--- src/databricks/sql/backend/kernel/client.py | 19 ++++++++++------- .../sql/backend/kernel/type_mapping.py | 8 +++---- src/databricks/sql/client.py | 16 ++++++++------ tests/unit/test_kernel_auth_bridge.py | 9 ++++---- tests/unit/test_kernel_result_set.py | 3 ++- tests/unit/test_kernel_type_mapping.py | 21 +++++++++++++++---- 7 files changed, 52 insertions(+), 30 deletions(-) diff --git a/src/databricks/sql/backend/kernel/auth_bridge.py b/src/databricks/sql/backend/kernel/auth_bridge.py index 01123b96c..f382284d2 100644 --- a/src/databricks/sql/backend/kernel/auth_bridge.py +++ b/src/databricks/sql/backend/kernel/auth_bridge.py @@ -24,7 +24,7 @@ from databricks.sql.auth.authenticators import AccessTokenAuthProvider, AuthProvider from databricks.sql.auth.token_federation import TokenFederationProvider -from databricks.sql.exc import NotSupportedError +from databricks.sql.exc import NotSupportedError, ProgrammingError logger = logging.getLogger(__name__) @@ -79,7 +79,7 @@ def _extract_bearer_token(auth_provider: AuthProvider) -> Optional[str]: return None token = auth[len(_BEARER_PREFIX) :] if _CONTROL_CHAR_RE.search(token): - raise ValueError( + raise ProgrammingError( "Bearer token contains ASCII control characters; refusing to " "forward it to the kernel auth bridge." ) @@ -98,7 +98,7 @@ def kernel_auth_kwargs(auth_provider: AuthProvider) -> Dict[str, Any]: if _is_pat(auth_provider): token = _extract_bearer_token(auth_provider) if not token: - raise ValueError( + raise ProgrammingError( "PAT auth provider did not produce a Bearer Authorization " "header; cannot route through the kernel's PAT path" ) diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index 2bc70c618..1b88cde60 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -67,15 +67,18 @@ try: import databricks_sql_kernel as _kernel # type: ignore[import-not-found] except ImportError as exc: # pragma: no cover - import-time error surfaces clearly - # The `databricks-sql-kernel` wheel is not yet on PyPI, so we - # don't yet declare it as an optional extra in pyproject.toml - # (doing so breaks `poetry lock`). Once published the install - # hint will move to `pip install 'databricks-sql-connector[kernel]'`. + # The ``databricks-sql-kernel`` wheel is not yet on PyPI, so the + # dev-install path is the only working one today. ``pip install + # databricks-sql-kernel`` would either find nothing or pull a + # squatted package, so we deliberately do not suggest it. Once + # the wheel is published the hint will move to + # ``pip install 'databricks-sql-connector[kernel]'``. raise ImportError( - "use_kernel=True requires the databricks-sql-kernel package. Install it with:\n" - " pip install databricks-sql-kernel\n" - "or for local development from the kernel repo:\n" - " cd databricks-sql-kernel/pyo3 && maturin develop --release" + "use_kernel=True requires the databricks-sql-kernel extension, which " + "is not yet published on PyPI. Build and install it locally from the " + "databricks-sql-kernel repo:\n" + " cd databricks-sql-kernel/pyo3 && maturin develop --release\n" + "(into the same venv as databricks-sql-connector)." ) from exc diff --git a/src/databricks/sql/backend/kernel/type_mapping.py b/src/databricks/sql/backend/kernel/type_mapping.py index a91160d17..83e55ed55 100644 --- a/src/databricks/sql/backend/kernel/type_mapping.py +++ b/src/databricks/sql/backend/kernel/type_mapping.py @@ -61,9 +61,9 @@ def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]: """Build a PEP 249 ``description`` list from a pyarrow Schema. Each tuple is ``(name, type_code, display_size, internal_size, - precision, scale, null_ok)``. The kernel does not report the - last five so they're all ``None`` — same shape the existing - ADBC / Thrift result paths produce. + precision, scale, null_ok)``. ``null_ok`` is taken from + ``field.nullable``; the other four are not reported by the + kernel today. """ return [ ( @@ -73,7 +73,7 @@ def description_from_arrow_schema(schema: pyarrow.Schema) -> List[Tuple]: None, None, None, - None, + field.nullable, ) for field in schema ] diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index e3c25fe65..7fc815cd8 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -120,12 +120,16 @@ def __init__( :param use_kernel: `bool`, optional (default is False) Route the connection through the Rust kernel (``databricks-sql-kernel`` via PyO3). Requires the - kernel wheel to be installed separately - (``pip install databricks-sql-kernel``); raises - ImportError otherwise. In active development — - PAT auth only today; OAuth / federation / external - credentials and native parameter binding land in - follow-ups. Mutually exclusive with ``use_sea``. + kernel extension to be installed separately — the + wheel is not yet published on PyPI, so today the + only supported install path is a local + ``maturin develop --release`` build from the + ``databricks-sql-kernel`` repo into the same venv. + Raises ``ImportError`` if the extension is not + available. In active development — PAT auth only + today; OAuth / federation / external credentials + and native parameter binding land in follow-ups. + Mutually exclusive with ``use_sea``. :param use_hybrid_disposition: `bool`, optional (default is False) Use the hybrid disposition instead of the inline disposition. :param server_hostname: Databricks instance host name. diff --git a/tests/unit/test_kernel_auth_bridge.py b/tests/unit/test_kernel_auth_bridge.py index a5e2e756b..dfad26ede 100644 --- a/tests/unit/test_kernel_auth_bridge.py +++ b/tests/unit/test_kernel_auth_bridge.py @@ -31,7 +31,7 @@ _extract_bearer_token, kernel_auth_kwargs, ) -from databricks.sql.exc import NotSupportedError +from databricks.sql.exc import NotSupportedError, ProgrammingError class _FakeOAuthProvider(AuthProvider): @@ -97,13 +97,14 @@ def test_federation_wrapped_pat_routes_to_kernel_pat(self): kwargs = kernel_auth_kwargs(federated) assert kwargs == {"auth_type": "pat", "access_token": "dapi-abc"} - def test_pat_with_silent_provider_raises_value_error(self): + def test_pat_with_silent_provider_raises_programming_error(self): """An AccessTokenAuthProvider that produces no Authorization header is misconfigured; surface that at bridge-build time, - not on the first kernel HTTP request.""" + not on the first kernel HTTP request. ``ProgrammingError`` so + the bridge's error surface is uniformly PEP 249.""" broken = AccessTokenAuthProvider("dapi-x") broken.add_headers = lambda h: None # type: ignore[method-assign] - with pytest.raises(ValueError, match="Bearer"): + with pytest.raises(ProgrammingError, match="Bearer"): kernel_auth_kwargs(broken) def test_generic_oauth_provider_raises_not_supported(self): diff --git a/tests/unit/test_kernel_result_set.py b/tests/unit/test_kernel_result_set.py index c83bfce94..2078441c4 100644 --- a/tests/unit/test_kernel_result_set.py +++ b/tests/unit/test_kernel_result_set.py @@ -77,7 +77,8 @@ def int_schema(): def test_description_built_from_kernel_schema(int_schema): handle = _FakeKernelHandle(int_schema, []) rs = _make_rs(handle) - assert rs.description == [("n", "bigint", None, None, None, None, None)] + # null_ok slot reflects pyarrow.Field.nullable (True by default). + assert rs.description == [("n", "bigint", None, None, None, None, True)] def test_fetchall_arrow_drains_all_batches(int_schema): diff --git a/tests/unit/test_kernel_type_mapping.py b/tests/unit/test_kernel_type_mapping.py index 5ab5bde74..82f62559a 100644 --- a/tests/unit/test_kernel_type_mapping.py +++ b/tests/unit/test_kernel_type_mapping.py @@ -65,9 +65,22 @@ def test_description_from_schema_preserves_field_names_and_order(): ("name", "string"), ("created_at", "timestamp"), ] - # PEP 249 says all 7-tuples; the last 5 slots are None for the - # kernel backend (we don't report display_size / precision / - # scale / nullability). + # PEP 249 says 7-tuples. We don't report display_size / + # internal_size / precision / scale (all None); ``null_ok`` is + # taken from ``pyarrow.Field.nullable`` — True by default for + # schemas built from (name, type) pairs. for d in desc: assert len(d) == 7 - assert d[2:] == (None, None, None, None, None) + assert d[2:] == (None, None, None, None, True) + + +def test_description_from_schema_reports_non_nullable_fields(): + schema = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("name", pa.string(), nullable=True), + ] + ) + desc = description_from_arrow_schema(schema) + assert desc[0][6] is False + assert desc[1][6] is True From 1d0f7b6596843fa3cf5e8b0d78973ec8e9e15ad5 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 18 May 2026 05:51:58 +0000 Subject: [PATCH 12/16] =?UTF-8?q?fix(backend/kernel):=20substantive=20revi?= =?UTF-8?q?ew=20fixes=20=E2=80=94=20M1,=20M2,=20M3,=20m2,=20m3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses gopalldb's four major / two minor remaining review comments. The shared error-mapping primitives move to a new ``_errors.py`` module so both ``client.py`` and ``result_set.py`` can use them without ``result_set.py`` importing from ``client.py``. M1 — async handle leak in get_execution_result (3249904251): ``ResultStream`` from ``await_result()`` is wrapped in ``KernelResultSet``; the underlying ``ExecutedAsyncStatement`` has no further role once the stream is in hand. Close it immediately, drop the entry from ``_async_handles``, and add the guid to ``_closed_commands``. A failed ``async_exec.close()`` is logged but doesn't break the result-set return — the kernel's Drop impl reaps server-side state. M2 — PyO3 native exceptions wrapped as OperationalError (3249904255): Added ``kernel_call(what)`` context manager (in the new ``_errors.py``). ``KernelError`` flows through ``reraise_kernel_error`` as before; anything else (``TypeError`` / ``OverflowError`` / ``ValueError`` from PyO3 argument conversion, extension-internal errors) is wrapped in ``OperationalError`` so DB-API callers only ever see PEP 249 exception types. Applied at every PyO3 call site: ``open_session``, ``execute_command``, ``cancel_command``, ``close_command``, ``get_query_state``, ``get_execution_result``, ``get_catalogs`` / ``get_schemas`` / ``get_tables`` / ``get_columns``, plus ``fetch_next_batch`` / ``arrow_schema`` in ``KernelResultSet``. M3 — KernelError during result-set construction (3249904259): ``KernelResultSet.__init__`` calls ``kernel_handle.arrow_schema()`` which can itself raise. Every call to ``_make_result_set`` is now inside a ``kernel_call`` scope so the schema-fetch error becomes a mapped PEP 249 exception instead of leaking raw ``KernelError``. m3 — get_query_state of a closed async command (3249904273): Added ``_closed_commands: Set[str]`` (guarded by the existing ``_async_handles_lock``). ``close_command`` records the guid; ``close_session`` records every swept guid; ``get_execution_result`` records its own command after closing the async_exec. ``get_query_state`` now returns ``CommandState.CLOSED`` instead of falling through to ``SUCCEEDED`` for these. m2 — unit test for get_tables table_types client-side filter (3249904269): Added ``test_get_tables_with_table_types_filters_rows`` and ``test_get_tables_without_table_types_returns_full_stream`` in ``tests/unit/test_kernel_client.py``. The first feeds a fake stream with mixed ``TABLE`` / ``VIEW`` rows and asserts only ``TABLE`` survives; the second confirms the no-filter path bypasses the drain-and-rewrap and returns all rows unchanged. Plus new tests for every change above: - test_pyo3_native_exception_wrapped_as_operational_error (M2) - test_pyo3_native_exception_wrapped_for_metadata_calls (M2) - test_kernel_error_during_result_set_construction_is_mapped (M3) - test_get_execution_result_closes_async_exec_and_drops_tracking (M1) - test_get_execution_result_does_not_raise_on_async_exec_close_failure (M1) - test_get_query_state_returns_closed_after_close_command (m3) - test_close_session_marks_swept_handles_as_closed (m3) 87/87 kernel unit tests pass (added 9). Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- src/databricks/sql/backend/kernel/_errors.py | 124 ++++++++ src/databricks/sql/backend/kernel/client.py | 289 ++++++++---------- .../sql/backend/kernel/result_set.py | 10 +- tests/unit/test_kernel_client.py | 251 +++++++++++++++ 4 files changed, 510 insertions(+), 164 deletions(-) create mode 100644 src/databricks/sql/backend/kernel/_errors.py diff --git a/src/databricks/sql/backend/kernel/_errors.py b/src/databricks/sql/backend/kernel/_errors.py new file mode 100644 index 000000000..7a0a0e783 --- /dev/null +++ b/src/databricks/sql/backend/kernel/_errors.py @@ -0,0 +1,124 @@ +"""Shared error-mapping primitives for the kernel backend. + +The PyO3 boundary can produce two flavours of exception: + +- ``databricks_sql_kernel.KernelError`` — the kernel's own + structured error type. Carries ``code`` / ``message`` / + ``sql_state`` / ``query_id`` / ``http_status`` / ``retryable`` / + ``vendor_code`` / ``error_code`` as attributes; mapped to a PEP + 249 exception class via ``_CODE_TO_EXCEPTION`` with the + attributes forwarded onto the re-raised exception so callers can + branch on ``err.code`` / ``err.sql_state`` without reaching + through ``__cause__``. +- Anything else — ``TypeError`` / ``OverflowError`` / + ``ValueError`` from PyO3 argument conversion, or arbitrary + extension-internal Python errors. These would otherwise propagate + raw to connector callers, breaking the DB-API contract that says + "only PEP 249 exception types cross the boundary". Wrapped in + ``OperationalError`` here. + +These primitives live in their own module so both ``client.py`` +(which orchestrates PyO3 calls) and ``result_set.py`` (which calls +``fetch_next_batch`` on the same kernel handles) can share them +without ``result_set.py`` importing from ``client.py`` — that +direction would be a layering violation. +""" + +from __future__ import annotations + +import contextlib +from typing import Iterator + +from databricks.sql.exc import ( + DatabaseError, + Error, + OperationalError, + ProgrammingError, +) + + +try: + import databricks_sql_kernel as _kernel # type: ignore[import-not-found] +except ImportError as exc: # pragma: no cover - same hint as client.py + raise ImportError( + "use_kernel=True requires the databricks-sql-kernel extension, which " + "is not yet published on PyPI. Build and install it locally from the " + "databricks-sql-kernel repo:\n" + " cd databricks-sql-kernel/pyo3 && maturin develop --release\n" + "(into the same venv as databricks-sql-connector)." + ) from exc + + +# Map a kernel `code` slug to the PEP 249 exception class that best +# captures it. The match isn't a perfect 1:1 — PEP 249 has a +# narrower taxonomy than the kernel — so several kernel codes +# collapse onto the same Python exception. This table is the only +# place that mapping lives. +_CODE_TO_EXCEPTION = { + "InvalidArgument": ProgrammingError, + "Unauthenticated": OperationalError, + "PermissionDenied": OperationalError, + "NotFound": ProgrammingError, + "ResourceExhausted": OperationalError, + "Unavailable": OperationalError, + "Timeout": OperationalError, + "Cancelled": OperationalError, + "DataLoss": DatabaseError, + "Internal": DatabaseError, + "InvalidStatementHandle": ProgrammingError, + "NetworkError": OperationalError, + "SqlError": DatabaseError, + "Unknown": DatabaseError, +} + + +def reraise_kernel_error(exc: "_kernel.KernelError") -> "Error": + """Convert a ``databricks_sql_kernel.KernelError`` to a PEP 249 + exception with the kernel's structured attributes forwarded onto + the new instance.""" + code = getattr(exc, "code", "Unknown") + cls = _CODE_TO_EXCEPTION.get(code, DatabaseError) + new = cls(getattr(exc, "message", str(exc))) + for attr in ( + "code", + "sql_state", + "error_code", + "vendor_code", + "http_status", + "retryable", + "query_id", + ): + setattr(new, attr, getattr(exc, attr, None)) + new.__cause__ = exc + return new + + +@contextlib.contextmanager +def kernel_call(what: str) -> Iterator[None]: + """Context manager that wraps a span of PyO3 calls so any error + crossing the Python/Rust boundary surfaces as a PEP 249 + exception. + + ``KernelError`` flows through ``reraise_kernel_error`` (the + structured-attribute path). Anything else is wrapped in + ``OperationalError`` so DB-API callers see a uniform exception + surface and never have to catch native Python exceptions to + handle a connector-level failure. + + ``what`` is a short tag used only in the ``OperationalError`` + message for the non-``KernelError`` path; keep it caller-named + (e.g. ``"execute_command"``). + """ + try: + yield + except _kernel.KernelError as exc: + raise reraise_kernel_error(exc) from exc + except Error: + # Already a PEP 249 error (e.g. a nested ``kernel_call`` or + # the cursor-side guard re-raising one); let it propagate + # unchanged. + raise + except Exception as exc: + raise OperationalError( + f"Unexpected error from databricks_sql_kernel during {what}: {exc!r}" + ) from exc diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index 1b88cde60..c7259aaba 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -36,9 +36,14 @@ import logging import threading import uuid -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union from databricks.sql.backend.databricks_client import DatabricksClient +from databricks.sql.backend.kernel._errors import ( + _kernel, + kernel_call as _kernel_call, + reraise_kernel_error as _reraise_kernel_error, +) from databricks.sql.backend.kernel.auth_bridge import kernel_auth_kwargs from databricks.sql.backend.kernel.result_set import KernelResultSet from databricks.sql.backend.types import ( @@ -48,11 +53,8 @@ SessionId, ) from databricks.sql.exc import ( - DatabaseError, - Error, InterfaceError, NotSupportedError, - OperationalError, ProgrammingError, ) from databricks.sql.thrift_api.TCLIService import ttypes @@ -64,79 +66,6 @@ logger = logging.getLogger(__name__) -try: - import databricks_sql_kernel as _kernel # type: ignore[import-not-found] -except ImportError as exc: # pragma: no cover - import-time error surfaces clearly - # The ``databricks-sql-kernel`` wheel is not yet on PyPI, so the - # dev-install path is the only working one today. ``pip install - # databricks-sql-kernel`` would either find nothing or pull a - # squatted package, so we deliberately do not suggest it. Once - # the wheel is published the hint will move to - # ``pip install 'databricks-sql-connector[kernel]'``. - raise ImportError( - "use_kernel=True requires the databricks-sql-kernel extension, which " - "is not yet published on PyPI. Build and install it locally from the " - "databricks-sql-kernel repo:\n" - " cd databricks-sql-kernel/pyo3 && maturin develop --release\n" - "(into the same venv as databricks-sql-connector)." - ) from exc - - -# ─── Error mapping ────────────────────────────────────────────────────────── - - -# Map a kernel `code` slug to the PEP 249 exception class that best -# captures it. The match isn't a perfect 1:1 — PEP 249 has a -# narrower taxonomy than the kernel — so several kernel codes -# collapse onto the same Python exception. This table is the only -# place that mapping lives. -_CODE_TO_EXCEPTION = { - "InvalidArgument": ProgrammingError, - "Unauthenticated": OperationalError, - "PermissionDenied": OperationalError, - "NotFound": ProgrammingError, - "ResourceExhausted": OperationalError, - "Unavailable": OperationalError, - "Timeout": OperationalError, - "Cancelled": OperationalError, - "DataLoss": DatabaseError, - "Internal": DatabaseError, - "InvalidStatementHandle": ProgrammingError, - "NetworkError": OperationalError, - "SqlError": DatabaseError, - "Unknown": DatabaseError, -} - - -def _reraise_kernel_error(exc: "_kernel.KernelError") -> "Error": - """Convert a ``databricks_sql_kernel.KernelError`` to a PEP 249 - exception. - - Kernel errors carry their structured attrs (``code``, - ``message``, ``sql_state``, ``error_code``, ``query_id`` …) as - plain attributes — we copy them onto the re-raised exception so - callers can branch on them without reaching back through - ``__cause__``. - """ - code = getattr(exc, "code", "Unknown") - cls = _CODE_TO_EXCEPTION.get(code, DatabaseError) - new = cls(getattr(exc, "message", str(exc))) - # Forward the structured fields so connector users can read - # err.sql_state / err.query_id / etc. without a type-switch. - for attr in ( - "code", - "sql_state", - "error_code", - "vendor_code", - "http_status", - "retryable", - "query_id", - ): - setattr(new, attr, getattr(exc, attr, None)) - new.__cause__ = exc - return new - - # ─── Client ───────────────────────────────────────────────────────────────── @@ -183,6 +112,12 @@ def __init__( # Guarded by ``_async_handles_lock`` so concurrent cursors on the # same connection don't race on submit / close / close-session. self._async_handles: Dict[str, Any] = {} + # CommandId.guids of async commands that have already been + # closed (via ``close_command`` or ``close_session``). Lets + # ``get_query_state`` report ``CLOSED`` for them rather than + # the SUCCEEDED fall-through used for the never-tracked sync + # path. Same lock as ``_async_handles``. + self._closed_commands: Set[str] = set() self._async_handles_lock = threading.RLock() # ── Session lifecycle ────────────────────────────────────────── @@ -202,16 +137,15 @@ def open_session( if session_configuration: session_conf = {k: str(v) for k, v in session_configuration.items()} try: - self._kernel_session = _kernel.Session( - host=self._server_hostname, - http_path=self._http_path, - catalog=catalog or self._catalog, - schema=schema or self._schema, - session_conf=session_conf, - **self._auth_kwargs, - ) - except _kernel.KernelError as exc: - raise _reraise_kernel_error(exc) + with _kernel_call("open_session"): + self._kernel_session = _kernel.Session( + host=self._server_hostname, + http_path=self._http_path, + catalog=catalog or self._catalog, + schema=schema or self._schema, + session_conf=session_conf, + **self._auth_kwargs, + ) finally: # Drop the raw access token from the instance once the # kernel session is constructed (or failed). The kernel @@ -234,18 +168,24 @@ def close_session(self, session_id: SessionId) -> None: # Close any tracked async handles first so they fire their # server-side CloseStatement before the session goes away. with self._async_handles_lock: - handles_to_close = list(self._async_handles.values()) + tracked = list(self._async_handles.items()) self._async_handles.clear() - for handle in handles_to_close: + for guid, _ in tracked: + self._closed_commands.add(guid) + for _, handle in tracked: + # Per-handle close errors are non-fatal — PEP 249 + # discourages raising from session close — so log and + # move on. Any non-KernelError that crosses the PyO3 + # boundary also gets caught here for the same reason. try: handle.close() - except _kernel.KernelError as exc: + except Exception as exc: logger.warning( "Error closing async handle during session close: %s", exc ) try: self._kernel_session.close() - except _kernel.KernelError as exc: + except Exception as exc: # Surface as a non-fatal warning — the kernel's Drop # impl will retry the close fire-and-forget. PEP 249 # discourages raising from connection.close(). @@ -282,32 +222,39 @@ def execute_command( "Statement-level query_tags are not yet supported on the kernel backend." ) - stmt = self._kernel_session.statement() + with _kernel_call("execute_command"): + stmt = self._kernel_session.statement() try: - stmt.set_sql(operation) - if async_op: - async_exec = stmt.submit() - command_id = CommandId.from_sea_statement_id(async_exec.statement_id) - cursor.active_command_id = command_id - with self._async_handles_lock: - self._async_handles[command_id.guid] = async_exec - return None - executed = stmt.execute() - except _kernel.KernelError as exc: - raise _reraise_kernel_error(exc) + with _kernel_call("execute_command"): + stmt.set_sql(operation) + if async_op: + async_exec = stmt.submit() + command_id = CommandId.from_sea_statement_id(async_exec.statement_id) + cursor.active_command_id = command_id + with self._async_handles_lock: + self._async_handles[command_id.guid] = async_exec + return None + executed = stmt.execute() finally: # ``Statement`` is a lifecycle owner separate from the # executed handle it produces. Drop it here so the # parent doesn't keep the handle alive longer than the - # caller expects. + # caller expects. Swallow all close errors (including + # PyO3 native exceptions) — a failed stmt.close() is + # not actionable for the caller. try: stmt.close() - except _kernel.KernelError: + except Exception: pass command_id = CommandId.from_sea_statement_id(executed.statement_id) cursor.active_command_id = command_id - return self._make_result_set(executed, cursor, command_id) + # ``KernelResultSet.__init__`` calls ``arrow_schema()`` which + # can itself raise ``KernelError`` (or, in principle, a PyO3 + # native exception) — wrap the construction so callers see a + # mapped PEP 249 exception. + with _kernel_call("execute_command"): + return self._make_result_set(executed, cursor, command_id) def cancel_command(self, command_id: CommandId) -> None: with self._async_handles_lock: @@ -319,34 +266,38 @@ def cancel_command(self, command_id: CommandId) -> None: # Match the Thrift backend's tolerant behaviour. logger.debug("cancel_command: no in-flight async handle for %s", command_id) return - try: + with _kernel_call("cancel_command"): handle.cancel() - except _kernel.KernelError as exc: - raise _reraise_kernel_error(exc) def close_command(self, command_id: CommandId) -> None: with self._async_handles_lock: handle = self._async_handles.pop(command_id.guid, None) + if handle is not None: + # Record the close so ``get_query_state`` can report + # ``CLOSED`` (not ``SUCCEEDED``) for this command. + self._closed_commands.add(command_id.guid) if handle is None: logger.debug("close_command: no tracked handle for %s", command_id) return - try: + with _kernel_call("close_command"): handle.close() - except _kernel.KernelError as exc: - raise _reraise_kernel_error(exc) def get_query_state(self, command_id: CommandId) -> CommandState: with self._async_handles_lock: handle = self._async_handles.get(command_id.guid) + already_closed = command_id.guid in self._closed_commands if handle is None: - # No tracked async handle means execute_command ran - # sync and the result was materialised before returning; - # the command is terminal by construction. + if already_closed: + # We tracked this async handle and have since closed + # it; the command is no longer queryable on the + # server but the connector still has the id. + return CommandState.CLOSED + # No tracked async handle and never closed: execute_command + # ran sync and the result was materialised before + # returning. Terminal by construction. return CommandState.SUCCEEDED - try: + with _kernel_call("get_query_state"): state, failure = handle.status() - except _kernel.KernelError as exc: - raise _reraise_kernel_error(exc) if state == "Failed" and failure is not None: # Surface server-reported failure as a database error so # the cursor's polling loop terminates with the right @@ -361,17 +312,35 @@ def get_execution_result( cursor: "Cursor", ) -> "ResultSet": with self._async_handles_lock: - handle = self._async_handles.get(command_id.guid) - if handle is None: + async_exec = self._async_handles.get(command_id.guid) + if async_exec is None: raise ProgrammingError( "get_execution_result called for an unknown command_id; " "the kernel backend only tracks async-submitted statements." ) + with _kernel_call("get_execution_result"): + stream = async_exec.await_result() + # The async-exec handle's role ends once it has produced the + # ``ResultStream`` — keeping it around (and tracked in + # ``_async_handles``) would leak the server-side + # ``ExecutedAsyncStatement`` until ``close_session`` swept it + # up, since ``KernelResultSet.close`` only closes the stream + # it wraps. Drop tracking and fire-and-forget the close. + with self._async_handles_lock: + self._async_handles.pop(command_id.guid, None) + self._closed_commands.add(command_id.guid) try: - stream = handle.await_result() - except _kernel.KernelError as exc: - raise _reraise_kernel_error(exc) - return self._make_result_set(stream, cursor, command_id) + async_exec.close() + except Exception as exc: + logger.warning( + "Error closing async_exec after await_result for %s: %s", + command_id, + exc, + ) + # ``KernelResultSet.__init__`` calls ``arrow_schema()`` which + # can raise — keep that in the mapped-exception scope. + with _kernel_call("get_execution_result"): + return self._make_result_set(stream, cursor, command_id) # ── Metadata ─────────────────────────────────────────────────── @@ -413,11 +382,9 @@ def get_catalogs( ) -> "ResultSet": if self._kernel_session is None: raise InterfaceError("get_catalogs requires an open session.") - try: + with _kernel_call("get_catalogs"): stream = self._kernel_session.metadata().list_catalogs() - except _kernel.KernelError as exc: - raise _reraise_kernel_error(exc) - return self._make_result_set(stream, cursor, self._synthetic_command_id()) + return self._make_result_set(stream, cursor, self._synthetic_command_id()) def get_schemas( self, @@ -430,14 +397,12 @@ def get_schemas( ) -> "ResultSet": if self._kernel_session is None: raise InterfaceError("get_schemas requires an open session.") - try: + with _kernel_call("get_schemas"): stream = self._kernel_session.metadata().list_schemas( catalog=catalog_name, schema_pattern=schema_name, ) - except _kernel.KernelError as exc: - raise _reraise_kernel_error(exc) - return self._make_result_set(stream, cursor, self._synthetic_command_id()) + return self._make_result_set(stream, cursor, self._synthetic_command_id()) def get_tables( self, @@ -452,36 +417,37 @@ def get_tables( ) -> "ResultSet": if self._kernel_session is None: raise InterfaceError("get_tables requires an open session.") - try: + with _kernel_call("get_tables"): stream = self._kernel_session.metadata().list_tables( catalog=catalog_name, schema_pattern=schema_name, table_pattern=table_name, table_types=table_types, ) - except _kernel.KernelError as exc: - raise _reraise_kernel_error(exc) - if not table_types: - return self._make_result_set(stream, cursor, self._synthetic_command_id()) - # The kernel today returns the unfiltered ``SHOW TABLES`` shape - # regardless of ``table_types``. Drain to a single Arrow table - # and apply the same client-side filter the native SEA backend - # uses (column index 5 is TABLE_TYPE, case-sensitive). Cheap - # because metadata result sets are small. - from databricks.sql.backend.sea.utils.filters import ResultSetFilter - - full_table = _drain_kernel_handle(stream) - filtered_table = ResultSetFilter._filter_arrow_table( - full_table, - column_name=full_table.schema.field(5).name, - allowed_values=table_types, - case_sensitive=True, - ) - return self._make_result_set( - _StaticArrowHandle(filtered_table), - cursor, - self._synthetic_command_id(), - ) + if not table_types: + return self._make_result_set( + stream, cursor, self._synthetic_command_id() + ) + # The kernel today returns the unfiltered ``SHOW TABLES`` + # shape regardless of ``table_types``. Drain to a single + # Arrow table and apply the same client-side filter the + # native SEA backend uses (column index 5 is + # ``TABLE_TYPE``, case-sensitive). Cheap because metadata + # result sets are small. + from databricks.sql.backend.sea.utils.filters import ResultSetFilter + + full_table = _drain_kernel_handle(stream) + filtered_table = ResultSetFilter._filter_arrow_table( + full_table, + column_name=full_table.schema.field(5).name, + allowed_values=table_types, + case_sensitive=True, + ) + return self._make_result_set( + _StaticArrowHandle(filtered_table), + cursor, + self._synthetic_command_id(), + ) def get_columns( self, @@ -503,16 +469,14 @@ def get_columns( raise ProgrammingError( "get_columns requires catalog_name on the kernel backend." ) - try: + with _kernel_call("get_columns"): stream = self._kernel_session.metadata().list_columns( catalog=catalog_name, schema_pattern=schema_name, table_pattern=table_name, column_pattern=column_name, ) - except _kernel.KernelError as exc: - raise _reraise_kernel_error(exc) - return self._make_result_set(stream, cursor, self._synthetic_command_id()) + return self._make_result_set(stream, cursor, self._synthetic_command_id()) # ── Misc ─────────────────────────────────────────────────────── @@ -552,7 +516,10 @@ def _drain_kernel_handle(handle: Any) -> Any: batches.append(batch) try: handle.close() - except _kernel.KernelError: + except Exception: + # Non-fatal — the surrounding ``get_tables`` call has already + # captured the result data, and the handle's server-side + # state will be reaped by the kernel's Drop impl. pass return pyarrow.Table.from_batches(batches, schema=schema) diff --git a/src/databricks/sql/backend/kernel/result_set.py b/src/databricks/sql/backend/kernel/result_set.py index 3aaaf7696..64e10c3b5 100644 --- a/src/databricks/sql/backend/kernel/result_set.py +++ b/src/databricks/sql/backend/kernel/result_set.py @@ -36,6 +36,7 @@ import pyarrow +from databricks.sql.backend.kernel._errors import kernel_call from databricks.sql.backend.kernel.type_mapping import description_from_arrow_schema from databricks.sql.backend.types import CommandId, CommandState from databricks.sql.result_set import ResultSet @@ -67,7 +68,8 @@ def __init__( arraysize: int, buffer_size_bytes: int, ): - schema = kernel_handle.arrow_schema() + with kernel_call("KernelResultSet.arrow_schema"): + schema = kernel_handle.arrow_schema() super().__init__( connection=connection, backend=backend, @@ -105,7 +107,8 @@ def _pull_one_batch(self) -> bool: is exhausted.""" if self._exhausted: return False - batch = self._kernel_handle.fetch_next_batch() + with kernel_call("fetch_next_batch"): + batch = self._kernel_handle.fetch_next_batch() if batch is None: self._exhausted = True self.has_more_rows = False @@ -162,7 +165,8 @@ def _drain(self) -> pyarrow.Table: chunks.append(self._buffer.popleft()) if not self._exhausted: while True: - batch = self._kernel_handle.fetch_next_batch() + with kernel_call("fetch_next_batch"): + batch = self._kernel_handle.fetch_next_batch() if batch is None: self._exhausted = True self.has_more_rows = False diff --git a/tests/unit/test_kernel_client.py b/tests/unit/test_kernel_client.py index b23365c6e..cdce33147 100644 --- a/tests/unit/test_kernel_client.py +++ b/tests/unit/test_kernel_client.py @@ -395,3 +395,254 @@ def test_close_session_clears_async_handles_even_if_close_fails(): assert c._async_handles == {} assert good.close.called assert bad.close.called + + +def test_close_session_marks_swept_handles_as_closed(): + """Close-session pre-populates ``_closed_commands`` for every + swept async handle so a subsequent ``get_query_state`` reports + ``CLOSED`` instead of falling through to the SUCCEEDED + sync-default.""" + c = _make_client() + handle = MagicMock() + cid = CommandId.from_sea_statement_id("xyz") + c._async_handles[cid.guid] = handle + c._kernel_session = MagicMock() + c.close_session(MagicMock()) + assert cid.guid in c._closed_commands + + +# --------------------------------------------------------------------------- +# CLOSED command-state for previously-tracked async handles (m3) +# --------------------------------------------------------------------------- + + +def test_get_query_state_returns_closed_after_close_command(): + """After ``close_command`` on a tracked async handle, the + subsequent ``get_query_state`` lookup must report ``CLOSED``, + not fall through to the SUCCEEDED sync-default — the command + was tracked then closed; SUCCEEDED would lie about its history.""" + c = _make_client() + handle = MagicMock() + cid = CommandId.from_sea_statement_id("async-1") + c._async_handles[cid.guid] = handle + c.close_command(cid) + assert handle.close.called + assert c.get_query_state(cid) == CommandState.CLOSED + + +# --------------------------------------------------------------------------- +# PyO3 native exceptions (M2) — non-KernelError wrapping +# --------------------------------------------------------------------------- + + +def test_pyo3_native_exception_wrapped_as_operational_error(): + """A PyO3 boundary error that is *not* a ``KernelError`` (e.g. + ``TypeError`` from argument conversion) must surface as a PEP + 249 exception, not propagate raw to connector callers.""" + c = _make_client() + c._kernel_session = MagicMock() + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + # Statement chain succeeds, but ``execute`` raises a raw + # ``TypeError`` (simulating PyO3 argument-conversion failure). + stmt = MagicMock() + stmt.execute.side_effect = TypeError("argument 'foo' must be str, not int") + c._kernel_session.statement.return_value = stmt + with pytest.raises(OperationalError, match="Unexpected error from databricks_sql_kernel"): + c.execute_command( + operation="SELECT 1", + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + +def test_pyo3_native_exception_wrapped_for_metadata_calls(): + """Same wrapping for every metadata method.""" + c = _make_client() + c._kernel_session = MagicMock() + md = c._kernel_session.metadata.return_value + md.list_catalogs.side_effect = ValueError("bad PyO3 arg") + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + with pytest.raises(OperationalError): + c.get_catalogs( + session_id=MagicMock(), max_rows=1, max_bytes=1, cursor=cursor + ) + + +# --------------------------------------------------------------------------- +# Schema-on-construct race (M3) — KernelError during arrow_schema() +# --------------------------------------------------------------------------- + + +def test_kernel_error_during_result_set_construction_is_mapped(): + """``KernelResultSet.__init__`` calls + ``kernel_handle.arrow_schema()`` which can itself raise a + ``KernelError``. The connector must catch that and surface a + mapped PEP 249 exception, not let the raw ``KernelError`` + escape.""" + c = _make_client() + c._kernel_session = MagicMock() + md = c._kernel_session.metadata.return_value + bad_stream = MagicMock() + bad_stream.arrow_schema.side_effect = _FakeKernelError( + code="SqlError", message="schema unavailable" + ) + md.list_catalogs.return_value = bad_stream + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + with pytest.raises(DatabaseError, match="schema unavailable"): + c.get_catalogs( + session_id=MagicMock(), max_rows=1, max_bytes=1, cursor=cursor + ) + + +# --------------------------------------------------------------------------- +# Async leak in get_execution_result (M1) +# --------------------------------------------------------------------------- + + +def test_get_execution_result_closes_async_exec_and_drops_tracking(): + """The ``ExecutedAsyncStatement`` handle's role ends once it + produces a ``ResultStream`` via ``await_result()``. The client + must close the async_exec and drop the tracking entry there — + otherwise ``KernelResultSet.close()`` (which only closes the + stream) leaves the executed handle leaked server-side until + ``close_session`` sweeps.""" + c = _make_client() + c._kernel_session = MagicMock() + async_exec = MagicMock() + fake_stream = MagicMock() + fake_stream.arrow_schema.return_value = pa.schema([("n", pa.int64())]) + async_exec.await_result.return_value = fake_stream + cid = CommandId.from_sea_statement_id("async-leak-test") + c._async_handles[cid.guid] = async_exec + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + + c.get_execution_result(cid, cursor=cursor) + + # async_exec must be closed and dropped from tracking; the + # closed-commands set records it. + assert async_exec.close.called + assert cid.guid not in c._async_handles + assert cid.guid in c._closed_commands + + +def test_get_execution_result_does_not_raise_on_async_exec_close_failure(): + """A failure to close the async_exec is non-fatal — the result + stream has already been returned by ``await_result()`` and the + kernel's Drop will reap server-side state.""" + c = _make_client() + c._kernel_session = MagicMock() + async_exec = MagicMock() + fake_stream = MagicMock() + fake_stream.arrow_schema.return_value = pa.schema([("n", pa.int64())]) + async_exec.await_result.return_value = fake_stream + async_exec.close.side_effect = _FakeKernelError(code="Unavailable") + cid = CommandId.from_sea_statement_id("async-close-fail") + c._async_handles[cid.guid] = async_exec + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + + # Must not raise. + rs = c.get_execution_result(cid, cursor=cursor) + assert rs is not None + assert cid.guid not in c._async_handles + + +# --------------------------------------------------------------------------- +# get_tables table_types client-side filter (m2) +# --------------------------------------------------------------------------- + + +def _make_tables_stream() -> MagicMock: + """Build a fake stream that mimics the kernel's ``list_tables`` + output shape (5 cols ending in TABLE_TYPE at index 5 — the + connector matches what SEA produces, which has 5 metadata cols + before TABLE_TYPE). Returns a fixed table with mixed table types + so the filter has something to discriminate.""" + schema = pa.schema( + [ + ("TABLE_CAT", pa.string()), + ("TABLE_SCHEM", pa.string()), + ("TABLE_NAME", pa.string()), + ("EXTRA_1", pa.string()), + ("EXTRA_2", pa.string()), + ("TABLE_TYPE", pa.string()), + ] + ) + table = pa.table( + { + "TABLE_CAT": ["main", "main", "main"], + "TABLE_SCHEM": ["s", "s", "s"], + "TABLE_NAME": ["t1", "t2", "v1"], + "EXTRA_1": ["", "", ""], + "EXTRA_2": ["", "", ""], + "TABLE_TYPE": ["TABLE", "TABLE", "VIEW"], + }, + schema=schema, + ) + batches = table.to_batches() + stream = MagicMock() + stream.arrow_schema.return_value = schema + # First call returns the batch; second returns None (exhausted). + stream.fetch_next_batch.side_effect = batches + [None] + return stream + + +def test_get_tables_with_table_types_filters_rows(): + c = _make_client() + c._kernel_session = MagicMock() + c._kernel_session.metadata.return_value.list_tables.return_value = ( + _make_tables_stream() + ) + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + + rs = c.get_tables( + session_id=MagicMock(), + max_rows=10, + max_bytes=1, + cursor=cursor, + table_types=["TABLE"], + ) + table = rs.fetchall_arrow() + assert table.num_rows == 2 + assert set(table.column("TABLE_TYPE").to_pylist()) == {"TABLE"} + + +def test_get_tables_without_table_types_returns_full_stream(): + """No filter → kernel result flows through unchanged via the + normal ``KernelResultSet`` path (no drain-and-rewrap).""" + c = _make_client() + c._kernel_session = MagicMock() + c._kernel_session.metadata.return_value.list_tables.return_value = ( + _make_tables_stream() + ) + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + + rs = c.get_tables( + session_id=MagicMock(), + max_rows=10, + max_bytes=1, + cursor=cursor, + table_types=None, + ) + table = rs.fetchall_arrow() + assert table.num_rows == 3 From 089e27156a7f9a41362d6f12092a7853326fb230 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 18 May 2026 06:00:09 +0000 Subject: [PATCH 13/16] refactor(backend/kernel): replace kernel_call context manager with explicit try/except MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Using a ``with`` block for error translation is bad form: ``with`` conventionally signals resource lifecycle (locks, files), so ``with kernel_call("X"):`` hides the fact that the block raises a mapped exception. Replace with explicit ``try/except Exception as exc: raise _wrap_kernel_exception("X", exc) from exc`` at every PyO3 call site. What changed: - ``_errors.py``: drop the ``kernel_call`` context manager; export ``wrap_kernel_exception(what, exc)`` — a pure function that maps a raw exception to a PEP 249 one (KernelError → mapped class via ``reraise_kernel_error``; existing Error → passthrough; anything else → OperationalError). - ``client.py``: replace 12 ``with _kernel_call(...):`` blocks with inline try/except calling the helper. - ``result_set.py``: same for the 3 sites (arrow_schema on construct, fetch_next_batch in _pull_one_batch, fetch_next_batch in _drain). Behaviour is unchanged — same KernelError → PEP 249 mapping, same non-KernelError → OperationalError wrapping. Just spelled in a way that makes control flow visible at the call site and keeps tracebacks one frame shorter (no ``__exit__`` frame). 87/87 kernel unit tests still pass. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- src/databricks/sql/backend/kernel/_errors.py | 63 ++++++++++--------- src/databricks/sql/backend/kernel/client.py | 56 ++++++++++++----- .../sql/backend/kernel/result_set.py | 14 +++-- 3 files changed, 83 insertions(+), 50 deletions(-) diff --git a/src/databricks/sql/backend/kernel/_errors.py b/src/databricks/sql/backend/kernel/_errors.py index 7a0a0e783..3988298e3 100644 --- a/src/databricks/sql/backend/kernel/_errors.py +++ b/src/databricks/sql/backend/kernel/_errors.py @@ -20,15 +20,24 @@ These primitives live in their own module so both ``client.py`` (which orchestrates PyO3 calls) and ``result_set.py`` (which calls ``fetch_next_batch`` on the same kernel handles) can share them -without ``result_set.py`` importing from ``client.py`` — that -direction would be a layering violation. +without ``result_set.py`` importing from ``client.py``. + +Usage at every PyO3 call site is a plain try/except: + + try: + stmt.execute() + except Exception as exc: + raise wrap_kernel_exception("execute_command", exc) from exc + +The helper returns the mapped exception; callers raise it. Plain +``try/except`` is preferred over a context manager: the control +flow is visible at the call site, the helper is a pure function +(trivial to test), and tracebacks don't carry an extra +``__exit__`` frame. """ from __future__ import annotations -import contextlib -from typing import Iterator - from databricks.sql.exc import ( DatabaseError, Error, @@ -93,32 +102,24 @@ def reraise_kernel_error(exc: "_kernel.KernelError") -> "Error": return new -@contextlib.contextmanager -def kernel_call(what: str) -> Iterator[None]: - """Context manager that wraps a span of PyO3 calls so any error - crossing the Python/Rust boundary surfaces as a PEP 249 - exception. +def wrap_kernel_exception(what: str, exc: BaseException) -> "Error": + """Map any exception from a PyO3 call site to a PEP 249 exception. - ``KernelError`` flows through ``reraise_kernel_error`` (the - structured-attribute path). Anything else is wrapped in - ``OperationalError`` so DB-API callers see a uniform exception - surface and never have to catch native Python exceptions to - handle a connector-level failure. + - ``KernelError`` → mapped class with structured attrs forwarded. + - Already-PEP-249 ``Error`` (e.g. raised by an inner caller that + already mapped) → passed through unchanged. + - Anything else (``TypeError`` / ``ValueError`` / etc. from PyO3 + argument conversion, extension-internal errors) → wrapped in + ``OperationalError``. - ``what`` is a short tag used only in the ``OperationalError`` - message for the non-``KernelError`` path; keep it caller-named - (e.g. ``"execute_command"``). + Returned, not raised — the caller decides whether to ``raise`` + or ``raise ... from exc``. ``what`` is a short tag (the calling + method name) used only in the ``OperationalError`` message. """ - try: - yield - except _kernel.KernelError as exc: - raise reraise_kernel_error(exc) from exc - except Error: - # Already a PEP 249 error (e.g. a nested ``kernel_call`` or - # the cursor-side guard re-raising one); let it propagate - # unchanged. - raise - except Exception as exc: - raise OperationalError( - f"Unexpected error from databricks_sql_kernel during {what}: {exc!r}" - ) from exc + if isinstance(exc, _kernel.KernelError): + return reraise_kernel_error(exc) + if isinstance(exc, Error): + return exc + return OperationalError( + f"Unexpected error from databricks_sql_kernel during {what}: {exc!r}" + ) diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index c7259aaba..8e7ca54da 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -41,8 +41,8 @@ from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.kernel._errors import ( _kernel, - kernel_call as _kernel_call, reraise_kernel_error as _reraise_kernel_error, + wrap_kernel_exception as _wrap_kernel_exception, ) from databricks.sql.backend.kernel.auth_bridge import kernel_auth_kwargs from databricks.sql.backend.kernel.result_set import KernelResultSet @@ -137,7 +137,7 @@ def open_session( if session_configuration: session_conf = {k: str(v) for k, v in session_configuration.items()} try: - with _kernel_call("open_session"): + try: self._kernel_session = _kernel.Session( host=self._server_hostname, http_path=self._http_path, @@ -146,6 +146,8 @@ def open_session( session_conf=session_conf, **self._auth_kwargs, ) + except Exception as exc: + raise _wrap_kernel_exception("open_session", exc) from exc finally: # Drop the raw access token from the instance once the # kernel session is constructed (or failed). The kernel @@ -222,10 +224,12 @@ def execute_command( "Statement-level query_tags are not yet supported on the kernel backend." ) - with _kernel_call("execute_command"): + try: stmt = self._kernel_session.statement() + except Exception as exc: + raise _wrap_kernel_exception("execute_command", exc) from exc try: - with _kernel_call("execute_command"): + try: stmt.set_sql(operation) if async_op: async_exec = stmt.submit() @@ -235,6 +239,8 @@ def execute_command( self._async_handles[command_id.guid] = async_exec return None executed = stmt.execute() + except Exception as exc: + raise _wrap_kernel_exception("execute_command", exc) from exc finally: # ``Statement`` is a lifecycle owner separate from the # executed handle it produces. Drop it here so the @@ -253,8 +259,10 @@ def execute_command( # can itself raise ``KernelError`` (or, in principle, a PyO3 # native exception) — wrap the construction so callers see a # mapped PEP 249 exception. - with _kernel_call("execute_command"): + try: return self._make_result_set(executed, cursor, command_id) + except Exception as exc: + raise _wrap_kernel_exception("execute_command", exc) from exc def cancel_command(self, command_id: CommandId) -> None: with self._async_handles_lock: @@ -266,8 +274,10 @@ def cancel_command(self, command_id: CommandId) -> None: # Match the Thrift backend's tolerant behaviour. logger.debug("cancel_command: no in-flight async handle for %s", command_id) return - with _kernel_call("cancel_command"): + try: handle.cancel() + except Exception as exc: + raise _wrap_kernel_exception("cancel_command", exc) from exc def close_command(self, command_id: CommandId) -> None: with self._async_handles_lock: @@ -279,8 +289,10 @@ def close_command(self, command_id: CommandId) -> None: if handle is None: logger.debug("close_command: no tracked handle for %s", command_id) return - with _kernel_call("close_command"): + try: handle.close() + except Exception as exc: + raise _wrap_kernel_exception("close_command", exc) from exc def get_query_state(self, command_id: CommandId) -> CommandState: with self._async_handles_lock: @@ -296,8 +308,10 @@ def get_query_state(self, command_id: CommandId) -> CommandState: # ran sync and the result was materialised before # returning. Terminal by construction. return CommandState.SUCCEEDED - with _kernel_call("get_query_state"): + try: state, failure = handle.status() + except Exception as exc: + raise _wrap_kernel_exception("get_query_state", exc) from exc if state == "Failed" and failure is not None: # Surface server-reported failure as a database error so # the cursor's polling loop terminates with the right @@ -318,8 +332,10 @@ def get_execution_result( "get_execution_result called for an unknown command_id; " "the kernel backend only tracks async-submitted statements." ) - with _kernel_call("get_execution_result"): + try: stream = async_exec.await_result() + except Exception as exc: + raise _wrap_kernel_exception("get_execution_result", exc) from exc # The async-exec handle's role ends once it has produced the # ``ResultStream`` — keeping it around (and tracked in # ``_async_handles``) would leak the server-side @@ -338,9 +354,11 @@ def get_execution_result( exc, ) # ``KernelResultSet.__init__`` calls ``arrow_schema()`` which - # can raise — keep that in the mapped-exception scope. - with _kernel_call("get_execution_result"): + # can raise — map that to PEP 249 too. + try: return self._make_result_set(stream, cursor, command_id) + except Exception as exc: + raise _wrap_kernel_exception("get_execution_result", exc) from exc # ── Metadata ─────────────────────────────────────────────────── @@ -382,9 +400,11 @@ def get_catalogs( ) -> "ResultSet": if self._kernel_session is None: raise InterfaceError("get_catalogs requires an open session.") - with _kernel_call("get_catalogs"): + try: stream = self._kernel_session.metadata().list_catalogs() return self._make_result_set(stream, cursor, self._synthetic_command_id()) + except Exception as exc: + raise _wrap_kernel_exception("get_catalogs", exc) from exc def get_schemas( self, @@ -397,12 +417,14 @@ def get_schemas( ) -> "ResultSet": if self._kernel_session is None: raise InterfaceError("get_schemas requires an open session.") - with _kernel_call("get_schemas"): + try: stream = self._kernel_session.metadata().list_schemas( catalog=catalog_name, schema_pattern=schema_name, ) return self._make_result_set(stream, cursor, self._synthetic_command_id()) + except Exception as exc: + raise _wrap_kernel_exception("get_schemas", exc) from exc def get_tables( self, @@ -417,7 +439,7 @@ def get_tables( ) -> "ResultSet": if self._kernel_session is None: raise InterfaceError("get_tables requires an open session.") - with _kernel_call("get_tables"): + try: stream = self._kernel_session.metadata().list_tables( catalog=catalog_name, schema_pattern=schema_name, @@ -448,6 +470,8 @@ def get_tables( cursor, self._synthetic_command_id(), ) + except Exception as exc: + raise _wrap_kernel_exception("get_tables", exc) from exc def get_columns( self, @@ -469,7 +493,7 @@ def get_columns( raise ProgrammingError( "get_columns requires catalog_name on the kernel backend." ) - with _kernel_call("get_columns"): + try: stream = self._kernel_session.metadata().list_columns( catalog=catalog_name, schema_pattern=schema_name, @@ -477,6 +501,8 @@ def get_columns( column_pattern=column_name, ) return self._make_result_set(stream, cursor, self._synthetic_command_id()) + except Exception as exc: + raise _wrap_kernel_exception("get_columns", exc) from exc # ── Misc ─────────────────────────────────────────────────────── diff --git a/src/databricks/sql/backend/kernel/result_set.py b/src/databricks/sql/backend/kernel/result_set.py index 64e10c3b5..c2b721a62 100644 --- a/src/databricks/sql/backend/kernel/result_set.py +++ b/src/databricks/sql/backend/kernel/result_set.py @@ -36,7 +36,7 @@ import pyarrow -from databricks.sql.backend.kernel._errors import kernel_call +from databricks.sql.backend.kernel._errors import wrap_kernel_exception from databricks.sql.backend.kernel.type_mapping import description_from_arrow_schema from databricks.sql.backend.types import CommandId, CommandState from databricks.sql.result_set import ResultSet @@ -68,8 +68,10 @@ def __init__( arraysize: int, buffer_size_bytes: int, ): - with kernel_call("KernelResultSet.arrow_schema"): + try: schema = kernel_handle.arrow_schema() + except Exception as exc: + raise wrap_kernel_exception("KernelResultSet.arrow_schema", exc) from exc super().__init__( connection=connection, backend=backend, @@ -107,8 +109,10 @@ def _pull_one_batch(self) -> bool: is exhausted.""" if self._exhausted: return False - with kernel_call("fetch_next_batch"): + try: batch = self._kernel_handle.fetch_next_batch() + except Exception as exc: + raise wrap_kernel_exception("fetch_next_batch", exc) from exc if batch is None: self._exhausted = True self.has_more_rows = False @@ -165,8 +169,10 @@ def _drain(self) -> pyarrow.Table: chunks.append(self._buffer.popleft()) if not self._exhausted: while True: - with kernel_call("fetch_next_batch"): + try: batch = self._kernel_handle.fetch_next_batch() + except Exception as exc: + raise wrap_kernel_exception("fetch_next_batch", exc) from exc if batch is None: self._exhausted = True self.has_more_rows = False From 14357c50bc1183c73c44c9a14b0151f647b627fd Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 18 May 2026 06:05:01 +0000 Subject: [PATCH 14/16] style(backend/kernel): black format client.py Fixes ``check-linting`` CI: one long line in ``execute_command``'s async-submit branch needed to wrap (the ``CommandId.from_sea_statement_id`` call). Pure formatting; no behaviour change. Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- src/databricks/sql/backend/kernel/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index 8e7ca54da..59d8f45d6 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -233,7 +233,9 @@ def execute_command( stmt.set_sql(operation) if async_op: async_exec = stmt.submit() - command_id = CommandId.from_sea_statement_id(async_exec.statement_id) + command_id = CommandId.from_sea_statement_id( + async_exec.statement_id + ) cursor.active_command_id = command_id with self._async_handles_lock: self._async_handles[command_id.guid] = async_exec From 0e1a250b6bde8aaf2b1b47876132560842aac899 Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 18 May 2026 09:09:53 +0000 Subject: [PATCH 15/16] fix(backend/kernel): address gopalldb's P1 review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Local kernel-package fixes from the follow-up review pass on PR #787 (https://github.com/databricks/databricks-sql-python/pull/787#issuecomment-4475165482). The two cross-cutting / pre-existing issues (P0 #1 async leak in shared ``Cursor.close()``, P1 #8/#9 fetch-after-close not raising ``InterfaceError``) are tracked separately as #791 and #792 — they affect Thrift and SEA equally and are out of scope for this PR. P1 #2 — call ``backend.close_command`` from ``KernelResultSet.close()``: The override previously bypassed the base ``ResultSet.close()`` entirely. Honour the contract by invoking ``backend.close_command(self.command_id)`` after the per-handle close + ``_async_handles`` pop. Our own ``close_command`` is tolerant of already-popped guids (no-op), so this is safe even though the per-handle close above already released server-side state. Doesn't go through ``super().close()`` directly because the base path warns when ``self.results`` is ``None`` (which it is for kernel result sets) — replicate the meaningful part of the base contract without the noisy warning. P1 #3 — case-insensitive ``Bearer`` prefix in auth_bridge: RFC 6750 §2.1 says the Authorization scheme is case-insensitive. Match leniently in case a federation proxy or future provider normalises the casing differently — failing closed would surface as a confusing ``ProgrammingError`` from the bridge. P1 #4 — drop redundant ``__cause__`` set in ``reraise_kernel_error``: ``raise wrap_kernel_exception(...) from exc`` already sets ``__cause__`` at the call site; the manual assignment in ``reraise_kernel_error`` was redundant. Updated the test that asserted on it; added ``test_kernel_error_chains_through_wrap`` to cover the end-to-end chain. P1 #5 — ``get_tables`` filter looks up TABLE_TYPE by name: Replaced ``schema.field(5).name`` (positional) with the literal ``"TABLE_TYPE"`` plus a missing-column guard. A future kernel reshape of ``SHOW TABLES`` now surfaces an explicit ``OperationalError`` instead of silently filtering the wrong column. The case-sensitive contract is now documented in the surrounding comment (matches SEA + warehouse). P1 #6 — ``KernelResultSet.close()`` guards on ``connection.open``: ``__del__``-driven close arriving after the parent connection is already closed previously issued a kernel call into a disposed session. Skip the kernel call entirely in that case; still mark the result set ``CLOSED`` locally so ``__del__`` is idempotent. P1 #7 — defer ``kernel_auth_kwargs`` to ``open_session``: ``KernelDatabricksClient.__init__`` previously called ``kernel_auth_kwargs(auth_provider)`` and stored the bearer token on ``self._auth_kwargs`` indefinitely. If ``open_session`` never ran (test paths, error paths, lazy retries) the token stayed resident on the connector object. Build the kwargs locally inside ``open_session`` now — local variable, GC-eligible the moment ``open_session`` returns. Also tightened the install-hint comment in ``pyproject.toml`` to match the rest of the codebase (the wheel isn't on PyPI; only the ``maturin develop`` path is supported today). 88/88 kernel unit tests pass (added 1). Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- pyproject.toml | 6 +- src/databricks/sql/backend/kernel/_errors.py | 8 ++- .../sql/backend/kernel/auth_bridge.py | 11 ++- src/databricks/sql/backend/kernel/client.py | 67 ++++++++++++------- .../sql/backend/kernel/result_set.py | 37 ++++++++-- tests/unit/test_kernel_client.py | 24 ++++++- 6 files changed, 115 insertions(+), 38 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6868919d2..c81747996 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,10 +46,10 @@ pyarrow = ["pyarrow"] # [tool.poetry.extras] # kernel = ["databricks-sql-kernel"] # -# Until then, install the kernel separately: -# pip install databricks-sql-kernel -# or (local dev): +# Until then, the wheel is not on PyPI and the only supported +# install path is local dev: # cd databricks-sql-kernel/pyo3 && maturin develop --release +# (into the same venv as databricks-sql-connector). [tool.poetry.group.dev.dependencies] pytest = "^7.1.2" diff --git a/src/databricks/sql/backend/kernel/_errors.py b/src/databricks/sql/backend/kernel/_errors.py index 3988298e3..a844ff716 100644 --- a/src/databricks/sql/backend/kernel/_errors.py +++ b/src/databricks/sql/backend/kernel/_errors.py @@ -84,7 +84,12 @@ def reraise_kernel_error(exc: "_kernel.KernelError") -> "Error": """Convert a ``databricks_sql_kernel.KernelError`` to a PEP 249 exception with the kernel's structured attributes forwarded onto - the new instance.""" + the new instance. + + The returned exception is raised by callers with ``raise ... from + exc``; the ``from`` clause is what sets ``__cause__``, so we don't + touch it here. + """ code = getattr(exc, "code", "Unknown") cls = _CODE_TO_EXCEPTION.get(code, DatabaseError) new = cls(getattr(exc, "message", str(exc))) @@ -98,7 +103,6 @@ def reraise_kernel_error(exc: "_kernel.KernelError") -> "Error": "query_id", ): setattr(new, attr, getattr(exc, attr, None)) - new.__cause__ = exc return new diff --git a/src/databricks/sql/backend/kernel/auth_bridge.py b/src/databricks/sql/backend/kernel/auth_bridge.py index f382284d2..827545b0a 100644 --- a/src/databricks/sql/backend/kernel/auth_bridge.py +++ b/src/databricks/sql/backend/kernel/auth_bridge.py @@ -29,7 +29,12 @@ logger = logging.getLogger(__name__) -_BEARER_PREFIX = "Bearer " +# RFC 6750 §2.1 defines the Authorization scheme as case-insensitive. +# The connector's auth providers all emit ``Bearer `` exactly today, +# but we match leniently in case a federation proxy or future provider +# normalises the casing differently — failing closed here would surface +# as a confusing ``ProgrammingError`` from the bridge. +_BEARER_PREFIX_LEN = len("Bearer ") # Defense-in-depth: reject tokens containing ASCII control characters. # A token with embedded CR/LF/NUL would let a misbehaving HTTP stack @@ -75,9 +80,9 @@ def _extract_bearer_token(auth_provider: AuthProvider) -> Optional[str]: auth = headers.get("Authorization") if not auth: return None - if not auth.startswith(_BEARER_PREFIX): + if not auth[:_BEARER_PREFIX_LEN].lower() == "bearer ": return None - token = auth[len(_BEARER_PREFIX) :] + token = auth[_BEARER_PREFIX_LEN:] if _CONTROL_CHAR_RE.search(token): raise ProgrammingError( "Bearer token contains ASCII control characters; refusing to " diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index 59d8f45d6..f1093a070 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -55,6 +55,7 @@ from databricks.sql.exc import ( InterfaceError, NotSupportedError, + OperationalError, ProgrammingError, ) from databricks.sql.thrift_api.TCLIService import ttypes @@ -100,7 +101,15 @@ def __init__( self._auth_provider = auth_provider self._catalog = catalog self._schema = schema - self._auth_kwargs = kernel_auth_kwargs(auth_provider) + # NB: don't call ``kernel_auth_kwargs`` here. That call + # materialises the bearer token in-process; keeping a + # cleartext copy on a long-lived connector object that may + # never have ``open_session`` invoked (test paths, error + # paths, lazy retries) widens the window where a debugger + # dump or accidental pickle could capture the credential. + # Resolved inside ``open_session`` instead, then immediately + # cleared once the kernel ``Session`` owns it. + # # Open ``databricks_sql_kernel.Session`` lazily in # ``open_session`` so the Session lifecycle gates the # underlying connection setup — same shape as Thrift's @@ -136,25 +145,28 @@ def open_session( session_conf: Optional[Dict[str, str]] = None if session_configuration: session_conf = {k: str(v) for k, v in session_configuration.items()} + # Build auth kwargs here (not in ``__init__``) so the bearer + # token has the shortest possible in-process lifetime: a + # local kwargs dict is GC-eligible the moment this method + # returns, regardless of whether the kernel ``Session()`` + # call succeeded or raised. + auth_kwargs = kernel_auth_kwargs(self._auth_provider) try: - try: - self._kernel_session = _kernel.Session( - host=self._server_hostname, - http_path=self._http_path, - catalog=catalog or self._catalog, - schema=schema or self._schema, - session_conf=session_conf, - **self._auth_kwargs, - ) - except Exception as exc: - raise _wrap_kernel_exception("open_session", exc) from exc + self._kernel_session = _kernel.Session( + host=self._server_hostname, + http_path=self._http_path, + catalog=catalog or self._catalog, + schema=schema or self._schema, + session_conf=session_conf, + **auth_kwargs, + ) + except Exception as exc: + raise _wrap_kernel_exception("open_session", exc) from exc finally: - # Drop the raw access token from the instance once the - # kernel session is constructed (or failed). The kernel - # owns the credential from this point on; keeping a - # cleartext copy on a long-lived connector object risks - # accidental capture by pickling / debuggers / telemetry. - self._auth_kwargs.pop("access_token", None) + # Best-effort scrub of the local dict before it goes out + # of scope. The kernel ``Session`` (if construction + # succeeded) now owns its own copy of the credential. + auth_kwargs.pop("access_token", None) # Use the kernel's real server-issued session id, not a # synthetic UUID. Matches what the native SEA backend does. @@ -319,7 +331,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState: # the cursor's polling loop terminates with the right # exception class — matches the Thrift backend's # behaviour on TOperationState::ERROR_STATE. - raise _reraise_kernel_error(failure) + raise _reraise_kernel_error(failure) from failure return _STATE_TO_COMMAND_STATE.get(state, CommandState.FAILED) def get_execution_result( @@ -455,15 +467,24 @@ def get_tables( # The kernel today returns the unfiltered ``SHOW TABLES`` # shape regardless of ``table_types``. Drain to a single # Arrow table and apply the same client-side filter the - # native SEA backend uses (column index 5 is - # ``TABLE_TYPE``, case-sensitive). Cheap because metadata - # result sets are small. + # native SEA backend uses. The filter is **case-sensitive** + # — matches the SEA backend's documented behaviour, and + # mirrors how the warehouse reports the values + # (``TABLE`` / ``VIEW`` / ``SYSTEM_TABLE`` — uppercase). + # Look the column up by name rather than positional index + # so a future kernel reshape of ``SHOW TABLES`` doesn't + # silently filter the wrong column. from databricks.sql.backend.sea.utils.filters import ResultSetFilter full_table = _drain_kernel_handle(stream) + if "TABLE_TYPE" not in full_table.schema.names: + raise OperationalError( + "kernel get_tables result is missing a TABLE_TYPE " + f"column; got {full_table.schema.names!r}" + ) filtered_table = ResultSetFilter._filter_arrow_table( full_table, - column_name=full_table.schema.field(5).name, + column_name="TABLE_TYPE", allowed_values=table_types, case_sensitive=True, ) diff --git a/src/databricks/sql/backend/kernel/result_set.py b/src/databricks/sql/backend/kernel/result_set.py index c2b721a62..7e978c3bf 100644 --- a/src/databricks/sql/backend/kernel/result_set.py +++ b/src/databricks/sql/backend/kernel/result_set.py @@ -223,12 +223,27 @@ def fetchall(self) -> List[Row]: return self._convert_arrow_table(self._drain()) def close(self) -> None: - """Close the underlying kernel handle. Idempotent — the - kernel's own ``close()`` is idempotent, and we guard against - repeated calls so partially-drained streams don't double- - decrement reference counts.""" + """Close the underlying kernel handle and notify the backend. + + Idempotent — the kernel's own ``close()`` is idempotent, and + we guard against repeated calls so partially-drained streams + don't double-decrement reference counts. + + Skipped entirely when the parent connection is already + closed. A ``__del__``-driven close arriving after + connection-close would otherwise issue a kernel call into an + already-disposed session. + """ if self._kernel_handle is None: return + if not self.connection.open: + self._kernel_handle = None + self._buffer.clear() + self._buffered_count = 0 + self._exhausted = True + self.has_been_closed_server_side = True + self.status = CommandState.CLOSED + return try: self._kernel_handle.close() except Exception as exc: @@ -242,11 +257,23 @@ def close(self) -> None: # otherwise leave a stale entry pointing at a closed handle. # No-op for the sync-execute and metadata paths, which never # register in ``_async_handles``. + backend = cast("KernelDatabricksClient", self.backend) guid = getattr(self.command_id, "guid", None) if guid is not None: - backend = cast("KernelDatabricksClient", self.backend) with backend._async_handles_lock: backend._async_handles.pop(guid, None) + # Honor the base ``ResultSet`` contract: notify the backend + # so any cross-cutting bookkeeping (telemetry, command-state + # tracking) sees the close. Our own ``close_command`` is + # tolerant of unknown command_ids (no-op), so this is safe + # even though the per-handle close above already released + # server-side state. + try: + backend.close_command(self.command_id) + except Exception as exc: + logger.warning( + "backend.close_command from result-set close failed: %s", exc + ) self._buffer.clear() self._buffered_count = 0 self._kernel_handle = None diff --git a/tests/unit/test_kernel_client.py b/tests/unit/test_kernel_client.py index cdce33147..99d867ca4 100644 --- a/tests/unit/test_kernel_client.py +++ b/tests/unit/test_kernel_client.py @@ -99,12 +99,13 @@ def __init__( ) def test_code_to_exception_mapping(code, expected_cls): """Every entry in ``_CODE_TO_EXCEPTION`` maps to the documented - PEP 249 class.""" + PEP 249 class. Cause chaining happens at the ``raise ... from exc`` + call site, not inside ``_reraise_kernel_error`` — verified + separately by ``test_kernel_error_chains_through_wrap``.""" err = _FakeKernelError(code=code, message=f"{code} boom") out = kernel_client._reraise_kernel_error(err) assert isinstance(out, expected_cls) assert "boom" in str(out) - assert out.__cause__ is err def test_unknown_code_falls_back_to_database_error(): @@ -131,6 +132,25 @@ def test_reraise_forwards_structured_attributes(): assert out.retryable is False +def test_kernel_error_chains_through_wrap(): + """``raise wrap_kernel_exception(...) from exc`` is the call-site + pattern; ``__cause__`` must be set to the original ``KernelError`` + so users can dig out the structured fields via ``e.__cause__``.""" + src = _FakeKernelError(code="SqlError", message="boom", sql_state="42P01") + try: + try: + raise src + except Exception as exc: + from databricks.sql.backend.kernel._errors import wrap_kernel_exception + + raise wrap_kernel_exception("test_site", exc) from exc + except DatabaseError as out: + assert out.__cause__ is src + assert getattr(out, "sql_state", None) == "42P01" + else: + raise AssertionError("expected DatabaseError") + + # --------------------------------------------------------------------------- # State mapping # --------------------------------------------------------------------------- From a05781f671dac7c7b94980b51864bb65bd2648eb Mon Sep 17 00:00:00 2001 From: Vikrant Puppala Date: Mon, 18 May 2026 09:57:59 +0000 Subject: [PATCH 16/16] fix(backend/kernel): address gopalldb's follow-up P1/P2 review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Review comment: https://github.com/databricks/databricks-sql-python/pull/787#issuecomment-4476419134 P1.1 — get_query_state handles non-BaseException failure: Previously ``raise _reraise_kernel_error(failure) from failure`` would explode with ``TypeError: exception causes must derive from BaseException`` if the kernel's ``status()`` ever returned a ``failure`` that wasn't a real ``KernelError`` (struct, dict — kernel API drift). Now route through ``_wrap_kernel_exception``, which isinstance-checks and falls through to ``OperationalError`` for non-PEP-249-shaped values. New unit test ``test_get_query_state_handles_non_baseexception_failure``. P1.2 — regression tests for prior fixes: - ``test_bearer_prefix_is_case_insensitive`` (P1 #3 from the earlier review): parametrised over "Bearer "/"bearer "/ "BEARER "/"BeArEr ". RFC 6750 §2.1 compliance was added but not covered by a test. - ``test_close_skips_kernel_call_when_connection_already_closed`` (P1 #6 from the earlier review): exercises the ``connection.open is False`` branch in ``KernelResultSet.close()`` — asserts neither the kernel handle's close nor backend.close_command fire, but the result set still ends in ``CLOSED`` so ``__del__`` is idempotent. - ``test_token_with_control_chars_or_whitespace_rejected`` (pre-existing security guard): parametrised over NUL / CR / LF / DEL / space / tab — the regex previously missed space (0x20). Covered + extended. P2.1 — wrap session_id extraction in open_session: ``SessionId.from_sea_session_id(self._kernel_session.session_id)`` was outside the ``try/except _wrap_kernel_exception`` scope. A raw PyO3 attribute-conversion error on the ``self._kernel_session.session_id`` access could escape unwrapped. Now wrapped. P2.2 — drop redundant _async_handles.pop in result-set close: After the M1 fix (``get_execution_result`` pops the guid before constructing the result set), the pop in ``KernelResultSet.close`` is dead code — every call misses. Sync-execute and metadata paths never registered in ``_async_handles`` to begin with. Drop the per-close pop; rewrote the surrounding comment so ``backend.close_command`` is now the single bookkeeping seam. P2.3 — control-char regex includes whitespace: Extended ``[\x00-\x1f\x7f]`` → ``[\x00-\x20\x7f]`` and renamed ``_CONTROL_CHAR_RE`` → ``_TOKEN_REJECT_RE``. RFC 6750 forbids whitespace within the credential token itself; a token like ``"Bearer doubled-space-token"`` previously slipped past the injection guard. Test parametrised above. type_mapping reuse — use SqlType constants: Replaced literal type strings ("bigint", "string", …) in ``_arrow_type_to_dbapi_string`` with the ``SqlType`` constants from ``databricks.sql.backend.sea.utils.conversion`` — same single source of truth the SEA backend already uses, so the kernel and SEA backends emit byte-identical type-code strings. The Arrow → SqlType lookup itself stays local to the kernel (SEA receives type-text from the server and normalises it; the kernel receives Arrow schemas directly), but the names are now shared. 100/100 kernel unit tests pass (added 12). Co-authored-by: Isaac Signed-off-by: Vikrant Puppala --- .../sql/backend/kernel/auth_bridge.py | 20 ++++---- src/databricks/sql/backend/kernel/client.py | 22 +++++++-- .../sql/backend/kernel/result_set.py | 24 +++------ .../sql/backend/kernel/type_mapping.py | 49 ++++++++++++------- tests/unit/test_kernel_auth_bridge.py | 44 +++++++++++++++++ tests/unit/test_kernel_client.py | 20 ++++++++ tests/unit/test_kernel_result_set.py | 26 ++++++++++ 7 files changed, 159 insertions(+), 46 deletions(-) diff --git a/src/databricks/sql/backend/kernel/auth_bridge.py b/src/databricks/sql/backend/kernel/auth_bridge.py index 827545b0a..a9acc2655 100644 --- a/src/databricks/sql/backend/kernel/auth_bridge.py +++ b/src/databricks/sql/backend/kernel/auth_bridge.py @@ -36,12 +36,14 @@ # as a confusing ``ProgrammingError`` from the bridge. _BEARER_PREFIX_LEN = len("Bearer ") -# Defense-in-depth: reject tokens containing ASCII control characters. -# A token with embedded CR/LF/NUL would let a misbehaving HTTP stack -# split or terminate the Authorization header line, opening a header- -# injection sink. Real PATs and federation-exchanged tokens never -# contain these. -_CONTROL_CHAR_RE = re.compile(r"[\x00-\x1f\x7f]") +# Defense-in-depth: reject tokens containing ASCII control characters +# or whitespace. CR/LF/NUL in a token would let a misbehaving HTTP +# stack split or terminate the Authorization header line, opening a +# header-injection sink. Space (0x20) is included so leading-/ +# embedded-whitespace tokens (e.g. ``"Bearer doubled-space-token"``, +# tab-prefixed token) get rejected too — RFC 6750 §2.1 forbids +# whitespace within the credential token itself. +_TOKEN_REJECT_RE = re.compile(r"[\x00-\x20\x7f]") def _is_pat(auth_provider: AuthProvider) -> bool: @@ -83,10 +85,10 @@ def _extract_bearer_token(auth_provider: AuthProvider) -> Optional[str]: if not auth[:_BEARER_PREFIX_LEN].lower() == "bearer ": return None token = auth[_BEARER_PREFIX_LEN:] - if _CONTROL_CHAR_RE.search(token): + if _TOKEN_REJECT_RE.search(token): raise ProgrammingError( - "Bearer token contains ASCII control characters; refusing to " - "forward it to the kernel auth bridge." + "Bearer token contains ASCII control characters or whitespace; " + "refusing to forward it to the kernel auth bridge." ) return token diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index f1093a070..fba814fc3 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -170,8 +170,13 @@ def open_session( # Use the kernel's real server-issued session id, not a # synthetic UUID. Matches what the native SEA backend does. - # Bind to a local first so mypy sees a non-Optional return. - session_id = SessionId.from_sea_session_id(self._kernel_session.session_id) + # ``session_id`` is a PyO3 attribute access — also wrapped so + # any conversion error surfaces as a mapped PEP 249 exception + # instead of bubbling raw from the boundary. + try: + session_id = SessionId.from_sea_session_id(self._kernel_session.session_id) + except Exception as exc: + raise _wrap_kernel_exception("open_session", exc) from exc self._session_id = session_id logger.info("Opened kernel-backed session %s", session_id) return session_id @@ -330,8 +335,17 @@ def get_query_state(self, command_id: CommandId) -> CommandState: # Surface server-reported failure as a database error so # the cursor's polling loop terminates with the right # exception class — matches the Thrift backend's - # behaviour on TOperationState::ERROR_STATE. - raise _reraise_kernel_error(failure) from failure + # behaviour on TOperationState::ERROR_STATE. Routed + # through ``_wrap_kernel_exception`` rather than + # ``_reraise_kernel_error`` directly so a non- + # ``KernelError``-shaped ``failure`` (kernel API drift — + # struct, dict, etc.) still produces a mapped PEP 249 + # exception instead of a confusing + # ``TypeError: exception causes must derive from + # BaseException`` from the ``from`` clause. + if isinstance(failure, BaseException): + raise _wrap_kernel_exception("get_query_state", failure) from failure + raise _wrap_kernel_exception("get_query_state", Exception(repr(failure))) return _STATE_TO_COMMAND_STATE.get(state, CommandState.FAILED) def get_execution_result( diff --git a/src/databricks/sql/backend/kernel/result_set.py b/src/databricks/sql/backend/kernel/result_set.py index 7e978c3bf..ed98984c8 100644 --- a/src/databricks/sql/backend/kernel/result_set.py +++ b/src/databricks/sql/backend/kernel/result_set.py @@ -251,23 +251,15 @@ def close(self) -> None: # level; log and swallow so the cursor's __del__ / # connection close path stays clean. logger.warning("Error closing kernel handle: %s", exc) - # Drop the entry from the backend's async-handle map (if - # present) — for async-submitted statements the handle is - # tracked there and the base ``ResultSet.close`` path would - # otherwise leave a stale entry pointing at a closed handle. - # No-op for the sync-execute and metadata paths, which never - # register in ``_async_handles``. + # Honor the base ``ResultSet`` contract: notify the backend. + # ``backend.close_command`` also drops the ``_async_handles`` + # entry and records the guid in ``_closed_commands`` — no + # separate pop needed here. Sync-execute and metadata paths + # never registered in ``_async_handles`` to begin with, and + # ``get_execution_result`` pops the async path before the + # result set is even constructed (see the M1 fix), so this + # call is the single bookkeeping seam. backend = cast("KernelDatabricksClient", self.backend) - guid = getattr(self.command_id, "guid", None) - if guid is not None: - with backend._async_handles_lock: - backend._async_handles.pop(guid, None) - # Honor the base ``ResultSet`` contract: notify the backend - # so any cross-cutting bookkeeping (telemetry, command-state - # tracking) sees the close. Our own ``close_command`` is - # tolerant of unknown command_ids (no-op), so this is safe - # even though the per-handle close above already released - # server-side state. try: backend.close_command(self.command_id) except Exception as exc: diff --git a/src/databricks/sql/backend/kernel/type_mapping.py b/src/databricks/sql/backend/kernel/type_mapping.py index 83e55ed55..bedcdcebd 100644 --- a/src/databricks/sql/backend/kernel/type_mapping.py +++ b/src/databricks/sql/backend/kernel/type_mapping.py @@ -6,6 +6,13 @@ flattens the conversion so ``KernelResultSet`` and any future kernel-result wrapper share the same mapping. +The string constants come from ``SqlType`` in the SEA backend's +``conversion`` module — same single source of truth both backends +already use. The Arrow → ``SqlType`` lookup itself is kernel- +specific (SEA receives type-text from the server and normalises it; +the kernel receives Arrow schemas directly), so the mapping +function stays local but the names are shared. + Parameter binding (``TSparkParameter`` → kernel ``TypedValue``) is not yet implemented — the PyO3 ``Statement`` doesn't expose a ``bind_param`` method on this branch. It'll land in a follow-up @@ -18,42 +25,50 @@ import pyarrow +from databricks.sql.backend.sea.utils.conversion import SqlType + def _arrow_type_to_dbapi_string(arrow_type: pyarrow.DataType) -> str: """Map a pyarrow type to the Databricks SQL type name used in - PEP 249 ``description``. Names match what the Thrift backend - produces so consumers can branch on them identically. + PEP 249 ``description``. Names come from ``SqlType`` so the + kernel and SEA backends emit identical type-code strings; + consumers can branch on them identically. """ if pyarrow.types.is_boolean(arrow_type): - return "boolean" + return SqlType.BOOLEAN if pyarrow.types.is_int8(arrow_type): - return "tinyint" + return SqlType.TINYINT if pyarrow.types.is_int16(arrow_type): - return "smallint" + return SqlType.SMALLINT if pyarrow.types.is_int32(arrow_type): - return "int" + return SqlType.INT if pyarrow.types.is_int64(arrow_type): - return "bigint" + return SqlType.BIGINT if pyarrow.types.is_float32(arrow_type): - return "float" + return SqlType.FLOAT if pyarrow.types.is_float64(arrow_type): - return "double" + return SqlType.DOUBLE if pyarrow.types.is_decimal(arrow_type): - return "decimal" + return SqlType.DECIMAL if pyarrow.types.is_string(arrow_type) or pyarrow.types.is_large_string(arrow_type): - return "string" + return SqlType.STRING if pyarrow.types.is_binary(arrow_type) or pyarrow.types.is_large_binary(arrow_type): - return "binary" + return SqlType.BINARY if pyarrow.types.is_date(arrow_type): - return "date" + return SqlType.DATE if pyarrow.types.is_timestamp(arrow_type): - return "timestamp" + return SqlType.TIMESTAMP if pyarrow.types.is_list(arrow_type) or pyarrow.types.is_large_list(arrow_type): - return "array" + return SqlType.ARRAY if pyarrow.types.is_struct(arrow_type): - return "struct" + return SqlType.STRUCT if pyarrow.types.is_map(arrow_type): - return "map" + return SqlType.MAP + # Fallback for types the kernel hasn't been observed to emit yet + # (time32/time64, unsigned ints, dictionary, string_view, + # binary_view, fixed_size_*). ``str(arrow_type)`` produces shapes + # like ``"fixed_size_binary[16]"`` — distinguishable from the + # canonical slugs above, so callers can detect the unknown. return str(arrow_type) diff --git a/tests/unit/test_kernel_auth_bridge.py b/tests/unit/test_kernel_auth_bridge.py index dfad26ede..c6bda4f87 100644 --- a/tests/unit/test_kernel_auth_bridge.py +++ b/tests/unit/test_kernel_auth_bridge.py @@ -73,6 +73,50 @@ def test_pat_routes_to_kernel_pat(self): kwargs = kernel_auth_kwargs(AccessTokenAuthProvider("dapi-xyz")) assert kwargs == {"auth_type": "pat", "access_token": "dapi-xyz"} + @pytest.mark.parametrize( + "scheme", + ["Bearer ", "bearer ", "BEARER ", "BeArEr "], + ids=["title", "lower", "upper", "mixed"], + ) + def test_bearer_prefix_is_case_insensitive(self, scheme): + """RFC 6750 §2.1: the Authorization scheme is case-insensitive. + A provider that emits ``bearer`` (lower) or ``BEARER`` (upper) + must route through PAT, not fall through to a confusing + ``ProgrammingError`` from the missing-header check.""" + + class _CustomCaseProvider(AccessTokenAuthProvider): + def add_headers(self, request_headers): + request_headers["Authorization"] = f"{scheme}dapi-xyz" + + kwargs = kernel_auth_kwargs(_CustomCaseProvider("dapi-xyz")) + assert kwargs == {"auth_type": "pat", "access_token": "dapi-xyz"} + + @pytest.mark.parametrize( + "bad_token", + [ + "dapi\x00null", # NUL + "dapi\rfoo", # CR + "dapi\nfoo", # LF + "dapi\x7fdel", # DEL + "dapi has space", # space inside token + "dapi\tfoo", # tab + ], + ids=["nul", "cr", "lf", "del", "space", "tab"], + ) + def test_token_with_control_chars_or_whitespace_rejected(self, bad_token): + """Defense-in-depth: a Bearer token containing CR/LF/NUL would + let a misbehaving HTTP stack split or terminate the + Authorization header line. Space/tab are also rejected + because RFC 6750 forbids whitespace inside the credential + token. Surface as ``ProgrammingError`` at bridge-build time.""" + + class _BadTokenProvider(AccessTokenAuthProvider): + def add_headers(self, request_headers): + request_headers["Authorization"] = f"Bearer {bad_token}" + + with pytest.raises(ProgrammingError, match="control characters or whitespace"): + kernel_auth_kwargs(_BadTokenProvider("ignored")) + def test_federation_wrapped_pat_routes_to_kernel_pat(self): """``get_python_sql_connector_auth_provider`` always wraps the base provider in a ``TokenFederationProvider``, so the diff --git a/tests/unit/test_kernel_client.py b/tests/unit/test_kernel_client.py index 99d867ca4..f43d8c7c7 100644 --- a/tests/unit/test_kernel_client.py +++ b/tests/unit/test_kernel_client.py @@ -368,6 +368,26 @@ def test_get_query_state_raises_on_failed_state_with_failure(): c.get_query_state(cid) +def test_get_query_state_handles_non_baseexception_failure(): + """If the kernel's status() ever returns a ``failure`` that isn't + a real ``KernelError`` (struct, dict, custom type — kernel API + drift), ``get_query_state`` must still surface a mapped PEP 249 + exception. The naive ``raise ... from failure`` would raise + ``TypeError: exception causes must derive from BaseException``; + the wrap helper deals with it.""" + c = _make_client() + fake_handle = MagicMock() + # ``failure`` is a plain dict (not BaseException) — simulates a + # kernel binding that exposes the failure as a structured value. + fake_handle.status.return_value = ("Failed", {"code": "Internal", "msg": "weird"}) + cid = CommandId.from_sea_statement_id("xyz") + c._async_handles[cid.guid] = fake_handle + # Must surface as a PEP 249 exception (OperationalError via the + # wrap helper's fallback path), not TypeError. + with pytest.raises(OperationalError): + c.get_query_state(cid) + + def test_get_query_state_returns_state_when_no_failure(): c = _make_client() fake_handle = MagicMock() diff --git a/tests/unit/test_kernel_result_set.py b/tests/unit/test_kernel_result_set.py index 2078441c4..87fa4f0d2 100644 --- a/tests/unit/test_kernel_result_set.py +++ b/tests/unit/test_kernel_result_set.py @@ -168,3 +168,29 @@ def test_close_swallows_handle_close_failures(int_schema): rs = _make_rs(handle) rs.close() # must not raise assert rs.status == CommandState.CLOSED + + +def test_close_skips_kernel_call_when_connection_already_closed(int_schema): + """``__del__``-driven close arriving after the parent connection + is already closed must not issue any kernel call (the kernel + session is disposed). The result set still marks itself + ``CLOSED`` locally so the close path stays idempotent.""" + handle = _FakeKernelHandle(int_schema, []) + connection = MagicMock() + connection.open = False + backend = MagicMock() + rs = KernelResultSet( + connection=connection, + backend=backend, + kernel_handle=handle, + command_id=CommandId.from_sea_statement_id("conn-closed-test"), + arraysize=100, + buffer_size_bytes=1024, + ) + rs.close() + # No kernel-side calls fired: + assert handle.closed is False + assert backend.close_command.called is False + # …but local state is still terminal so __del__ is idempotent: + assert rs.status == CommandState.CLOSED + assert rs.has_been_closed_server_side is True