|
| 1 | +import csv |
| 2 | +import os |
| 3 | +import random |
| 4 | + |
| 5 | +from faker import Faker |
| 6 | +from probes.logging_config import initialize_logger |
| 7 | +from probes.login import connect |
| 8 | +from probes.registry import prober_function # noqa |
| 9 | + |
| 10 | +import snowflake.connector |
| 11 | +from snowflake.connector.util_text import random_string |
| 12 | + |
| 13 | +# Initialize logger |
| 14 | +logger = initialize_logger(__name__) |
| 15 | + |
| 16 | + |
| 17 | +def generate_random_data(num_records: int, file_path: str) -> str: |
| 18 | + """ |
| 19 | + Generates random CSV data with the specified number of rows. |
| 20 | +
|
| 21 | + Args: |
| 22 | + num_records (int): Number of rows to generate. |
| 23 | +
|
| 24 | + Returns: |
| 25 | + str: File path to CSV file |
| 26 | + """ |
| 27 | + fake = Faker() |
| 28 | + with open(file_path, mode="w", newline="", encoding="utf-8") as csvfile: |
| 29 | + writer = csv.writer(csvfile, quoting=csv.QUOTE_ALL) |
| 30 | + writer.writerow(["id", "name", "email", "address"]) |
| 31 | + for i in range(1, num_records + 1): |
| 32 | + writer.writerow([i, fake.name(), fake.email(), fake.address()]) |
| 33 | + with open(file_path, newline="", encoding="utf-8") as csvfile: |
| 34 | + reader = csv.reader(csvfile) |
| 35 | + rows = list(reader) |
| 36 | + # Subtract 1 for the header row |
| 37 | + actual_records = len(rows) - 1 |
| 38 | + assert actual_records == num_records, logger.error( |
| 39 | + f"Expected {num_records} records, but found {actual_records}." |
| 40 | + ) |
| 41 | + return file_path |
| 42 | + |
| 43 | + |
| 44 | +def create_data_table(cursor: snowflake.connector.cursor.SnowflakeCursor) -> str: |
| 45 | + """ |
| 46 | + Creates a data table in Snowflake with the specified schema. |
| 47 | +
|
| 48 | + Returns: |
| 49 | + str: The name of the created table. |
| 50 | + """ |
| 51 | + table_name = random_string(7, "test_data_") |
| 52 | + create_table_query = f""" |
| 53 | + CREATE OR REPLACE TABLE {table_name} ( |
| 54 | + id INT, |
| 55 | + name STRING, |
| 56 | + email STRING, |
| 57 | + address STRING |
| 58 | + ); |
| 59 | + """ |
| 60 | + cursor.execute(create_table_query) |
| 61 | + if cursor.fetchone(): |
| 62 | + print({"created_table": True}) |
| 63 | + else: |
| 64 | + print({"created_table": False}) |
| 65 | + return table_name |
| 66 | + |
| 67 | + |
| 68 | +def create_data_stage(cursor: snowflake.connector.cursor.SnowflakeCursor) -> str: |
| 69 | + """ |
| 70 | + Creates a stage in Snowflake for data upload. |
| 71 | +
|
| 72 | + Returns: |
| 73 | + str: The name of the created stage. |
| 74 | + """ |
| 75 | + stage_name = random_string(7, "test_data_stage_") |
| 76 | + create_stage_query = f"CREATE OR REPLACE STAGE {stage_name};" |
| 77 | + |
| 78 | + cursor.execute(create_stage_query) |
| 79 | + if cursor.fetchone(): |
| 80 | + print({"created_stage": True}) |
| 81 | + else: |
| 82 | + print({"created_stage": False}) |
| 83 | + return stage_name |
| 84 | + |
| 85 | + |
| 86 | +def copy_into_table_from_stage( |
| 87 | + table_name: str, stage_name: str, cur: snowflake.connector.cursor.SnowflakeCursor |
| 88 | +): |
| 89 | + """ |
| 90 | + Copies data from a stage into a specified table in Snowflake. |
| 91 | +
|
| 92 | + Args: |
| 93 | + table_name (str): The name of the table where data will be copied. |
| 94 | + stage_name (str): The name of the stage from which data will be copied. |
| 95 | + cur (snowflake.connector.cursor.SnowflakeCursor): The cursor to execute the SQL command. |
| 96 | + """ |
| 97 | + cur.execute( |
| 98 | + f""" |
| 99 | + COPY INTO {table_name} |
| 100 | + FROM @{stage_name} |
| 101 | + FILE_FORMAT = (TYPE = CSV FIELD_OPTIONALLY_ENCLOSED_BY = '"' SKIP_HEADER = 1);""" |
| 102 | + ) |
| 103 | + |
| 104 | + # Check if the data was loaded successfully |
| 105 | + if cur.fetchall()[0][1] == "LOADED": |
| 106 | + print({"copied_data_from_stage_into_table": True}) |
| 107 | + else: |
| 108 | + print({"copied_data_from_stage_into_table": False}) |
| 109 | + |
| 110 | + |
| 111 | +def put_file_to_stage( |
| 112 | + file_name: str, stage_name: str, cur: snowflake.connector.cursor.SnowflakeCursor |
| 113 | +): |
| 114 | + """ |
| 115 | + Uploads a file to a specified stage in Snowflake. |
| 116 | +
|
| 117 | + Args: |
| 118 | + file_name (str): The name of the file to upload. |
| 119 | + stage_name (str): The name of the stage where the file will be uploaded. |
| 120 | + cur (snowflake.connector.cursor.SnowflakeCursor): The cursor to execute the SQL command. |
| 121 | + """ |
| 122 | + response = cur.execute( |
| 123 | + f"PUT file://{file_name} @{stage_name} AUTO_COMPRESS=TRUE" |
| 124 | + ).fetchall() |
| 125 | + logger.error(response) |
| 126 | + |
| 127 | + if response[0][6] == "UPLOADED": |
| 128 | + print({"PUT_operation": True}) |
| 129 | + else: |
| 130 | + print({"PUT_operation": False}) |
| 131 | + |
| 132 | + |
| 133 | +def count_data_from_table( |
| 134 | + table_name: str, num_records: int, cur: snowflake.connector.cursor.SnowflakeCursor |
| 135 | +): |
| 136 | + count = cur.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0] |
| 137 | + if count == num_records: |
| 138 | + print({"data_transferred_completely": True}) |
| 139 | + else: |
| 140 | + print({"data_transferred_completely": False}) |
| 141 | + |
| 142 | + |
| 143 | +def compare_fetched_data( |
| 144 | + table_name: str, |
| 145 | + file_name: str, |
| 146 | + cur: snowflake.connector.cursor.SnowflakeCursor, |
| 147 | + repetitions: int = 10, |
| 148 | + fetch_limit: int = 100, |
| 149 | +): |
| 150 | + """ |
| 151 | + Compares the data fetched from the table with the data in the CSV file. |
| 152 | +
|
| 153 | + Args: |
| 154 | + table_name (str): The name of the table to fetch data from. |
| 155 | + file_name (str): The name of the CSV file to compare data against. |
| 156 | + cur (snowflake.connector.cursor.SnowflakeCursor): The cursor to execute the SQL command. |
| 157 | + repetitions (int): Number of times to repeat the comparison. Default is 10. |
| 158 | + fetch_limit (int): Number of rows to fetch from the table for comparison. Default is 100. |
| 159 | + """ |
| 160 | + |
| 161 | + fetched_data = cur.execute( |
| 162 | + f"SELECT * FROM {table_name} LIMIT {fetch_limit}" |
| 163 | + ).fetchall() |
| 164 | + |
| 165 | + with open(file_name, newline="", encoding="utf-8") as csvfile: |
| 166 | + reader = csv.reader(csvfile) |
| 167 | + csv_data = list(reader)[1:] # Skip header row |
| 168 | + for _ in range(repetitions): |
| 169 | + random_index = random.randint(0, fetch_limit - 1) |
| 170 | + for y in range(len(fetched_data[0])): |
| 171 | + if str(fetched_data[random_index][y]) != csv_data[random_index][y]: |
| 172 | + print({"data_integrity_check": False}) |
| 173 | + break |
| 174 | + print({"data_integrity_check": True}) |
| 175 | + |
| 176 | + |
| 177 | +def execute_get_command(stage_name: str, conn: snowflake.connector.SnowflakeConnection): |
| 178 | + """ |
| 179 | + Downloads a file from a specified stage in Snowflake. |
| 180 | +
|
| 181 | + Args: |
| 182 | + stage_name (str): The name of the stage from which the file will be downloaded. |
| 183 | + conn (snowflake.connector.SnowflakeConnection): The connection object to execute the SQL command. |
| 184 | + """ |
| 185 | + download_dir = f"s3://{conn.account}/{stage_name}" |
| 186 | + |
| 187 | + try: |
| 188 | + if not os.path.exists(download_dir): |
| 189 | + os.makedirs(download_dir) |
| 190 | + conn.cursor().execute(f"GET @{stage_name} file://{download_dir}/ ;") |
| 191 | + # Check if files are downloaded |
| 192 | + downloaded_files = os.listdir(download_dir) |
| 193 | + if downloaded_files: |
| 194 | + print({"GET_operation": True}) |
| 195 | + else: |
| 196 | + print({"GET_operation": False}) |
| 197 | + |
| 198 | + finally: |
| 199 | + try: |
| 200 | + for file in os.listdir(download_dir): |
| 201 | + file_path = os.path.join(download_dir, file) |
| 202 | + if os.path.isfile(file_path): |
| 203 | + os.remove(file_path) |
| 204 | + os.rmdir(download_dir) |
| 205 | + except FileNotFoundError: |
| 206 | + logger.error( |
| 207 | + f"Error cleaning up directory {download_dir}. It may not exist or be empty." |
| 208 | + ) |
| 209 | + |
| 210 | + |
| 211 | +def perform_put_fetch_get(connection_parameters: dict, num_records: int = 1000): |
| 212 | + """ |
| 213 | + Performs a PUT, fetch and GET operation using the provided connection parameters. |
| 214 | +
|
| 215 | + Args: |
| 216 | + connection_parameters (dict): A dictionary containing connection details such as |
| 217 | + host, port, user, password, account, schema, etc. |
| 218 | + num_records (int): Number of records to generate and PUT. Default is 10,000. |
| 219 | + """ |
| 220 | + try: |
| 221 | + with connect(connection_parameters) as conn: |
| 222 | + with conn.cursor() as cur: |
| 223 | + |
| 224 | + logger.error("Creating stage") |
| 225 | + stage_name = create_data_stage(cur) |
| 226 | + logger.error(f"Stage {stage_name} created") |
| 227 | + |
| 228 | + logger.error("Creating stage") |
| 229 | + table_name = create_data_table(cur) |
| 230 | + logger.error(f"Table {table_name} created") |
| 231 | + |
| 232 | + logger.error("Generating random data") |
| 233 | + file_name = generate_random_data(num_records, f"{table_name}.csv") |
| 234 | + logger.error(f"Random data generated in {file_name}") |
| 235 | + |
| 236 | + logger.error("PUT file to stage") |
| 237 | + put_file_to_stage(file_name, stage_name, cur) |
| 238 | + logger.error(f"File {file_name} uploaded to stage {stage_name}") |
| 239 | + |
| 240 | + logger.error("Copying data from stage to table") |
| 241 | + copy_into_table_from_stage(table_name, stage_name, cur) |
| 242 | + logger.error( |
| 243 | + f"Data copied from stage {stage_name} to table {table_name}" |
| 244 | + ) |
| 245 | + |
| 246 | + logger.error("Counting data in the table") |
| 247 | + count_data_from_table(table_name, num_records, cur) |
| 248 | + |
| 249 | + logger.error("Comparing fetched data with CSV file") |
| 250 | + compare_fetched_data(table_name, file_name, cur) |
| 251 | + |
| 252 | + logger.error("Performing GET operation") |
| 253 | + execute_get_command(stage_name, conn) |
| 254 | + logger.error("File downloaded from stage to local directory") |
| 255 | + |
| 256 | + except Exception as e: |
| 257 | + logger.error(f"Error during PUT/GET operation: {e}") |
| 258 | + |
| 259 | + finally: |
| 260 | + # Cleanup: Remove data from the stage and delete table |
| 261 | + with connect(connection_parameters) as conn: |
| 262 | + with conn.cursor() as cur: |
| 263 | + cur.execute(f"REMOVE @{stage_name}") |
| 264 | + cur.execute(f"DROP TABLE {table_name}") |
| 265 | + |
| 266 | + |
| 267 | +# Disabled in MVP, uncomment to run |
| 268 | +# @prober_function |
| 269 | +def perform_put_fetch_get_100_lines(connection_parameters: dict): |
| 270 | + """ |
| 271 | + Performs a PUT and GET operation for 1,000 rows using the provided connection parameters. |
| 272 | +
|
| 273 | + Args: |
| 274 | + connection_parameters (dict): A dictionary containing connection details such as |
| 275 | + host, port, user, password, account, schema, etc. |
| 276 | + """ |
| 277 | + perform_put_fetch_get(connection_parameters, num_records=100) |
0 commit comments