diff --git a/.github/workflows/check_installation.yml b/.github/workflows/check_installation.yml new file mode 100644 index 0000000000..4c8d31b80d --- /dev/null +++ b/.github/workflows/check_installation.yml @@ -0,0 +1,56 @@ +name: Test Installation + +on: + push: + branches: + - master + - main + pull_request: + branches: + - '**' + workflow_dispatch: + +concurrency: + # older builds for the same pull request number or branch should be cancelled + cancel-in-progress: true + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + +jobs: + test-installation: + name: Test Boto Dependency + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.12 + + - name: Test default installation (should include boto) + shell: bash + run: | + python -m venv test_default_env + source test_default_env/bin/activate + + python -m pip install . + pip freeze | grep boto || exit 1 # boto3/botocore should be installed by default + + # Deactivate and clean up + deactivate + rm -rf test_default_env + + - name: Test installation with SNOWFLAKE_NO_BOTO=1 (should exclude boto) + shell: bash + run: | + python -m venv test_no_boto_env + source test_no_boto_env/bin/activate + + SNOWFLAKE_NO_BOTO=1 python -m pip install . + + # Check that boto3 and botocore are NOT installed + pip freeze | grep boto && exit 1 # boto3 and botocore should be not installed + + # Deactivate and clean up + deactivate + rm -rf test_no_boto_env diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ccf3ceeea6..b62b8b868c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -61,6 +61,13 @@ repos: src/snowflake/connector/vendored/.* )$ args: [--show-fixes] + - id: check-optional-imports + name: Check for direct imports of modules which might be unavailable + entry: python ci/pre-commit/check_optional_imports.py + language: system + files: ^src/snowflake/connector/.*\.py$ + exclude: src/snowflake/connector/options.py + args: [--show-fixes] - repo: https://github.com/PyCQA/flake8 rev: 7.1.1 hooks: diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 3e02e72660..bb556de9b6 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -10,6 +10,15 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne - v3.18.0(TBD) - Added the `workload_identity_impersonation_path` parameter to support service account impersonation for Workload Identity Federation on GCP and AWS workloads only - Fixed `get_results_from_sfqid` when using `DictCursor` and executing multiple statements at once + - Added the `oauth_credentials_in_body` parameter supporting an option to send the oauth client credentials in the request body + - Fix retry behavior for `ECONNRESET` error + - Added an option to exclude `botocore` and `boto3` dependencies by setting `SNOWFLAKE_NO_BOTO` environment variable during installation + - Added support for pandas conversion for Day-time and Year-Month Interval types + +- v3.17.4(September 22,2025) + - Added support for intermediate certificates as roots when they are stored in the trust store + - Bumped up vendored `urllib3` to `2.5.0` and `requests` to `v2.32.5` + - Dropped support for OpenSSL versions older than 1.1.1 - v3.17.3(September 02,2025) - Enhanced configuration file permission warning messages. diff --git a/ci/pre-commit/check_optional_imports.py b/ci/pre-commit/check_optional_imports.py new file mode 100644 index 0000000000..f2a35f0927 --- /dev/null +++ b/ci/pre-commit/check_optional_imports.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +""" +Pre-commit hook to ensure optional dependencies are always imported from .options module. +This ensures that the connector can operate in environments where these optional libraries are not available. +""" +import argparse +import ast +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import List + +CHECKED_MODULES = ["boto3", "botocore", "pandas", "pyarrow", "keyring"] + + +@dataclass(frozen=True) +class ImportViolation: + """Pretty prints a violation import restrictions.""" + + filename: str + line: int + col: int + message: str + + def __str__(self): + return f"{self.filename}:{self.line}:{self.col}: {self.message}" + + +class ImportChecker(ast.NodeVisitor): + """Checks that optional imports are only imported from .options module.""" + + def __init__(self, filename: str): + self.filename = filename + self.violations: List[ImportViolation] = [] + + def visit_If(self, node: ast.If): + # Always visit the condition, but ignore imports inside "if TYPE_CHECKING:" blocks + if getattr(node.test, "id", None) == "TYPE_CHECKING": + # Skip the body and orelse for TYPE_CHECKING blocks + pass + else: + self.generic_visit(node) + + def visit_Import(self, node: ast.Import): + """Check import statements.""" + for alias in node.names: + self._check_import(alias.name, node.lineno, node.col_offset) + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom): + """Check from...import statements.""" + if node.module: + # Check if importing from a checked module directly + for module in CHECKED_MODULES: + if node.module.startswith(module): + self.violations.append( + ImportViolation( + self.filename, + node.lineno, + node.col_offset, + f"Import from '{node.module}' is not allowed. Use 'from .options import {module}' instead", + ) + ) + + # Check if importing checked modules from .options (this is allowed) + if node.module == ".options": + # This is the correct way to import these modules + pass + self.generic_visit(node) + + def _check_import(self, module_name: str, line: int, col: int): + """Check if a module import is for checked modules and not from .options.""" + for module in CHECKED_MODULES: + if module_name.startswith(module): + self.violations.append( + ImportViolation( + self.filename, + line, + col, + f"Direct import of '{module_name}' is not allowed. Use 'from .options import {module}' instead", + ) + ) + break + + +def check_file(filename: str) -> List[ImportViolation]: + """Check a file for optional import violations.""" + try: + tree = ast.parse(Path(filename).read_text()) + except SyntaxError: + # gracefully handle syntax errors + return [] + checker = ImportChecker(filename) + checker.visit(tree) + return checker.violations + + +def main(): + """Main function for pre-commit hook.""" + parser = argparse.ArgumentParser( + description="Check that optional imports are only imported from .options module" + ) + parser.add_argument("filenames", nargs="*", help="Filenames to check") + parser.add_argument( + "--show-fixes", action="store_true", help="Show suggested fixes" + ) + args = parser.parse_args() + + all_violations = [] + for filename in args.filenames: + if not filename.endswith(".py"): + continue + all_violations.extend(check_file(filename)) + + # Show violations + if all_violations: + print("Optional import violations found:") + print() + + for violation in all_violations: + print(f" {violation}") + + if args.show_fixes: + print() + print("How to fix:") + print(" - Import optional modules only from .options module") + print(" - Example:") + print(" # CORRECT:") + print(" from .options import boto3, botocore, installed_boto") + print(" if installed_boto:") + print(" SigV4Auth = botocore.auth.SigV4Auth") + print() + print(" # INCORRECT:") + print(" import boto3") + print(" from botocore.auth import SigV4Auth") + print() + print( + " - This ensures the connector works in environments where optional libraries are not installed" + ) + + print() + print(f"Found {len(all_violations)} violation(s)") + return 1 + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/setup.cfg b/setup.cfg index 69f2f2c55b..d2ceb1f921 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,9 +43,9 @@ project_urls = python_requires = >=3.9 packages = find_namespace: install_requires = + # [boto] extension is added by default unless SNOWFLAKE_NO_BOTO variable is set + # check setup.py asn1crypto>0.24.0,<2.0.0 - boto3>=1.24 - botocore>=1.24 cffi>=1.9,<2.0.0 cryptography>=3.1.0 pyOpenSSL>=22.0.0,<25.0.0 @@ -79,6 +79,9 @@ console_scripts = snowflake-dump-certs = snowflake.connector.tool.dump_certs:main [options.extras_require] +boto = + boto3>=1.24 + botocore>=1.24 development = Cython coverage diff --git a/setup.py b/setup.py index 37e9a96fe2..0b7ab60f4d 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ import warnings from setuptools import Extension, setup +from setuptools.command.egg_info import egg_info CONNECTOR_SRC_DIR = os.path.join("src", "snowflake", "connector") NANOARROW_SRC_DIR = os.path.join(CONNECTOR_SRC_DIR, "nanoarrow_cpp", "ArrowIterator") @@ -38,9 +39,14 @@ extensions = None cmd_class = {} -SNOWFLAKE_DISABLE_COMPILE_ARROW_EXTENSIONS = os.environ.get( - "SNOWFLAKE_DISABLE_COMPILE_ARROW_EXTENSIONS", "false" -).lower() in ("y", "yes", "t", "true", "1", "on") +_POSITIVE_VALUES = ("y", "yes", "t", "true", "1", "on") +SNOWFLAKE_DISABLE_COMPILE_ARROW_EXTENSIONS = ( + os.environ.get("SNOWFLAKE_DISABLE_COMPILE_ARROW_EXTENSIONS", "false").lower() + in _POSITIVE_VALUES +) +SNOWFLAKE_NO_BOTO = ( + os.environ.get("SNOWFLAKE_NO_BOTO", "false").lower() in _POSITIVE_VALUES +) try: from Cython.Build import cythonize @@ -88,7 +94,7 @@ def build_extension(self, ext): ext.sources += [ os.path.join( NANOARROW_ARROW_ITERATOR_SRC_DIR, - *((file,) if isinstance(file, str) else file) + *((file,) if isinstance(file, str) else file), ) for file in { "ArrayConverter.cpp", @@ -174,6 +180,22 @@ def new__compile(obj, src: str, ext, cc_args, extra_postargs, pp_opts): cmd_class = {"build_ext": MyBuildExt} + +class SetDefaultInstallationExtras(egg_info): + """Adds AWS extra unless SNOWFLAKE_NO_BOTO is specified.""" + + def finalize_options(self): + super().finalize_options() + + # if not explicitly excluded, add boto dependencies to install_requires + if not SNOWFLAKE_NO_BOTO: + boto_extras = self.distribution.extras_require.get("boto", []) + self.distribution.install_requires += boto_extras + + +# Update command classes +cmd_class["egg_info"] = SetDefaultInstallationExtras + setup( version=version, ext_modules=extensions, diff --git a/src/snowflake/connector/aio/_wif_util.py b/src/snowflake/connector/aio/_wif_util.py index 553e8e6309..f794f87b0e 100644 --- a/src/snowflake/connector/aio/_wif_util.py +++ b/src/snowflake/connector/aio/_wif_util.py @@ -7,8 +7,7 @@ import aioboto3 from aiobotocore.utils import AioInstanceMetadataRegionFetcher -from botocore.auth import SigV4Auth -from botocore.awsrequest import AWSRequest +from snowflake.connector.options import botocore from ..errorcode import ER_WIF_CREDENTIALS_NOT_FOUND from ..errors import ProgrammingError @@ -57,7 +56,7 @@ async def create_aws_attestation() -> WorkloadIdentityAttestation: region = await get_aws_region() partition = session.get_partition_for_region(region) sts_hostname = get_aws_sts_hostname(region, partition) - request = AWSRequest( + request = botocore.awsrequest.AWSRequest( method="POST", url=f"https://{sts_hostname}/?Action=GetCallerIdentity&Version=2011-06-15", headers={ @@ -66,7 +65,7 @@ async def create_aws_attestation() -> WorkloadIdentityAttestation: }, ) - SigV4Auth(aws_creds, "sts", region).add_auth(request) + botocore.auth.SigV4Auth(aws_creds, "sts", region).add_auth(request) assertion_dict = { "url": request.url, diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp index b853e4a9f7..dc0c169e74 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.cpp @@ -68,6 +68,7 @@ void CArrowTableIterator::convertIfNeeded(ArrowSchema* columnSchema, case SnowflakeType::Type::DATE: case SnowflakeType::Type::REAL: case SnowflakeType::Type::TEXT: + case SnowflakeType::Type::INTERVAL_YEAR_MONTH: case SnowflakeType::Type::VARIANT: case SnowflakeType::Type::VECTOR: { // Do not need to convert @@ -174,6 +175,24 @@ void CArrowTableIterator::convertIfNeeded(ArrowSchema* columnSchema, break; } + case SnowflakeType::Type::INTERVAL_DAY_TIME: { + int scale = 9; + if (metadata != nullptr) { + struct ArrowStringView scaleString = ArrowCharView(nullptr); + returnCode = ArrowMetadataGetValue(metadata, ArrowCharView("scale"), + &scaleString); + SF_CHECK_ARROW_RC(returnCode, + "[Snowflake Exception] error getting 'scale' " + "from Arrow metadata, error code: %d", + returnCode); + scale = + std::stoi(std::string(scaleString.data, scaleString.size_bytes)); + } + convertIntervalDayTimeColumn_nanoarrow(&columnSchemaView, columnArray, + scale); + break; + } + case SnowflakeType::Type::TIME: { int scale = 9; if (metadata != nullptr) { @@ -503,6 +522,76 @@ void CArrowTableIterator:: ArrowArrayMove(newArray, columnArray->array); } +void CArrowTableIterator::convertIntervalDayTimeColumn_nanoarrow( + ArrowSchemaView* field, ArrowArrayView* columnArray, const int scale) { + int returnCode = 0; + nanoarrow::UniqueSchema newUniqueField; + nanoarrow::UniqueArray newUniqueArray; + ArrowSchema* newSchema = newUniqueField.get(); + ArrowArray* newArray = newUniqueArray.get(); + ArrowError error; + + // create new schema + ArrowSchemaInit(newSchema); + newSchema->flags &= + (field->schema->flags & ARROW_FLAG_NULLABLE); // map to nullable() + + returnCode = ArrowSchemaSetTypeDateTime(newSchema, NANOARROW_TYPE_DURATION, + NANOARROW_TIME_UNIT_NANO, NULL); + SF_CHECK_ARROW_RC(returnCode, + "[Snowflake Exception] error setting arrow schema type " + "DateTime, error code: %d", + returnCode); + + returnCode = ArrowSchemaSetName(newSchema, field->schema->name); + SF_CHECK_ARROW_RC( + returnCode, + "[Snowflake Exception] error setting schema name, error code: %d", + returnCode); + + returnCode = ArrowArrayInitFromSchema(newArray, newSchema, &error); + SF_CHECK_ARROW_RC(returnCode, + "[Snowflake Exception] error initializing ArrowArrayView " + "from schema : %s, error code: %d", + ArrowErrorMessage(&error), returnCode); + + returnCode = ArrowArrayStartAppending(newArray); + SF_CHECK_ARROW_RC( + returnCode, + "[Snowflake Exception] error appending arrow array, error code: %d", + returnCode); + + for (int64_t rowIdx = 0; rowIdx < columnArray->array->length; rowIdx++) { + if (ArrowArrayViewIsNull(columnArray, rowIdx)) { + returnCode = ArrowArrayAppendNull(newArray, 1); + SF_CHECK_ARROW_RC(returnCode, + "[Snowflake Exception] error appending null to arrow " + "array, error code: %d", + returnCode); + } else { + ArrowDecimal arrowDecimal; + ArrowDecimalInit(&arrowDecimal, 128, 38, 0); + ArrowArrayViewGetDecimalUnsafe(columnArray, rowIdx, &arrowDecimal); + auto originalVal = ArrowDecimalGetIntUnsafe(&arrowDecimal); + returnCode = ArrowArrayAppendInt(newArray, originalVal); + SF_CHECK_ARROW_RC(returnCode, + "[Snowflake Exception] error appending int to arrow " + "array, error code: %d", + returnCode); + } + } + + returnCode = ArrowArrayFinishBuildingDefault(newArray, &error); + SF_CHECK_ARROW_RC(returnCode, + "[Snowflake Exception] error finishing building arrow " + "array: %s, error code: %d", + ArrowErrorMessage(&error), returnCode); + field->schema->release(field->schema); + ArrowSchemaMove(newSchema, field->schema); + columnArray->array->release(columnArray->array); + ArrowArrayMove(newArray, columnArray->array); +} + void CArrowTableIterator::convertTimeColumn_nanoarrow( ArrowSchemaView* field, ArrowArrayView* columnArray, const int scale) { int returnCode = 0; diff --git a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.hpp b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.hpp index 7615ed264d..759e832108 100644 --- a/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.hpp +++ b/src/snowflake/connector/nanoarrow_cpp/ArrowIterator/CArrowTableIterator.hpp @@ -94,6 +94,14 @@ class CArrowTableIterator : public CArrowIterator { ArrowArrayView* columnArray, const int scale); + /** + * convert Snowflake Interval Day-Time column (Arrow int64/decimal128) to + * Arrow Duration column + */ + void convertIntervalDayTimeColumn_nanoarrow(ArrowSchemaView* field, + ArrowArrayView* columnArray, + const int scale); + /** * convert Snowflake TimestampNTZ/TimestampLTZ column to Arrow Timestamp * column diff --git a/src/snowflake/connector/options.py b/src/snowflake/connector/options.py index 8454ab1699..2eeff7c449 100644 --- a/src/snowflake/connector/options.py +++ b/src/snowflake/connector/options.py @@ -48,6 +48,18 @@ class MissingKeyring(MissingOptionalDependency): _dep_name = "keyring" +class MissingBotocore(MissingOptionalDependency): + """The class is specifically for boto optional dependency.""" + + _dep_name = "botocore" + + +class MissingBoto3(MissingOptionalDependency): + """The class is specifically for boto3 optional dependency.""" + + _dep_name = "boto3" + + ModuleLikeObject = Union[ModuleType, MissingOptionalDependency] @@ -126,6 +138,17 @@ def _import_or_missing_keyring_option() -> tuple[ModuleLikeObject, bool]: return MissingKeyring(), False +def _import_or_missing_boto_option() -> tuple[ModuleLikeObject, ModuleLikeObject, bool]: + """This function tries importing the following packages: botocore and boto3.""" + try: + botocore = importlib.import_module("botocore") + boto3 = importlib.import_module("boto3") + return botocore, boto3, True + except ImportError: + return MissingBotocore(), MissingBoto3(), False + + # Create actual constants to be imported from this file pandas, pyarrow, installed_pandas = _import_or_missing_pandas_option() keyring, installed_keyring = _import_or_missing_keyring_option() +botocore, boto3, installed_boto = _import_or_missing_boto_option() diff --git a/src/snowflake/connector/platform_detection.py b/src/snowflake/connector/platform_detection.py index ec615be24d..2c6f0eeff1 100644 --- a/src/snowflake/connector/platform_detection.py +++ b/src/snowflake/connector/platform_detection.py @@ -7,9 +7,11 @@ from enum import Enum from functools import cache -import boto3 -from botocore.config import Config -from botocore.utils import IMDSFetcher +from .options import boto3, botocore, installed_boto + +if installed_boto: + Config = botocore.config.Config + IMDSFetcher = botocore.utils.IMDSFetcher from .session_manager import SessionManager from .vendored.requests import RequestException, Timeout @@ -40,6 +42,10 @@ def is_ec2_instance(platform_detection_timeout_seconds: float): Returns: _DetectionState: DETECTED if running on EC2, NOT_DETECTED otherwise. """ + if not installed_boto: + logger.debug("boto3 is not installed, skipping EC2 instance detection") + return _DetectionState.NOT_DETECTED + try: fetcher = IMDSFetcher( timeout=platform_detection_timeout_seconds, num_attempts=1 @@ -105,6 +111,10 @@ def has_aws_identity(platform_detection_timeout_seconds: float): Returns: _DetectionState: DETECTED if valid AWS identity exists, NOT_DETECTED otherwise. """ + if not installed_boto: + logger.debug("boto3 is not installed, skipping AWS identity detection") + return _DetectionState.NOT_DETECTED + try: config = Config( connect_timeout=platform_detection_timeout_seconds, diff --git a/src/snowflake/connector/wif_util.py b/src/snowflake/connector/wif_util.py index 406ee12725..2d7302426f 100644 --- a/src/snowflake/connector/wif_util.py +++ b/src/snowflake/connector/wif_util.py @@ -7,14 +7,17 @@ from dataclasses import dataclass from enum import Enum, unique -import boto3 import jwt -from botocore.auth import SigV4Auth -from botocore.awsrequest import AWSRequest -from botocore.utils import InstanceMetadataRegionFetcher + +from .options import boto3, botocore, installed_boto + +if installed_boto: + SigV4Auth = botocore.auth.SigV4Auth + AWSRequest = botocore.awsrequest.AWSRequest + InstanceMetadataRegionFetcher = botocore.utils.InstanceMetadataRegionFetcher from .errorcode import ER_INVALID_WIF_SETTINGS, ER_WIF_CREDENTIALS_NOT_FOUND -from .errors import ProgrammingError +from .errors import MissingDependencyError, ProgrammingError from .session_manager import SessionManager logger = logging.getLogger(__name__) @@ -149,6 +152,12 @@ def create_aws_attestation( If the application isn't running on AWS or no credentials were found, raises an error. """ + if not installed_boto: + raise MissingDependencyError( + msg="AWS Workload Identity Federation can't be used because boto3 or botocore optional dependency is not installed. Try installing missing dependencies.", + errno=ER_WIF_CREDENTIALS_NOT_FOUND, + ) + # TODO: SNOW-2223669 Investigate if our adapters - containing settings of http traffic - should be passed here as boto urllib3session. Those requests go to local servers, so they do not need Proxy setup or Headers customization in theory. But we may want to have all the traffic going through one class (e.g. Adapter or mixin). session = boto3.session.Session() aws_creds = session.get_credentials() diff --git a/test/integ/aio_it/pandas_it/test_arrow_pandas_async.py b/test/integ/aio_it/pandas_it/test_arrow_pandas_async.py index 557cdc2907..25c7a746e0 100644 --- a/test/integ/aio_it/pandas_it/test_arrow_pandas_async.py +++ b/test/integ/aio_it/pandas_it/test_arrow_pandas_async.py @@ -662,6 +662,26 @@ async def test_vector(conn_cnx, is_public_test): await finish(conn, table) +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +async def test_interval_year_month(conn_cnx): + cases = ["1-2", "-1-3", "999999999-11", "-999999999-11"] + table = "test_arrow_year_month_interval" + values = "(" + "),(".join([f"'{c}'" for c in cases]) + ")" + async with conn_cnx() as conn: + cursor = conn.cursor() + await cursor.execute("alter session set feature_interval_types=enabled") + await cursor.execute( + f"create or replace table {table} (a interval year to month)" + ) + await cursor.execute(f"insert into {table} values {values}") + sql_text = f"select a from {table}" + await validate_pandas(conn, sql_text, cases, 1, "one", "interval_year_month") + await finish(conn, table) + + async def validate_pandas( cnx_table, sql, @@ -740,6 +760,23 @@ async def validate_pandas( c_case = Decimal(cases[i]) elif data_type == "date": c_case = datetime.strptime(cases[i], "%Y-%m-%d").date() + elif data_type == "interval_year_month": + year_month_list = cases[i].split("-") + if len(year_month_list) == 2: + c_case = int(year_month_list[0]) * 12 + int( + year_month_list[1] + ) + else: + # negative value + c_case = -( + int(year_month_list[1]) * 12 + int(year_month_list[2]) + ) + elif data_type == "interval_day_time": + timedelta_split_days = cases[i].split(" ") + pandas_timedelta_str = ( + timedelta_split_days[0] + " days " + timedelta_split_days[1] + ) + c_case = pandas.to_timedelta(pandas_timedelta_str) elif data_type == "time": time_str_len = 8 if scale == 0 else 9 + scale c_case = cases[i].strip()[:time_str_len] diff --git a/test/integ/pandas_it/test_arrow_pandas.py b/test/integ/pandas_it/test_arrow_pandas.py index bc954e7d6f..cab84cb05f 100644 --- a/test/integ/pandas_it/test_arrow_pandas.py +++ b/test/integ/pandas_it/test_arrow_pandas.py @@ -62,6 +62,42 @@ def test_num_one(conn_cnx): fetch_pandas(conn_cnx, sql_exec, row_count, col_count, "one") +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +def test_interval_year_month(conn_cnx): + cases = ["1-2", "-1-3", "999999999-11", "-999999999-11"] + table = "test_arrow_year_month_interval" + values = "(" + "),(".join([f"'{c}'" for c in cases]) + ")" + with conn_cnx() as conn: + cursor = conn.cursor() + cursor.execute("alter session set feature_interval_types=enabled") + cursor.execute(f"create or replace table {table} (a interval year to month)") + cursor.execute(f"insert into {table} values {values}") + sql_text = f"select a from {table}" + validate_pandas(conn, sql_text, cases, 1, "one", "interval_year_month") + finish(conn, table) + + +@pytest.mark.skipif( + not installed_pandas or no_arrow_iterator_ext, + reason="arrow_iterator extension is not built, or pandas is missing.", +) +def test_interval_day_time(conn_cnx): + cases = ["106751 23:47:16.854775807", "0 0:0:0.0", "-5 0:0:0.0"] + table = "test_arrow_day_time_interval" + values = "(" + "),(".join([f"'{c}'" for c in cases]) + ")" + with conn_cnx() as conn: + cursor = conn.cursor() + cursor.execute("alter session set feature_interval_types=enabled") + cursor.execute(f"create or replace table {table} (a interval day to second)") + cursor.execute(f"insert into {table} values {values}") + sql_text = f"select a from {table}" + validate_pandas(conn, sql_text, cases, 1, "one", "interval_day_time") + finish(conn, table) + + @pytest.mark.skipif( not installed_pandas or no_arrow_iterator_ext, reason="arrow_iterator extension is not built, or pandas is missing.", @@ -734,6 +770,23 @@ def validate_pandas( c_case = Decimal(cases[i]) elif data_type == "date": c_case = datetime.strptime(cases[i], "%Y-%m-%d").date() + elif data_type == "interval_year_month": + year_month_list = cases[i].split("-") + if len(year_month_list) == 2: + c_case = int(year_month_list[0]) * 12 + int( + year_month_list[1] + ) + else: + # negative value + c_case = -( + int(year_month_list[1]) * 12 + int(year_month_list[2]) + ) + elif data_type == "interval_day_time": + timedelta_split_days = cases[i].split(" ") + pandas_timedelta_str = ( + timedelta_split_days[0] + " days " + timedelta_split_days[1] + ) + c_case = pandas.to_timedelta(pandas_timedelta_str) elif data_type == "time": time_str_len = 8 if scale == 0 else 9 + scale c_case = cases[i].strip()[:time_str_len]