diff --git a/arctic_training/__init__.py b/arctic_training/__init__.py index c9ebc8e2..a425f61d 100644 --- a/arctic_training/__init__.py +++ b/arctic_training/__init__.py @@ -37,6 +37,7 @@ from arctic_training.data.hf_instruct_source import HFDataSourceInstruct from arctic_training.data.hf_source import HFDataSource from arctic_training.data.sft_factory import SFTDataFactory +from arctic_training.data.snowflake_source import SnowflakeDataSource from arctic_training.data.source import DataSource from arctic_training.logging import logger from arctic_training.model.factory import ModelFactory diff --git a/arctic_training/data/snowflake_source.py b/arctic_training/data/snowflake_source.py new file mode 100644 index 00000000..02b93470 --- /dev/null +++ b/arctic_training/data/snowflake_source.py @@ -0,0 +1,233 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import Optional + +from pydantic import Field +from pydantic import model_validator +from typing_extensions import Self + +from arctic_training.config.data import DataSourceConfig +from arctic_training.data.source import DataSource +from arctic_training.data.utils import DatasetType + +if TYPE_CHECKING: + from snowflake.snowpark import Session + +_DATASET_URI_PATTERN = re.compile(r"^snow://dataset/([^/]+)/versions/([^/]+)$") + + +def _check_snowflake_ml_installed() -> None: + """Check if snowflake-ml-python is installed.""" + try: + import snowflake.ml # noqa: F401 + except ImportError: + raise ImportError( + "snowflake-ml-python is required for Snowflake data sources. " + "Install with: pip install 'arctic_training[snowflake]'" + ) + + +def get_default_snowflake_session() -> "Session": + """ + Get or create a default Snowflake Session. + + This function attempts to get an active Snowpark session. If none exists, + it creates a new session using default connection parameters. + + The session can be configured via: + - Environment variables (SNOWFLAKE_ACCOUNT, SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, etc.) + - A Snowflake connection configuration file (~/.snowflake/connections.toml) + - The SNOWFLAKE_DEFAULT_CONNECTION_NAME environment variable + + Returns: + A Snowpark Session object. + + Raises: + ImportError: If snowflake-snowpark-python is not installed. + Exception: If session creation fails due to missing or invalid credentials. + """ + _check_snowflake_ml_installed() + + from snowflake.snowpark import Session + + try: + # Get an existing active session or create a new one using default connection + # This will use environment variables or ~/.snowflake/connections.toml + return Session.builder.getOrCreate() + except Exception: + from snowflake.ml._internal.utils.connection_params import SnowflakeLoginOptions + + # Fall back to SnowML's connection parameters + config = SnowflakeLoginOptions() + return Session.builder.configs(config).getOrCreate() # noqa: F841 + + +class SnowflakeSourceConfig(DataSourceConfig): + """Configuration for Snowflake data sources. + + Supports three mutually exclusive modes: + - sql: Execute a raw SQL query + - table_name: Load all data from a table (generates SELECT * FROM table_name) + - dataset_uri: Load from a versioned Snowflake Dataset + """ + + sql: Optional[str] = None + """ + SQL query to execute against Snowflake. + Example: 'SELECT col1, col2 FROM my_db.my_schema.my_table WHERE created_at > "2024-01-01"' + """ + + table_name: Optional[str] = None + """ + Snowflake table reference in format [[db.]schema.]table_name. + Examples: 'my_table', 'my_schema.my_table', 'my_db.my_schema.my_table' + """ + + dataset_uri: Optional[str] = None + """ + Snowflake Dataset URI in format snow://dataset//versions/. + Where is in format [[db.]schema.]dataset_name. + Examples: 'snow://dataset/my_training_set/versions/v1', 'snow://dataset/my_schema.my_dataset/versions/v1' + """ + + column_mapping: Dict[str, str] = Field(default_factory=dict) + """ + Optional mapping from source column names to target column names. + If empty, data passes through unchanged. + Example: {'source_col': 'target_col'} renames 'source_col' to 'target_col'. + """ + + limit: Optional[int] = None + """Maximum number of rows to load. If None, loads all rows.""" + + batch_size: int = 1024 + """Batch size for internal data retrieval.""" + + @model_validator(mode="after") + def validate_exactly_one_source(self) -> Self: + """Ensure exactly one of sql, table_name, or dataset_uri is specified.""" + sources = [self.sql, self.table_name, self.dataset_uri] + specified = sum(1 for s in sources if s is not None) + if specified != 1: + raise ValueError("Exactly one of 'sql', 'table_name', or 'dataset_uri' must be specified") + + # Auto-generate sql from table_name + if self.table_name: + self.sql = f"SELECT * FROM {self.table_name}" + + # Validate dataset_uri format if specified + if self.dataset_uri: + match = _DATASET_URI_PATTERN.match(self.dataset_uri) + if not match: + raise ValueError( + f"Invalid dataset_uri format: '{self.dataset_uri}'. " + "Expected format: 'snow://dataset//versions/'" + ) + + # Validate the dataset_name component using Snowflake's identifier parser + from snowflake.ml._internal.utils.identifier import parse_schema_level_object_identifier + + dataset_name = match.group(1) + try: + parse_schema_level_object_identifier(dataset_name) + except ValueError as e: + raise ValueError(f"Invalid dataset_name format in URI: {e}") + + return self + + +class SnowflakeDataSource(DataSource): + """DataSource for loading data from Snowflake. + + Supports three modes: + - SQL query: Execute arbitrary SQL and load results + - Table: Load all data from a Snowflake table + - Dataset: Load from a versioned Snowflake Dataset + """ + + name = "snowflake" + config: SnowflakeSourceConfig + + session: Optional[Any] = None + """ + Optional Snowpark Session to use for connecting to Snowflake. + If None, a default session will be created using get_default_snowflake_session(). + """ + + def load(self, config: SnowflakeSourceConfig, split: str) -> DatasetType: + """Load data from Snowflake. + + Routes to the appropriate loading method based on config. + """ + _check_snowflake_ml_installed() + + session = self.session or get_default_snowflake_session() + + if config.dataset_uri: + return self._load_from_dataset(config, session=session) + else: + return self._load_from_sql(config, session=session) + + def _load_from_sql(self, config: SnowflakeSourceConfig, *, session: "Session") -> DatasetType: + """Load data using a SQL query.""" + from snowflake.ml.data.data_connector import DataConnector + + # Create connector from SQL query + connector = DataConnector.from_sql(config.sql, session=session) + + # Convert to HuggingFace dataset + dataset = connector.to_huggingface_dataset( + streaming=False, + limit=config.limit, + batch_size=config.batch_size, + ) + + return dataset + + def _load_from_dataset(self, config: SnowflakeSourceConfig, *, session: "Session") -> DatasetType: + """Load data from a Snowflake Dataset.""" + from snowflake.ml.data.data_connector import DataConnector + from snowflake.ml.dataset import load_dataset + + # Parse URI and load the Snowflake Dataset object + assert config.dataset_uri is not None + match = _DATASET_URI_PATTERN.match(config.dataset_uri) + if not match: + raise ValueError(f"Invalid dataset_uri format: '{config.dataset_uri}'") + dataset_name, dataset_version = match.group(1), match.group(2) + snow_dataset = load_dataset(session, dataset_name, dataset_version) + + # Create connector from the Dataset object + connector = DataConnector.from_dataset(snow_dataset) + + # Convert to HuggingFace dataset + dataset = connector.to_huggingface_dataset( + streaming=False, + limit=config.limit, + batch_size=config.batch_size, + ) + + return dataset + + def post_load_callback(self, dataset: DatasetType) -> DatasetType: + """Apply column mapping if provided.""" + if self.config.column_mapping: + dataset = dataset.rename_columns(self.config.column_mapping) + return dataset diff --git a/projects/causal_snowflake/README.md b/projects/causal_snowflake/README.md new file mode 100644 index 00000000..a1d098f2 --- /dev/null +++ b/projects/causal_snowflake/README.md @@ -0,0 +1,200 @@ +# Causal Training with Snowflake Data Sources + +This project demonstrates causal language model training using data stored in Snowflake. It includes examples of all three Snowflake data source modes supported by Arctic Training. + +## Snowflake Data Source + +The unified `snowflake` data source type supports three mutually exclusive modes: + +| Mode | Config Key | Description | +|------|------------|-------------| +| SQL Query | `sql` | Execute arbitrary SQL queries against Snowflake | +| Table | `table_name` | Load data directly from a Snowflake table | +| Dataset | `dataset_uri` | Load data from a versioned Snowflake Dataset | + +**Note:** Exactly one of `sql`, `table_name`, or `dataset_uri` must be specified. + +## Prerequisites + +### 1. Install Dependencies + +Install Arctic Training with Snowflake support: + +```bash +pip install 'arctic_training[snowflake]' +``` + +### 2. Configure Snowflake Credentials + +The Snowflake data sources use `Session.builder.getOrCreate()` which supports multiple authentication methods: + +**Option A: Environment Variables** +```bash +export SNOWFLAKE_ACCOUNT="your_account" +export SNOWFLAKE_USER="your_username" +export SNOWFLAKE_PASSWORD="your_password" +export SNOWFLAKE_WAREHOUSE="your_warehouse" # optional +export SNOWFLAKE_DATABASE="your_database" # optional +export SNOWFLAKE_SCHEMA="your_schema" # optional +``` + +**Option B: Connections Config File** + +Create `~/.snowflake/connections.toml`: +```toml +[default] +account = "your_account" +user = "your_username" +password = "your_password" +warehouse = "your_warehouse" +``` + +You can also specify a non-default connection: +```bash +export SNOWFLAKE_DEFAULT_CONNECTION_NAME="my_connection" +``` + +## Setting Up Snowflake with Training Data + +Before running training, you need to populate your Snowflake account with the training data. + +### Run the Set Up Script + +```bash +cd projects/causal_snowflake +python setup_snowflake.py +``` + +This will: +1. Download the `stas/gutenberg-100` dataset from HuggingFace +2. Create the `ARCTIC_TRAINING.CAUSAL_DEMO` database and schema +3. Upload the data to a `GUTENBERG_100` table +4. Create a versioned `GUTENBERG_DATASET` Snowflake Dataset + +### Set Up Script Options + +```bash +python setup_snowflake.py --help + +Options: + --database TEXT Snowflake database name (default: ARCTIC_TRAINING) + --schema TEXT Snowflake schema name (default: CAUSAL_DEMO) + --table-name TEXT Snowflake table name (default: GUTENBERG_100) + --dataset-name TEXT Snowflake Dataset name (default: GUTENBERG_DATASET) + --dataset-version TEXT Snowflake Dataset version (default: v1) + --hf-dataset TEXT HuggingFace dataset to download (default: stas/gutenberg-100) + --sample-count INT Number of samples to upload (default: 100) + --drop-existing Drop existing table and dataset if they exist +``` + +### Expected Snowflake Resources + +After set up, you should have: + +| Resource | Full Name | +|----------|-----------| +| Database | `ARCTIC_TRAINING` | +| Schema | `ARCTIC_TRAINING.CAUSAL_DEMO` | +| Table | `ARCTIC_TRAINING.CAUSAL_DEMO.GUTENBERG_100` | +| Dataset | `snow://dataset/ARCTIC_TRAINING.CAUSAL_DEMO.GUTENBERG_DATASET/versions/v1` | + +## Running Training + +### Using SQL Query Mode + +Load data via a custom SQL query: + +```bash +arctic_training run-causal-snowflake-sql.yml +``` + +Config snippet: +```yaml +data: + sources: + - type: snowflake + sql: "SELECT TEXT FROM ARCTIC_TRAINING.CAUSAL_DEMO.GUTENBERG_100" + column_mapping: {"TEXT": "text"} +``` + +### Using Table Name Mode + +Load data directly from a table (auto-generates `SELECT * FROM table_name`): + +```bash +arctic_training run-causal-snowflake-table.yml +``` + +Config snippet: +```yaml +data: + sources: + - type: snowflake + table_name: ARCTIC_TRAINING.CAUSAL_DEMO.GUTENBERG_100 + column_mapping: {"TEXT": "text"} +``` + +### Using Dataset URI Mode + +Load data from a versioned Snowflake Dataset: + +```bash +arctic_training run-causal-snowflake-dataset.yml +``` + +Config snippet: +```yaml +data: + sources: + - type: snowflake + dataset_uri: "snow://dataset/ARCTIC_TRAINING.CAUSAL_DEMO.GUTENBERG_DATASET/versions/v1" + column_mapping: {"TEXT": "text"} +``` + +## Configuration Options + +All modes support these common options: + +| Option | Type | Default | Description | +|--------|------|---------|-------------| +| `column_mapping` | dict | `{}` | Rename columns (e.g., `{"SRC": "dst"}`) | +| `limit` | int | None | Maximum rows to load | +| `batch_size` | int | 1024 | Batch size for data retrieval | + +### Mode-Specific Options + +Exactly one of the following must be specified: + +| Option | Description | +|--------|-------------| +| `sql` | SQL query to execute | +| `table_name` | Table reference as `[[db.]schema.]table_name` | +| `dataset_uri` | Dataset URI as `snow://dataset//versions/` | + +## Troubleshooting + +### Connection Issues + +If you get authentication errors: +1. Verify your credentials are correct +2. Check that your account identifier is in the correct format (e.g., `orgname-accountname`) +3. Ensure your IP is allowlisted if network policies are enabled + +### Missing snowflake-ml-python + +If you see `ImportError: snowflake-ml-python is required`: +```bash +pip install 'arctic_training[snowflake]' +``` + +### Table/Dataset Not Found + +Ensure you've run the set up script first: +```bash +python setup_snowflake.py +``` + +Or verify the resources exist in Snowflake: +```sql +SHOW TABLES IN ARCTIC_TRAINING.CAUSAL_DEMO; +``` diff --git a/projects/causal_snowflake/run-causal-snowflake-dataset.yml b/projects/causal_snowflake/run-causal-snowflake-dataset.yml new file mode 100644 index 00000000..22ff1bed --- /dev/null +++ b/projects/causal_snowflake/run-causal-snowflake-dataset.yml @@ -0,0 +1,57 @@ +# Causal training using Snowflake Dataset data source +# +# This config demonstrates loading training data from a Snowflake Dataset. +# Snowflake Datasets provide versioned, managed datasets that are ideal for +# reproducible ML training pipelines. +# +# Prerequisites: +# 1. Run setup_snowflake.py to create the required Snowflake resources +# 2. Configure Snowflake credentials (env vars or ~/.snowflake/connections.toml) +# +# Usage: +# arctic_training run-causal-snowflake-dataset.yml + +type: causal +micro_batch_size: 1 +exit_iteration: 10 +min_iterations: 10 + +deepspeed: + zero_optimization: + stage: 3 + +optimizer: + learning_rate: 1e-5 + +model: + name_or_path: hf-internal-testing/tiny-random-LlamaForCausalLM + attn_implementation: flash_attention_2 + dtype: bf16 + +data: + sources: + - type: snowflake + # Snowflake Dataset URI format: snow://dataset//versions/ + # Where can be [[db.]schema.]dataset_name + dataset_uri: "snow://dataset/ARCTIC_TRAINING.CAUSAL_DEMO.GUTENBERG_DATASET/versions/v1" + # Map Snowflake column names (uppercase) to expected lowercase + column_mapping: {"TEXT": "text"} + # Optional: limit number of rows for testing + # limit: 50 + # Optional: batch size for data retrieval + # batch_size: 1024 + + cache_dir: /tmp/data-cache + num_proc: 16 + dl_num_workers: 1 + max_length: 2048 + +logger: + level: WARNING + output_dir: "logs" + print_output_ranks: [0,1,2,3,4,5,6,7] + +checkpoint: + - type: huggingface + save_every_n_steps: 300 + output_dir: /tmp/ft-model diff --git a/projects/causal_snowflake/run-causal-snowflake-sql.yml b/projects/causal_snowflake/run-causal-snowflake-sql.yml new file mode 100644 index 00000000..01a4b45f --- /dev/null +++ b/projects/causal_snowflake/run-causal-snowflake-sql.yml @@ -0,0 +1,56 @@ +# Causal training using Snowflake SQL data source +# +# This config demonstrates loading training data from Snowflake using a raw SQL query. +# The SQL query is executed directly against Snowflake and results are converted to +# a HuggingFace dataset for training. +# +# Prerequisites: +# 1. Run setup_snowflake.py to create the required Snowflake resources +# 2. Configure Snowflake credentials (env vars or ~/.snowflake/connections.toml) +# +# Usage: +# arctic_training run-causal-snowflake-sql.yml + +type: causal +micro_batch_size: 1 +exit_iteration: 10 +min_iterations: 10 + +deepspeed: + zero_optimization: + stage: 3 + +optimizer: + learning_rate: 1e-5 + +model: + name_or_path: hf-internal-testing/tiny-random-LlamaForCausalLM + attn_implementation: flash_attention_2 + dtype: bf16 + +data: + sources: + - type: snowflake + # Raw SQL query to fetch training data + sql: "SELECT TEXT FROM ARCTIC_TRAINING.CAUSAL_DEMO.GUTENBERG_100" + # Map Snowflake column names (uppercase) to expected lowercase + column_mapping: {"TEXT": "text"} + # Optional: limit number of rows for testing + # limit: 50 + # Optional: batch size for data retrieval + # batch_size: 1024 + + cache_dir: /tmp/data-cache + num_proc: 16 + dl_num_workers: 1 + max_length: 2048 + +logger: + level: WARNING + output_dir: "logs" + print_output_ranks: [0,1,2,3,4,5,6,7] + +checkpoint: + - type: huggingface + save_every_n_steps: 300 + output_dir: /tmp/ft-model diff --git a/projects/causal_snowflake/run-causal-snowflake-table.yml b/projects/causal_snowflake/run-causal-snowflake-table.yml new file mode 100644 index 00000000..64e1e434 --- /dev/null +++ b/projects/causal_snowflake/run-causal-snowflake-table.yml @@ -0,0 +1,55 @@ +# Causal training using Snowflake Table data source +# +# This config demonstrates loading training data from a Snowflake table directly. +# The table name is automatically converted to a SELECT * query internally. +# +# Prerequisites: +# 1. Run setup_snowflake.py to create the required Snowflake resources +# 2. Configure Snowflake credentials (env vars or ~/.snowflake/connections.toml) +# +# Usage: +# arctic_training run-causal-snowflake-table.yml + +type: causal +micro_batch_size: 1 +exit_iteration: 10 +min_iterations: 10 + +deepspeed: + zero_optimization: + stage: 3 + +optimizer: + learning_rate: 1e-5 + +model: + name_or_path: hf-internal-testing/tiny-random-LlamaForCausalLM + attn_implementation: flash_attention_2 + dtype: bf16 + +data: + sources: + - type: snowflake + # Fully qualified table name: [[db.]schema.]table_name + table_name: ARCTIC_TRAINING.CAUSAL_DEMO.GUTENBERG_100 + # Map Snowflake column names (uppercase) to expected lowercase + column_mapping: {"TEXT": "text"} + # Optional: limit number of rows for testing + # limit: 50 + # Optional: batch size for data retrieval + # batch_size: 1024 + + cache_dir: /tmp/data-cache + num_proc: 16 + dl_num_workers: 1 + max_length: 2048 + +logger: + level: WARNING + output_dir: "logs" + print_output_ranks: [0,1,2,3,4,5,6,7] + +checkpoint: + - type: huggingface + save_every_n_steps: 300 + output_dir: /tmp/ft-model diff --git a/projects/causal_snowflake/setup_snowflake.py b/projects/causal_snowflake/setup_snowflake.py new file mode 100644 index 00000000..a1e0505e --- /dev/null +++ b/projects/causal_snowflake/setup_snowflake.py @@ -0,0 +1,220 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Set up script to populate a Snowflake account with training data from HuggingFace. + +This script: +1. Downloads the stas/gutenberg-100 dataset from HuggingFace +2. Creates a Snowflake database and schema +3. Uploads the data as a Snowflake table +4. Creates a Snowflake Dataset from that table + +Prerequisites: +- Install arctic_training with snowflake extras: pip install 'arctic_training[snowflake]' +- Configure Snowflake credentials via: + - Environment variables (SNOWFLAKE_ACCOUNT, SNOWFLAKE_USER, SNOWFLAKE_PASSWORD) + - Config file (~/.snowflake/connections.toml) + - SNOWFLAKE_DEFAULT_CONNECTION_NAME environment variable + +Usage: + python setup_snowflake.py [--database DATABASE] [--schema SCHEMA] [--sample-count N] +""" + +import argparse + +import pandas as pd +from datasets import load_dataset +from snowflake.ml.dataset import create_from_dataframe +from snowflake.snowpark import Session + + +def get_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Set up Snowflake with training data from HuggingFace") + parser.add_argument( + "--database", + type=str, + default="ARCTIC_TRAINING", + help="Snowflake database name (default: ARCTIC_TRAINING)", + ) + parser.add_argument( + "--schema", + type=str, + default="CAUSAL_DEMO", + help="Snowflake schema name (default: CAUSAL_DEMO)", + ) + parser.add_argument( + "--table-name", + type=str, + default="GUTENBERG_100", + help="Snowflake table name (default: GUTENBERG_100)", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="GUTENBERG_DATASET", + help="Snowflake Dataset name (default: GUTENBERG_DATASET)", + ) + parser.add_argument( + "--dataset-version", + type=str, + default="v1", + help="Snowflake Dataset version (default: v1)", + ) + parser.add_argument( + "--hf-dataset", + type=str, + default="stas/gutenberg-100", + help="HuggingFace dataset to download (default: stas/gutenberg-100)", + ) + parser.add_argument( + "--sample-count", + type=int, + default=100, + help="Number of samples to upload (default: 100)", + ) + parser.add_argument( + "--drop-existing", + action="store_true", + help="Drop existing table and dataset if they exist", + ) + return parser.parse_args() + + +def download_hf_dataset(dataset_name: str, sample_count: int) -> pd.DataFrame: + """Download dataset from HuggingFace and convert to pandas DataFrame.""" + print(f"Downloading HuggingFace dataset: {dataset_name}") + + # Load the dataset + hf_dataset = load_dataset(dataset_name, split=f"train[:{sample_count}]") + + # Convert to pandas DataFrame + df = hf_dataset.to_pandas() + + # Ensure consistent column naming (uppercase for Snowflake) + df.columns = [col.upper() for col in df.columns] + + print(f"Downloaded {len(df)} samples with columns: {list(df.columns)}") + return df + + +def create_snowflake_resources( + session: Session, + database: str, + schema: str, + table_name: str, + dataset_name: str, + dataset_version: str, + df: pd.DataFrame, + drop_existing: bool = False, +) -> None: + """Create Snowflake database, schema, table, and dataset.""" + + # Create database and schema + print(f"Creating database: {database}") + session.sql(f"CREATE DATABASE IF NOT EXISTS {database}").collect() + + print(f"Creating schema: {database}.{schema}") + session.sql(f"CREATE SCHEMA IF NOT EXISTS {database}.{schema}").collect() + + # Set the context + session.use_database(database) + session.use_schema(schema) + + full_table_name = f"{database}.{schema}.{table_name}" + full_dataset_name = f"{database}.{schema}.{dataset_name}" + + if drop_existing: + print(f"Dropping existing table (if exists): {full_table_name}") + session.sql(f"DROP TABLE IF EXISTS {full_table_name}").collect() + + # Note: Snowflake Datasets cannot be dropped via SQL, they need to be + # deleted via the Dataset API or will be overwritten + + # Create table from pandas DataFrame + print(f"Creating table: {full_table_name}") + snowpark_df = session.create_dataframe(df) + snowpark_df.write.mode("overwrite").save_as_table(full_table_name) + + # Verify table creation + row_count = session.sql(f"SELECT COUNT(*) FROM {full_table_name}").collect()[0][0] + print(f"Table created with {row_count} rows") + + # Create Snowflake Dataset from the table + print(f"Creating Snowflake Dataset: {full_dataset_name} (version: {dataset_version})") + + # Read the table as a Snowpark DataFrame for dataset creation + table_df = session.table(full_table_name) + + # Create the dataset + snow_dataset = create_from_dataframe( + session=session, + name=full_dataset_name, + version=dataset_version, + input_dataframe=table_df, + ) + + print(f"Dataset created: {snow_dataset.fully_qualified_name}") + print(f"Dataset URI: snow://dataset/{full_dataset_name}/versions/{dataset_version}") + + +def main() -> None: + args = get_args() + + print("=" * 60) + print("Snowflake Set Up Script for Arctic Training") + print("=" * 60) + + # Download HuggingFace dataset + df = download_hf_dataset(args.hf_dataset, args.sample_count) + + # Connect to Snowflake + print("\nConnecting to Snowflake...") + session = Session.builder.getOrCreate() + print(f"Connected to account: {session.get_current_account()}") + + try: + # Create Snowflake resources + create_snowflake_resources( + session=session, + database=args.database, + schema=args.schema, + table_name=args.table_name, + dataset_name=args.dataset_name, + dataset_version=args.dataset_version, + df=df, + drop_existing=args.drop_existing, + ) + + print("\n" + "=" * 60) + print("Set up completed successfully!") + print("=" * 60) + print("\nCreated resources:") + print(f" - Table: {args.database}.{args.schema}.{args.table_name}") + print( + " - Dataset:" + f" snow://dataset/{args.database}.{args.schema}.{args.dataset_name}/versions/{args.dataset_version}" + ) + print("\nYou can now run the training configs:") + print(" - run-causal-snowflake-sql.yml") + print(" - run-causal-snowflake-table.yml") + print(" - run-causal-snowflake-dataset.yml") + + finally: + session.close() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index c004aac4..0f46690b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ all = [ "arctic_training[testing]", "arctic_training[vllm]", "arctic_training[cortex]", + "arctic_training[snowflake]", ] dev = [ @@ -91,6 +92,7 @@ testing = [ "pytest-instafail", "parameterized", "pytest-xdist", + "arctic_training[snowflake]", ] formatting = [ @@ -115,6 +117,11 @@ cortex = [ "snowflake-connector-python==3.12.3", ] +snowflake = [ + "snowflake-ml-python>=1.21.0", + "numba>=0.63", # Mitigate uv dependency resolution issue with snowflake-ml-python +] + [project.scripts] arctic_training = "arctic_training_cli:main" arctic_training_run = "arctic_training.entrypoint:launch" diff --git a/tests/data/test_snowflake_source.py b/tests/data/test_snowflake_source.py new file mode 100644 index 00000000..9a01af4f --- /dev/null +++ b/tests/data/test_snowflake_source.py @@ -0,0 +1,372 @@ +# Copyright 2025 Snowflake Inc. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from datasets import Dataset +from pydantic import ValidationError +from snowflake.ml.data.data_connector import DataConnector + +from arctic_training.data.snowflake_source import SnowflakeDataSource +from arctic_training.data.snowflake_source import SnowflakeSourceConfig + + +class TestSnowflakeSourceConfig: + """Tests for SnowflakeSourceConfig validation.""" + + # ===== SQL Mode Tests ===== + + def test_valid_sql_query(self): + """Test that valid SQL queries are accepted.""" + sql = "SELECT col1, col2 FROM my_db.my_schema.my_table WHERE id > 100" + config = SnowflakeSourceConfig(type="snowflake", sql=sql) + assert config.sql == sql + + def test_sql_with_custom_options(self): + """Test that custom options are preserved with sql mode.""" + config = SnowflakeSourceConfig( + type="snowflake", + sql="SELECT * FROM my_table", + column_mapping={"old_col": "new_col"}, + limit=1000, + batch_size=512, + ) + assert config.sql == "SELECT * FROM my_table" + assert config.column_mapping == {"old_col": "new_col"} + assert config.limit == 1000 + assert config.batch_size == 512 + + # ===== Table Name Mode Tests ===== + + @pytest.mark.parametrize( + "table_name", + [ + "my_table", + "my_schema.my_table", + "my_db.my_schema.my_table", + ], + ) + def test_valid_table_name(self, table_name: str): + """Test that [[db.]schema.]table format is accepted.""" + config = SnowflakeSourceConfig(type="snowflake", table_name=table_name) + assert config.table_name == table_name + + @pytest.mark.parametrize( + ("table_name", "expected_sql"), + [ + ("my_table", "SELECT * FROM my_table"), + ("my_schema.my_table", "SELECT * FROM my_schema.my_table"), + ("my_db.my_schema.my_table", "SELECT * FROM my_db.my_schema.my_table"), + ], + ) + def test_sql_generated_from_table_name(self, table_name: str, expected_sql: str): + """Test that sql field is auto-populated from table_name.""" + config = SnowflakeSourceConfig(type="snowflake", table_name=table_name) + assert config.sql == expected_sql + + def test_table_name_with_custom_options(self): + """Test that custom options are preserved with table_name mode.""" + config = SnowflakeSourceConfig( + type="snowflake", + table_name="my_table", + column_mapping={"old_col": "new_col"}, + limit=1000, + batch_size=512, + ) + assert config.column_mapping == {"old_col": "new_col"} + assert config.limit == 1000 + assert config.batch_size == 512 + + # ===== Dataset URI Mode Tests ===== + + @pytest.mark.parametrize( + "dataset_uri", + [ + "snow://dataset/my_dataset/versions/v1", + 'snow://dataset/"my-training_set"/versions/v1.0', + # Dataset names can also be qualified as [[db.]schema.]dataset_name + "snow://dataset/my_schema.my_dataset/versions/v1", + "snow://dataset/my_db.my_schema.my_dataset/versions/v2", + # Quoted identifiers can be used to allow special characters (e.g. hyphens) + 'snow://dataset/"my_db"."my_schema"."my-training_set"/versions/v1', + ], + ) + def test_valid_dataset_uri(self, dataset_uri): + """Test that valid dataset URIs are accepted.""" + config = SnowflakeSourceConfig(type="snowflake", dataset_uri=dataset_uri) + assert config.dataset_uri == dataset_uri + + @pytest.mark.parametrize( + "dataset_uri", + [ + # Missing snow:// prefix + "dataset/my_dataset/versions/v1", + # Wrong base path + "snow://my_dataset/v1", + # Missing version segment + "snow://dataset/my_dataset", + # Too many qualifiers: db.schema.dataset.extra + "snow://dataset/a.b.c.d/versions/v1", + # Empty identifier segments + "snow://dataset/.a/versions/v1", + "snow://dataset/a../versions/v1", + ], + ) + def test_invalid_dataset_uri(self, dataset_uri: str): + """Test that invalid dataset URIs are rejected.""" + with pytest.raises(ValidationError) as exc_info: + SnowflakeSourceConfig(type="snowflake", dataset_uri=dataset_uri) + assert "Invalid dataset_uri format" in str(exc_info.value) or "dataset_name format" in str(exc_info.value) + + def test_dataset_uri_with_custom_options(self): + """Test that custom options are preserved with dataset_uri mode.""" + config = SnowflakeSourceConfig( + type="snowflake", + dataset_uri="snow://dataset/my_dataset/versions/v1", + column_mapping={"chat": "messages"}, + limit=500, + batch_size=256, + ) + assert config.column_mapping == {"chat": "messages"} + assert config.limit == 500 + assert config.batch_size == 256 + + # ===== Default Values Tests ===== + + def test_default_values(self): + """Test that default values are set correctly.""" + config = SnowflakeSourceConfig(type="snowflake", sql="SELECT * FROM my_table") + assert config.column_mapping == {} + assert config.limit is None + assert config.batch_size == 1024 + + # ===== One-of Validation Tests ===== + + def test_error_when_no_source_specified(self): + """Test that error is raised when no source is specified.""" + with pytest.raises(ValidationError) as exc_info: + SnowflakeSourceConfig(type="snowflake") + assert "Exactly one of 'sql', 'table_name', or 'dataset_uri' must be specified" in str(exc_info.value) + + def test_error_when_multiple_sources_specified(self): + """Test that error is raised when multiple sources are specified.""" + # sql + table_name + with pytest.raises(ValidationError) as exc_info: + SnowflakeSourceConfig(type="snowflake", sql="SELECT * FROM t", table_name="my_table") + assert "Exactly one of 'sql', 'table_name', or 'dataset_uri' must be specified" in str(exc_info.value) + + # sql + dataset_uri + with pytest.raises(ValidationError) as exc_info: + SnowflakeSourceConfig( + type="snowflake", sql="SELECT * FROM t", dataset_uri="snow://dataset/my_dataset/versions/v1" + ) + assert "Exactly one of 'sql', 'table_name', or 'dataset_uri' must be specified" in str(exc_info.value) + + # table_name + dataset_uri + with pytest.raises(ValidationError) as exc_info: + SnowflakeSourceConfig( + type="snowflake", table_name="my_table", dataset_uri="snow://dataset/my_dataset/versions/v1" + ) + assert "Exactly one of 'sql', 'table_name', or 'dataset_uri' must be specified" in str(exc_info.value) + + # all three + with pytest.raises(ValidationError) as exc_info: + SnowflakeSourceConfig( + type="snowflake", + sql="SELECT * FROM t", + table_name="my_table", + dataset_uri="snow://dataset/my_dataset/versions/v1", + ) + assert "Exactly one of 'sql', 'table_name', or 'dataset_uri' must be specified" in str(exc_info.value) + + +class TestSnowflakeDataSource: + """Tests for SnowflakeDataSource.""" + + # ===== SQL Mode Tests ===== + + @patch.object(DataConnector, "from_sql") + def test_load_with_sql(self, mock_from_sql): + """Test that load() calls DataConnector.from_sql() with the provided SQL.""" + # Setup mocks + mock_session = MagicMock() + + mock_dataset = Dataset.from_dict({"col1": ["a", "b"], "col2": [1, 2]}) + mock_connector_instance = MagicMock() + mock_connector_instance.to_huggingface_dataset.return_value = mock_dataset + mock_from_sql.return_value = mock_connector_instance + + # Create config and data source + sql = "SELECT col1, col2 FROM my_db.my_schema.my_table WHERE id > 100" + config = SnowflakeSourceConfig( + type="snowflake", + sql=sql, + limit=100, + batch_size=512, + ) + + mock_data_factory = MagicMock() + data_source = SnowflakeDataSource(data_factory=mock_data_factory, config=config) + data_source.session = mock_session + + result = data_source.load(config, split="train") + + # Verify SQL query was passed correctly + mock_from_sql.assert_called_once_with(sql, session=mock_session) + mock_connector_instance.to_huggingface_dataset.assert_called_once_with( + streaming=False, + limit=100, + batch_size=512, + ) + assert result == mock_dataset + + # ===== Table Name Mode Tests ===== + + @patch.object(DataConnector, "from_sql") + def test_load_with_table_name(self, mock_from_sql): + """Test that load() with table_name generates correct SQL.""" + # Setup mocks + mock_session = MagicMock() + + mock_dataset = Dataset.from_dict({"text": ["hello", "world"]}) + mock_connector_instance = MagicMock() + mock_connector_instance.to_huggingface_dataset.return_value = mock_dataset + mock_from_sql.return_value = mock_connector_instance + + # Create config and data source + config = SnowflakeSourceConfig( + type="snowflake", + table_name="my_db.my_schema.my_table", + limit=100, + batch_size=512, + ) + + mock_data_factory = MagicMock() + data_source = SnowflakeDataSource(data_factory=mock_data_factory, config=config) + data_source.session = mock_session + + result = data_source.load(config, split="train") + + # Verify SQL query was constructed correctly from table_name + mock_from_sql.assert_called_once_with("SELECT * FROM my_db.my_schema.my_table", session=mock_session) + mock_connector_instance.to_huggingface_dataset.assert_called_once_with( + streaming=False, + limit=100, + batch_size=512, + ) + assert result == mock_dataset + + # ===== Dataset URI Mode Tests ===== + + @pytest.mark.parametrize( + ("dataset_uri", "expected_name", "expected_version"), + [ + ("snow://dataset/my_dataset/versions/v1", "my_dataset", "v1"), + # Dataset names can also be qualified as [[db.]schema.]dataset_name + ("snow://dataset/my_schema.my_dataset/versions/v1", "my_schema.my_dataset", "v1"), + ("snow://dataset/my_db.my_schema.my_dataset/versions/v2", "my_db.my_schema.my_dataset", "v2"), + # Quoted identifiers can be used to allow special characters (e.g. hyphens) + ('snow://dataset/"my-training_set"/versions/v1.0', '"my-training_set"', "v1.0"), + ( + 'snow://dataset/"my_db"."my_schema"."my-training_set"/versions/v1', + '"my_db"."my_schema"."my-training_set"', + "v1", + ), + ], + ) + @patch("snowflake.ml.dataset.load_dataset") + @patch.object(DataConnector, "from_dataset") + def test_load_with_dataset_uri( + self, mock_from_dataset, mock_load_dataset, dataset_uri, expected_name, expected_version + ): + """Test that load() with dataset_uri calls DataConnector.from_dataset() correctly.""" + # Setup mocks + mock_session = MagicMock() + + mock_snow_dataset = MagicMock() + mock_load_dataset.return_value = mock_snow_dataset + + mock_hf_dataset = Dataset.from_dict({"messages": [["msg1"], ["msg2"]]}) + mock_connector_instance = MagicMock() + mock_connector_instance.to_huggingface_dataset.return_value = mock_hf_dataset + mock_from_dataset.return_value = mock_connector_instance + + # Create config and data source + config = SnowflakeSourceConfig( + type="snowflake", + dataset_uri=dataset_uri, + limit=500, + batch_size=256, + ) + + mock_data_factory = MagicMock() + data_source = SnowflakeDataSource(data_factory=mock_data_factory, config=config) + data_source.session = mock_session + + result = data_source.load(config, split="train") + + # Verify load_dataset was called with correct arguments + mock_load_dataset.assert_called_once_with(mock_session, expected_name, expected_version) + # Verify DataConnector.from_dataset was called with the loaded dataset + mock_from_dataset.assert_called_once_with(mock_snow_dataset) + mock_connector_instance.to_huggingface_dataset.assert_called_once_with( + streaming=False, + limit=500, + batch_size=256, + ) + assert result == mock_hf_dataset + + # ===== Post Load Callback Tests ===== + + def test_post_load_callback_applies_column_mapping(self): + """Test that post_load_callback applies column mapping.""" + mock_dataset = Dataset.from_dict({"old_col": ["a", "b"], "other": [1, 2]}) + mock_renamed_dataset = Dataset.from_dict({"new_col": ["a", "b"], "other": [1, 2]}) + mock_dataset.rename_columns = MagicMock(return_value=mock_renamed_dataset) + + config = SnowflakeSourceConfig( + type="snowflake", + sql="SELECT * FROM my_table", + column_mapping={"old_col": "new_col"}, + ) + + mock_data_factory = MagicMock() + data_source = SnowflakeDataSource(data_factory=mock_data_factory, config=config) + + result = data_source.post_load_callback(mock_dataset) + + mock_dataset.rename_columns.assert_called_once_with({"old_col": "new_col"}) + assert result == mock_renamed_dataset + + def test_post_load_callback_passthrough_without_mapping(self): + """Test that post_load_callback passes through unchanged without mapping.""" + mock_dataset = Dataset.from_dict({"col": ["a", "b"]}) + mock_dataset.rename_columns = MagicMock() + + config = SnowflakeSourceConfig( + type="snowflake", + sql="SELECT * FROM my_table", + column_mapping={}, # Empty mapping + ) + + mock_data_factory = MagicMock() + data_source = SnowflakeDataSource(data_factory=mock_data_factory, config=config) + + result = data_source.post_load_callback(mock_dataset) + + mock_dataset.rename_columns.assert_not_called() + assert result == mock_dataset