Skip to content

SNOW-2000181: Using COPY INTO with pyarrow.timestamp("ns", ...) results in tables with corrupted data in timestamp columns #2229

@pacificsky-hs

Description

@pacificsky-hs

Python version

Python 3.10.16 (main, Feb 28 2025, 23:03:01) [GCC 13.3.0]

Operating system and processor architecture

Linux-6.8.0-1024-aws-x86_64-with-glibc2.39

Installed packages

Package                    Version
-------------------------- ---------
asn1crypto                 1.5.1
certifi                    2025.1.31
cffi                       1.17.1
charset-normalizer         3.4.1
cryptography               44.0.2
filelock                   3.18.0
idna                       3.10
numpy                      2.2.4
packaging                  24.2
pip                        23.0.1
platformdirs               4.3.7
pyarrow                    16.1.0
pyarrow-hotfix             0.6
pycparser                  2.22
PyJWT                      2.10.1
pyOpenSSL                  24.3.0
pytz                       2025.1
requests                   2.32.3
setuptools                 65.5.0
snowflake-connector-python 3.13.2
sortedcontainers           2.4.0
tomlkit                    0.13.2
typing_extensions          4.12.2
urllib3                    2.3.0

What did you do?

# Replace bolded values below with appropriate values for your environment
# Run like so: uv run script.py

# /// script
# dependencies = [
#   "snowflake-connector-python==3.13.2",
#   "pyarrow==16.1.0",
# ]
# ///

from __future__ import annotations

import logging
import sys
from datetime import datetime, timezone
from os import getenv
from typing import Any

import pyarrow as pa
import snowflake.connector as sc
from pyarrow import parquet
from pyarrow.fs import S3FileSystem

# Script to test whether pyarrow timestamp columns are correctly being handled when inserting
# data into Snowflake using the Snowflake Connector for Python and PyArrow file upload + bulk upload


def initialize_logging(root_log_level: int = logging.INFO) -> None:
    """
    Initialize logging system with well-formatted stdout logger.
    """
    root = logging.getLogger()
    root.setLevel(root_log_level)

    # Override log level for the Snowflake Python connector, which is very verbose.
    logging.getLogger("snowflake").setLevel(logging.WARNING)

    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(root_log_level)
    formatter = logging.Formatter(
        "[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s"
    )
    handler.setFormatter(formatter)
    # Overwrite any existing handlers. Our API servers will come up with handlers
    # already configured, which causes double-logs without this.
    root.handlers = [handler]


initialize_logging()
_LOGGER = logging.getLogger(__name__)

S3_BASE_DIR = "SOME_BUCKET_NAME/PATH/SUB_PATH/STAGE_ROOT"
# Ensure that stage is defined to point to STAGE_ROOT in the bucket path above
QUOTED_STAGE_NAME = "DATABASE.SCHEMA.STAGE_NAME"

_PYARROW_S3FS: S3FileSystem | None = None


def get_pyarrow_s3fs() -> S3FileSystem:
    """Returns the global singleton S3FileSystem."""
    global _PYARROW_S3FS
    if _PYARROW_S3FS is None:
        _PYARROW_S3FS = S3FileSystem()

    return _PYARROW_S3FS


def insert_from_pyarrow_table(
    file_name: str,
    arrow_table: pa.Table,
    table_name: str,
    column_names_and_types: list[tuple[str, str]],
    cursor: sc.SnowflakeCursor,
) -> None:
    full_file_path = f"{S3_BASE_DIR}/{file_name}"
    stage_file_path = f"STAGE_ROOT/{file_name}"

    s3 = get_pyarrow_s3fs()

    with s3.open_output_stream(full_file_path) as dest:
        parquet.write_table(arrow_table, dest)

    column_selects = [
        f"$1:{column_name.upper()}::{column_type} AS {column_name.upper()}"
        for column_name, column_type in column_names_and_types
    ]

    results = cursor.execute(
        f"""
        COPY INTO PUBLIC.{table_name.upper()}
        ({",".join([f"{col[0]}" for col in column_names_and_types])})
        FROM (
            SELECT
            {",".join(column_selects)}
            FROM @{QUOTED_STAGE_NAME}
        )
        FILES = (%(col1)s)
        ON_ERROR = ABORT_STATEMENT
        """,
        {"col1": stage_file_path},
    )

    status_row = results.fetchall()
    _LOGGER.info(f"Rows loaded: {status_row}")

    s3.delete_file(full_file_path)


def fetch_max_from_table(
    table_name: str,
    column_names_and_types: list[tuple[str, str]],
    cursor: sc.SnowflakeCursor,
) -> None:
    query = (
        f"SELECT MAX({column_names_and_types[1][0]}) FROM PUBLIC.{table_name.upper()}"
    )

    try:
        results = cursor.execute(query)
        data = results.fetchone()
        _LOGGER.info(f"Max value from {table_name}: {data[0]}")
    except sc.errors.ProgrammingError as e:
        _LOGGER.error(f"Error fetching max from {table_name}: {e}")
        _LOGGER.error(
            "This indicates the data was corrupted during the upload process."
        )
        _LOGGER.error("Verify this on app.snowflake.com by querying this table.")


def records_to_pyarrow_table(
    columns: list[list[Any]], fields: list[tuple[str, Any]]
) -> pa.Table | None:
    _fields = [(field.upper(), field_type) for field, field_type in fields]

    return pa.Table.from_arrays(columns, schema=pa.schema(_fields))


def datetime_now() -> datetime:
    return datetime.now(tz=timezone.utc)


def main() -> None:
    account = "SNOWFLAKE_ACCOUNT_NAME"
    user = "SNOWFLAKE_USER_NAME"
    password = getenv("SNOWFLAKE_PASSWORD")
    role = "SNOWFLAKE_ROLE_NAME"
    database = "SNOWFLAKE_DATABASE_NAME"
    warehouse = "SNOWFLAKE_WAREHOUSE_NAME"

    conn_params = {
        "account": account,
        "user": user,
        "password": password,
        "role": role,
        "database": database,
        "warehouse": warehouse,
    }

    ctx = sc.connect(**conn_params)
    cs = ctx.cursor()
    try:
        # Validate our connection with a simple query
        cs.execute("select current_version()")
        one = cs.fetchone()
        _LOGGER.info(one[0])

        cs.execute(
            "CREATE OR REPLACE TABLE public.timestamp_test_1 (base varchar, ts TIMESTAMP_TZ)"
        )
        cs.execute(
            "CREATE OR REPLACE TABLE public.timestamp_test_2 (base varchar, ts TIMESTAMP_TZ)"
        )
        cs.execute(
            "CREATE OR REPLACE TABLE public.timestamp_test_3 (base varchar, ts TIMESTAMP_TZ)"
        )
        cs.execute(
            "CREATE OR REPLACE TABLE public.timestamp_test_4 (base varchar, ts TIMESTAMP_TZ)"
        )

        dt1 = datetime_now()
        _LOGGER.info(f"dt1: {dt1}")
        dt2 = datetime_now()
        _LOGGER.info(f"dt2: {dt2}")
        dt3 = datetime_now()
        _LOGGER.info(f"dt3: {dt3}")

        table_snowflake_schema = [
            ("base", "VARCHAR"),
            ("ts", "TIMESTAMP_TZ"),
        ]

        # Insert data into timestamp_test_1 using pyarrow with ns precision
        # This produces corrupted data in the ts column in the Snowflake table, as demonstrated
        # by the subsequent SQL query.
        # Directly querying the column in Snowflake shows "Invalid Date" for the timestamp column i
        # on app.snowflake.com
        table1 = records_to_pyarrow_table(
            columns=[
                ["value1", "value2", "value3"],
                [dt1, dt2, dt3],
            ],
            fields=[
                ("base", pa.string()),
                ("ts", pa.timestamp("ns", tz="UTC")),
            ],
        )

        insert_from_pyarrow_table(
            file_name="timestamp_test_1.parquet",
            arrow_table=table1,
            table_name="timestamp_test_1",
            column_names_and_types=table_snowflake_schema,
            cursor=cs,
        )

        # Fetch max from timestamp_test_1 -- this fails
        fetch_max_from_table(
            table_name="timestamp_test_1",
            column_names_and_types=table_snowflake_schema,
            cursor=cs,
        )

        # Insert data into timestamp_test_2 using pyarrow with us precision
        # This works correctly, as the data is inserted with the correct precision
        table2 = records_to_pyarrow_table(
            columns=[
                ["value1", "value2", "value3"],
                [dt1, dt2, dt3],
            ],
            fields=[
                ("base", pa.string()),
                ("ts", pa.timestamp("us", tz="UTC")),
            ],
        )
        insert_from_pyarrow_table(
            file_name="timestamp_test_2.parquet",
            arrow_table=table2,
            table_name="timestamp_test_2",
            column_names_and_types=table_snowflake_schema,
            cursor=cs,
        )

        # Fetch max from timestamp_test_2
        fetch_max_from_table(
            table_name="timestamp_test_2",
            column_names_and_types=table_snowflake_schema,
            cursor=cs,
        )

        # Insert data into timestamp_test_3 using pyarrow with ms precision
        # This works correctly, as the data is inserted with the correct precision
        table3 = records_to_pyarrow_table(
            columns=[
                ["value1", "value2", "value3"],
                [dt1, dt2, dt3],
            ],
            fields=[
                ("base", pa.string()),
                ("ts", pa.timestamp("ms", tz="UTC")),
            ],
        )
        insert_from_pyarrow_table(
            file_name="timestamp_test_3.parquet",
            arrow_table=table3,
            table_name="timestamp_test_3",
            column_names_and_types=table_snowflake_schema,
            cursor=cs,
        )

        # Fetch max from timestamp_test_3
        fetch_max_from_table(
            table_name="timestamp_test_3",
            column_names_and_types=table_snowflake_schema,
            cursor=cs,
        )

        # Insert data into timestamp_test_4 using pyarrow with s precision
        # This works correctly, as the data is inserted with the correct precision
        table4 = records_to_pyarrow_table(
            columns=[
                ["value1", "value2", "value3"],
                [dt1, dt2, dt3],
            ],
            fields=[
                ("base", pa.string()),
                ("ts", pa.timestamp("s", tz="UTC")),
            ],
        )
        insert_from_pyarrow_table(
            file_name="timestamp_test_4.parquet",
            arrow_table=table4,
            table_name="timestamp_test_4",
            column_names_and_types=table_snowflake_schema,
            cursor=cs,
        )

        # Fetch max from timestamp_test_4
        fetch_max_from_table(
            table_name="timestamp_test_4",
            column_names_and_types=table_snowflake_schema,
            cursor=cs,
        )

    finally:
        cs.close()
        ctx.close()


if __name__ == "__main__":
    main()

What did you expect to see?

One of either:

  • Loading data into timestamp columns using pa.timestamp("ns") works correctly without data corruption, or
  • Loading data into timestamp columns using pa.timestamp("ns") fails with an error

Currently we're in the worst world possible - the data load succeeds but the loaded data is corrupt.

FWIW this used to work previously with snowflake-connector-python==3.1.0 and pyarrow==10.0.1 so this is a regression.
We discovered it in production.

Can you set logging to DEBUG and collect the logs?

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions