Skip to content

Commit 8350f8b

Browse files
SNOW-2114096-extending-probing-capabilities (#2348)
1 parent c53aad7 commit 8350f8b

File tree

4 files changed

+283
-9
lines changed

4 files changed

+283
-9
lines changed

prober/probes/login.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def perform_login(connection_parameters: dict):
5858
cursor = connection.cursor()
5959
cursor.execute("SELECT 1;")
6060
result = cursor.fetchone()
61-
logger.info(result)
61+
logger.error(f"Logging: {result}")
6262
assert result == (1,)
63-
logger.info({f"success_login={True}"})
63+
print({"success_login": True})
6464
except Exception as e:
65-
logger.info({f"success_login={False}"})
65+
print({"success_login": False})
6666
logger.error(f"Error during login: {e}")

prober/probes/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22
import logging
33

4-
from probes import login # noqa
4+
from probes import login, put_fetch_get # noqa
55
from probes.logging_config import initialize_logger
66
from probes.registry import PROBES_FUNCTIONS
77

@@ -20,7 +20,7 @@ def main():
2020
parser.add_argument("--account", required=True, help="Account")
2121
parser.add_argument("--schema", required=True, help="Schema")
2222
parser.add_argument("--warehouse", required=True, help="Warehouse")
23-
parser.add_argument("--database", required=True, help="Datanase")
23+
parser.add_argument("--database", required=True, help="Database")
2424
parser.add_argument("--user", required=True, help="Username")
2525
parser.add_argument(
2626
"--auth", required=True, help="Authenticator (e.g., KEY_PAIR_AUTHENTICATOR)"

prober/probes/put_fetch_get.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
import csv
2+
import os
3+
import random
4+
5+
from faker import Faker
6+
from probes.logging_config import initialize_logger
7+
from probes.login import connect
8+
from probes.registry import prober_function # noqa
9+
10+
import snowflake.connector
11+
from snowflake.connector.util_text import random_string
12+
13+
# Initialize logger
14+
logger = initialize_logger(__name__)
15+
16+
17+
def generate_random_data(num_records: int, file_path: str) -> str:
18+
"""
19+
Generates random CSV data with the specified number of rows.
20+
21+
Args:
22+
num_records (int): Number of rows to generate.
23+
24+
Returns:
25+
str: File path to CSV file
26+
"""
27+
fake = Faker()
28+
with open(file_path, mode="w", newline="", encoding="utf-8") as csvfile:
29+
writer = csv.writer(csvfile, quoting=csv.QUOTE_ALL)
30+
writer.writerow(["id", "name", "email", "address"])
31+
for i in range(1, num_records + 1):
32+
writer.writerow([i, fake.name(), fake.email(), fake.address()])
33+
with open(file_path, newline="", encoding="utf-8") as csvfile:
34+
reader = csv.reader(csvfile)
35+
rows = list(reader)
36+
# Subtract 1 for the header row
37+
actual_records = len(rows) - 1
38+
assert actual_records == num_records, logger.error(
39+
f"Expected {num_records} records, but found {actual_records}."
40+
)
41+
return file_path
42+
43+
44+
def create_data_table(cursor: snowflake.connector.cursor.SnowflakeCursor) -> str:
45+
"""
46+
Creates a data table in Snowflake with the specified schema.
47+
48+
Returns:
49+
str: The name of the created table.
50+
"""
51+
table_name = random_string(7, "test_data_")
52+
create_table_query = f"""
53+
CREATE OR REPLACE TABLE {table_name} (
54+
id INT,
55+
name STRING,
56+
email STRING,
57+
address STRING
58+
);
59+
"""
60+
cursor.execute(create_table_query)
61+
if cursor.fetchone():
62+
print({"created_table": True})
63+
else:
64+
print({"created_table": False})
65+
return table_name
66+
67+
68+
def create_data_stage(cursor: snowflake.connector.cursor.SnowflakeCursor) -> str:
69+
"""
70+
Creates a stage in Snowflake for data upload.
71+
72+
Returns:
73+
str: The name of the created stage.
74+
"""
75+
stage_name = random_string(7, "test_data_stage_")
76+
create_stage_query = f"CREATE OR REPLACE STAGE {stage_name};"
77+
78+
cursor.execute(create_stage_query)
79+
if cursor.fetchone():
80+
print({"created_stage": True})
81+
else:
82+
print({"created_stage": False})
83+
return stage_name
84+
85+
86+
def copy_into_table_from_stage(
87+
table_name: str, stage_name: str, cur: snowflake.connector.cursor.SnowflakeCursor
88+
):
89+
"""
90+
Copies data from a stage into a specified table in Snowflake.
91+
92+
Args:
93+
table_name (str): The name of the table where data will be copied.
94+
stage_name (str): The name of the stage from which data will be copied.
95+
cur (snowflake.connector.cursor.SnowflakeCursor): The cursor to execute the SQL command.
96+
"""
97+
cur.execute(
98+
f"""
99+
COPY INTO {table_name}
100+
FROM @{stage_name}
101+
FILE_FORMAT = (TYPE = CSV FIELD_OPTIONALLY_ENCLOSED_BY = '"' SKIP_HEADER = 1);"""
102+
)
103+
104+
# Check if the data was loaded successfully
105+
if cur.fetchall()[0][1] == "LOADED":
106+
print({"copied_data_from_stage_into_table": True})
107+
else:
108+
print({"copied_data_from_stage_into_table": False})
109+
110+
111+
def put_file_to_stage(
112+
file_name: str, stage_name: str, cur: snowflake.connector.cursor.SnowflakeCursor
113+
):
114+
"""
115+
Uploads a file to a specified stage in Snowflake.
116+
117+
Args:
118+
file_name (str): The name of the file to upload.
119+
stage_name (str): The name of the stage where the file will be uploaded.
120+
cur (snowflake.connector.cursor.SnowflakeCursor): The cursor to execute the SQL command.
121+
"""
122+
response = cur.execute(
123+
f"PUT file://{file_name} @{stage_name} AUTO_COMPRESS=TRUE"
124+
).fetchall()
125+
logger.error(response)
126+
127+
if response[0][6] == "UPLOADED":
128+
print({"PUT_operation": True})
129+
else:
130+
print({"PUT_operation": False})
131+
132+
133+
def count_data_from_table(
134+
table_name: str, num_records: int, cur: snowflake.connector.cursor.SnowflakeCursor
135+
):
136+
count = cur.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
137+
if count == num_records:
138+
print({"data_transferred_completely": True})
139+
else:
140+
print({"data_transferred_completely": False})
141+
142+
143+
def compare_fetched_data(
144+
table_name: str,
145+
file_name: str,
146+
cur: snowflake.connector.cursor.SnowflakeCursor,
147+
repetitions: int = 10,
148+
fetch_limit: int = 100,
149+
):
150+
"""
151+
Compares the data fetched from the table with the data in the CSV file.
152+
153+
Args:
154+
table_name (str): The name of the table to fetch data from.
155+
file_name (str): The name of the CSV file to compare data against.
156+
cur (snowflake.connector.cursor.SnowflakeCursor): The cursor to execute the SQL command.
157+
repetitions (int): Number of times to repeat the comparison. Default is 10.
158+
fetch_limit (int): Number of rows to fetch from the table for comparison. Default is 100.
159+
"""
160+
161+
fetched_data = cur.execute(
162+
f"SELECT * FROM {table_name} LIMIT {fetch_limit}"
163+
).fetchall()
164+
165+
with open(file_name, newline="", encoding="utf-8") as csvfile:
166+
reader = csv.reader(csvfile)
167+
csv_data = list(reader)[1:] # Skip header row
168+
for _ in range(repetitions):
169+
random_index = random.randint(0, fetch_limit - 1)
170+
for y in range(len(fetched_data[0])):
171+
if str(fetched_data[random_index][y]) != csv_data[random_index][y]:
172+
print({"data_integrity_check": False})
173+
break
174+
print({"data_integrity_check": True})
175+
176+
177+
def execute_get_command(stage_name: str, conn: snowflake.connector.SnowflakeConnection):
178+
"""
179+
Downloads a file from a specified stage in Snowflake.
180+
181+
Args:
182+
stage_name (str): The name of the stage from which the file will be downloaded.
183+
conn (snowflake.connector.SnowflakeConnection): The connection object to execute the SQL command.
184+
"""
185+
download_dir = f"s3://{conn.account}/{stage_name}"
186+
187+
try:
188+
if not os.path.exists(download_dir):
189+
os.makedirs(download_dir)
190+
conn.cursor().execute(f"GET @{stage_name} file://{download_dir}/ ;")
191+
# Check if files are downloaded
192+
downloaded_files = os.listdir(download_dir)
193+
if downloaded_files:
194+
print({"GET_operation": True})
195+
else:
196+
print({"GET_operation": False})
197+
198+
finally:
199+
try:
200+
for file in os.listdir(download_dir):
201+
file_path = os.path.join(download_dir, file)
202+
if os.path.isfile(file_path):
203+
os.remove(file_path)
204+
os.rmdir(download_dir)
205+
except FileNotFoundError:
206+
logger.error(
207+
f"Error cleaning up directory {download_dir}. It may not exist or be empty."
208+
)
209+
210+
211+
def perform_put_fetch_get(connection_parameters: dict, num_records: int = 1000):
212+
"""
213+
Performs a PUT, fetch and GET operation using the provided connection parameters.
214+
215+
Args:
216+
connection_parameters (dict): A dictionary containing connection details such as
217+
host, port, user, password, account, schema, etc.
218+
num_records (int): Number of records to generate and PUT. Default is 10,000.
219+
"""
220+
try:
221+
with connect(connection_parameters) as conn:
222+
with conn.cursor() as cur:
223+
224+
logger.error("Creating stage")
225+
stage_name = create_data_stage(cur)
226+
logger.error(f"Stage {stage_name} created")
227+
228+
logger.error("Creating stage")
229+
table_name = create_data_table(cur)
230+
logger.error(f"Table {table_name} created")
231+
232+
logger.error("Generating random data")
233+
file_name = generate_random_data(num_records, f"{table_name}.csv")
234+
logger.error(f"Random data generated in {file_name}")
235+
236+
logger.error("PUT file to stage")
237+
put_file_to_stage(file_name, stage_name, cur)
238+
logger.error(f"File {file_name} uploaded to stage {stage_name}")
239+
240+
logger.error("Copying data from stage to table")
241+
copy_into_table_from_stage(table_name, stage_name, cur)
242+
logger.error(
243+
f"Data copied from stage {stage_name} to table {table_name}"
244+
)
245+
246+
logger.error("Counting data in the table")
247+
count_data_from_table(table_name, num_records, cur)
248+
249+
logger.error("Comparing fetched data with CSV file")
250+
compare_fetched_data(table_name, file_name, cur)
251+
252+
logger.error("Performing GET operation")
253+
execute_get_command(stage_name, conn)
254+
logger.error("File downloaded from stage to local directory")
255+
256+
except Exception as e:
257+
logger.error(f"Error during PUT/GET operation: {e}")
258+
259+
finally:
260+
# Cleanup: Remove data from the stage and delete table
261+
with connect(connection_parameters) as conn:
262+
with conn.cursor() as cur:
263+
cur.execute(f"REMOVE @{stage_name}")
264+
cur.execute(f"DROP TABLE {table_name}")
265+
266+
267+
# Disabled in MVP, uncomment to run
268+
# @prober_function
269+
def perform_put_fetch_get_100_lines(connection_parameters: dict):
270+
"""
271+
Performs a PUT and GET operation for 1,000 rows using the provided connection parameters.
272+
273+
Args:
274+
connection_parameters (dict): A dictionary containing connection details such as
275+
host, port, user, password, account, schema, etc.
276+
"""
277+
perform_put_fetch_get(connection_parameters, num_records=100)

prober/setup.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@
44
name="snowflake_prober",
55
version="1.0.0",
66
packages=find_packages(),
7-
install_requires=[
8-
"snowflake-connector-python",
9-
"requests",
10-
],
7+
install_requires=["snowflake-connector-python", "requests", "faker"],
118
entry_points={
129
"console_scripts": [
1310
"prober=probes.main:main",

0 commit comments

Comments
 (0)