diff --git a/pyproject.toml b/pyproject.toml index 78f7f2b..a2c8123 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "pandas", "types-requests>=2.31.0", "pandas-stubs>=2.0.3.230814", + "sqlparse>=0.5.0", ] [project.optional-dependencies] diff --git a/tests/test_redaction.py b/tests/test_redaction.py new file mode 100644 index 0000000..a553372 --- /dev/null +++ b/tests/test_redaction.py @@ -0,0 +1,151 @@ +"""Tests for SQL literal redaction used in query logging.""" + +import pytest + +from wherobots.db.redaction import get_statement_type, redact_sql + + +def test_redacts_single_string_literal() -> None: + assert redact_sql("SELECT * FROM t WHERE name = 'alice'") == ( + "SELECT * FROM t WHERE name = ?" + ) + + +def test_redacts_string_with_escaped_quote() -> None: + # Doubled single-quote is an escaped quote inside the literal; the whole + # literal must collapse to one placeholder (no leakage of the second half). + assert redact_sql("SELECT * FROM t WHERE name = 'O''Brien'") == ( + "SELECT * FROM t WHERE name = ?" + ) + + +def test_redacts_string_with_backslash_escaped_quote() -> None: + assert redact_sql(r"SELECT * FROM t WHERE name = 'a\'b'") == ( + "SELECT * FROM t WHERE name = ?" + ) + + +def test_redacts_numeric_literals() -> None: + assert redact_sql("SELECT * FROM t WHERE age = 42 AND score > 3.14") == ( + "SELECT * FROM t WHERE age = ? AND score > ?" + ) + + +def test_redacts_scientific_notation() -> None: + assert redact_sql("SELECT * FROM t WHERE x < 1.5e10") == ( + "SELECT * FROM t WHERE x < ?" + ) + + +def test_double_quoted_identifier_left_intact() -> None: + # Double-quoted identifiers are column/table names, not value literals. + assert redact_sql('SELECT "user id", count(*) FROM "my table" WHERE x = 1') == ( + 'SELECT "user id", count(*) FROM "my table" WHERE x = ?' + ) + + +def test_backtick_identifier_left_intact() -> None: + assert redact_sql("SELECT `col` FROM `db`.`tbl` WHERE n = 5") == ( + "SELECT `col` FROM `db`.`tbl` WHERE n = ?" + ) + + +def test_show_tblproperties_statement() -> None: + # SHOW TBLPROPERTIES with a quoted property key: the single-quoted literal + # still redacts, while the table name (an identifier) stays intact. + assert redact_sql("SHOW TBLPROPERTIES my_db.my_table ('comment')") == ( + "SHOW TBLPROPERTIES my_db.my_table (?)" + ) + + +def test_select_where_secret_redacted() -> None: + statement = "SELECT id FROM users WHERE ssn = '123-45-6789'" + redacted = redact_sql(statement) + assert "123-45-6789" not in redacted + assert redacted == "SELECT id FROM users WHERE ssn = ?" + + +def test_identifier_with_digits_not_redacted() -> None: + # Column names containing digits must not be treated as numeric literals. + assert redact_sql("SELECT col1, t2.col3 FROM tbl4") == ( + "SELECT col1, t2.col3 FROM tbl4" + ) + + +def test_unterminated_string_does_not_leak() -> None: + redacted = redact_sql("SELECT * FROM t WHERE x = 'unterminated secret") + assert "secret" not in redacted + assert redacted == "SELECT * FROM t WHERE x = ?" + + +def test_unterminated_quote_does_not_leak_later_statements() -> None: + # Multi-statement input where the FIRST statement contains an unterminated + # quote (a lone ``"`` opener, which sqlparse emits as an Error token) and a + # LATER statement contains a secret literal. Redaction must fail closed and + # bail on the entire input, never appending a later statement verbatim -- + # otherwise the secret would leak into the logs. (sqlparse splits this into + # two statements, so this exercises the cross-statement path specifically.) + statement = "SELECT \" FROM t; SELECT v FROM creds WHERE token = 'topsecret123'" + redacted = redact_sql(statement) + assert "topsecret123" not in redacted + assert redacted == "SELECT ?" + + +@pytest.mark.parametrize( + ("statement", "expected", "secret"), + [ + # Hex/blob literal: sqlparse classifies the quoted body as String.Single, + # so the value collapses to ? (the bare X prefix is a Name and is kept). + ( + "SELECT * FROM t WHERE b = X'deadbeef'", + "SELECT * FROM t WHERE b = X?", + "deadbeef", + ), + # Unicode string literal (U&'...'): the quoted body is String.Single. + ( + r"SELECT * FROM t WHERE s = U&'\0041'", + "SELECT * FROM t WHERE s = U&?", + "0041", + ), + # National-character / symbol-prefixed literal (N'...'): String.Single. + ("SELECT * FROM t WHERE s = N'abc'", "SELECT * FROM t WHERE s = N?", "abc"), + ], +) +def test_prefixed_string_literals_redacted( + statement: str, expected: str, secret: str +) -> None: + # Hex (X'..'), unicode (U&'..') and symbol-prefixed (N'..') string literals + # all tokenize their value-bearing quoted body as String.Single, so the + # existing guard already redacts them -- no broadening to ``T.String`` is + # needed (and broadening would be harmful: double-quoted identifiers are + # String.Symbol, see test_double_quoted_identifier_left_intact). + redacted = redact_sql(statement) + assert secret not in redacted + assert redacted == expected + + +def test_multiple_literals_mixed() -> None: + statement = "INSERT INTO t (a, b) VALUES ('x', 10), ('y', 20)" + assert redact_sql(statement) == "INSERT INTO t (a, b) VALUES (?, ?), (?, ?)" + + +def test_line_comment_preserved_literal_in_comment_kept() -> None: + # We only redact value positions; a literal inside a comment is left as-is + # because comments are preserved verbatim (structure-preserving). + assert redact_sql("SELECT 1 -- note: keep this\n") == ( + "SELECT ? -- note: keep this\n" + ) + + +@pytest.mark.parametrize( + ("statement", "expected"), + [ + ("SELECT 1", "SELECT"), + (" show tables", "SHOW"), + ("describe t", "DESCRIBE"), + ("/* c */ SELECT 1", "UNKNOWN"), + ("", "UNKNOWN"), + ], +) +def test_get_statement_type(statement: str, expected: str) -> None: + assert get_statement_type(statement) == expected diff --git a/uv.lock b/uv.lock index 1b9602a..ad2e74b 100644 --- a/uv.lock +++ b/uv.lock @@ -942,6 +942,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "sqlparse" +version = "0.5.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/90/76/437d71068094df0726366574cf3432a4ed754217b436eb7429415cf2d480/sqlparse-0.5.5.tar.gz", hash = "sha256:e20d4a9b0b8585fdf63b10d30066c7c94c5d7a7ec47c889a2d83a3caa93ff28e", size = 120815, upload-time = "2025-12-19T07:17:45.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/4b/359f28a903c13438ef59ebeee215fb25da53066db67b305c125f1c6d2a25/sqlparse-0.5.5-py3-none-any.whl", hash = "sha256:12a08b3bf3eec877c519589833aed092e2444e68240a3577e8e26148acc7b1ba", size = 46138, upload-time = "2025-12-19T07:17:46.573Z" }, +] + [[package]] name = "strenum" version = "0.4.15" @@ -1157,6 +1166,7 @@ dependencies = [ { name = "pandas-stubs" }, { name = "pyarrow" }, { name = "requests" }, + { name = "sqlparse" }, { name = "strenum" }, { name = "tenacity" }, { name = "types-requests" }, @@ -1188,6 +1198,7 @@ requires-dist = [ { name = "pyarrow", specifier = ">=14.0.2" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=8.0.2" }, { name = "requests", specifier = ">=2.31.0" }, + { name = "sqlparse", specifier = ">=0.5.0" }, { name = "strenum", specifier = ">=0.4.15,<0.5" }, { name = "tenacity", specifier = ">=8.2.3" }, { name = "types-requests", specifier = ">=2.31.0" }, diff --git a/wherobots/db/connection.py b/wherobots/db/connection.py index 96666ca..cdb420b 100644 --- a/wherobots/db/connection.py +++ b/wherobots/db/connection.py @@ -6,6 +6,8 @@ from dataclasses import dataclass from typing import Any, Callable, Dict +from wherobots.db.redaction import get_statement_type, redact_sql + import pandas import pyarrow import cbor2 @@ -260,9 +262,26 @@ def _handle_results(self, execution_id: str, results: Dict[str, Any]) -> Any: def __send(self, message: Dict[str, Any]) -> None: request = json.dumps(message) - logging.debug("Request: %s", request) + # Only compute the redacted request (json.dumps + sqlparse parse) when + # DEBUG is actually enabled; the log argument is evaluated eagerly, so an + # unguarded call would redact on every request even with DEBUG off. + if logging.getLogger().isEnabledFor(logging.DEBUG): + logging.debug("Request: %s", self.__redacted_request(message)) self.__ws.send(request) + @staticmethod + def __redacted_request(message: Dict[str, Any]) -> str: + """Serialize a request for logging with any SQL statement redacted. + + The wire payload (sent verbatim by ``__send``) carries the raw + ``statement``; this driver is embedded by other services, so logging it + -- even at DEBUG -- would leak raw SQL into their log streams (WBC-139). + """ + statement = message.get("statement") + if isinstance(statement, str): + message = {**message, "statement": redact_sql(statement)} + return json.dumps(message) + def __recv(self) -> Dict[str, Any]: frame = self.__ws.recv(timeout=self.__read_timeout) if isinstance(frame, str): @@ -301,8 +320,13 @@ def __execute_sql( store=store, ) + # Redact literal values before logging: this driver is embedded by other + # services, so raw SQL here would leak into their log streams (WBC-139). logging.info( - "Executing SQL query %s: %s", execution_id, textwrap.shorten(sql, width=60) + "Executing SQL query %s (%s): %s", + execution_id, + get_statement_type(sql), + textwrap.shorten(redact_sql(sql), width=200), ) self.__send(request) return execution_id diff --git a/wherobots/db/redaction.py b/wherobots/db/redaction.py new file mode 100644 index 0000000..65996c8 --- /dev/null +++ b/wherobots/db/redaction.py @@ -0,0 +1,78 @@ +"""Redaction helpers for SQL statements logged by this driver. + +The driver logs SQL statements for observability (e.g. ``Executing SQL query +: ...``). Raw statements can embed literal PII (for example ``WHERE ssn = +'123-45-6789'``), and because this library is embedded by other services its log +output ends up in their log streams. This module replaces literal *values* +(single-quoted string literals and numeric literals) with a ``?`` placeholder +while preserving the statement structure: keywords, function names, and +identifiers (including double-quoted/back-quoted identifiers) are kept intact so +the redacted form is still useful for debugging and aggregation. + +The implementation tokenizes with ``sqlparse``, a lenient, pure-Python tokenizer +with zero transitive dependencies. It is dialect-agnostic, never raises on +malformed input, and classifies literals by token type, which lets us redact +precisely the value-bearing tokens while copying everything else (keywords, +identifiers, operators, comments, whitespace) through verbatim. + +This logic is duplicated from the ``sql-session`` service (PR #197); the two +repositories do not share a package, so the implementation is intentionally +replicated here rather than imported. + +This is a best-effort redaction for *logging*. It is not a security boundary and +must not be relied on to sanitize untrusted input for any other purpose. +""" + +import re + +import sqlparse +from sqlparse import tokens as T + +REDACTED_PLACEHOLDER = "?" + +# Leading SQL keyword -> statement type. Used purely for observability tagging. +# Intentionally a regex (not sqlparse's ``Statement.get_type()``) because +# ``get_type()`` returns "UNKNOWN" for SHOW/DESCRIBE/SET, which would regress the +# observability tagging this module exists to support. +_STATEMENT_TYPE_RE = re.compile(r"\s*([A-Za-z]+)") + + +def get_statement_type(statement: str) -> str: + """Return the upper-cased leading keyword of a statement (e.g. ``SELECT``). + + Returns ``"UNKNOWN"`` when the statement does not begin with a word. + """ + match = _STATEMENT_TYPE_RE.match(statement) + if match is None: + return "UNKNOWN" + return match.group(1).upper() + + +def redact_sql(statement: str) -> str: + """Return ``statement`` with string and numeric literals replaced by ``?``. + + Single-quoted string literals (``String.Single``) and numeric literals + (``Number.*``) are each collapsed to a single ``?`` placeholder. Everything + else -- keywords, function names, plain and quoted identifiers (``"col"`` / + `` `col` ``), comments, whitespace, and operators -- is preserved verbatim. + """ + out: list[str] = [] + for stmt in sqlparse.parse(statement): + # sqlparse ships no type stubs, so ``flatten`` is untyped under strict. + for token in stmt.flatten(): # type: ignore[no-untyped-call] + ttype = token.ttype + if ttype in T.String.Single or ttype in T.Number: + out.append(REDACTED_PLACEHOLDER) + elif ttype in T.Error and token.value in ("'", '"', "`"): + # An unterminated quote: sqlparse emits the lone opener as an + # Error token and tokenizes the trailing characters as ordinary + # text. Redact the opener and fail closed by bailing on the + # entire input -- returning here (rather than ``break``, which + # only exits the inner loop) ensures that for multi-statement + # input no later statement is appended verbatim, which would + # leak the literals this function exists to hide. + out.append(REDACTED_PLACEHOLDER) + return "".join(out) + else: + out.append(token.value) + return "".join(out)