1
- import os
2
1
import json
3
- import re
4
2
import time
5
3
import logging
6
4
from typing import Any , List , Dict , Optional
7
5
8
6
import httpx
9
7
10
8
from config import (
11
- MCP_TRANSPORT ,
12
- SNOWFLAKE_BASE_URL ,
13
- SNOWFLAKE_DATABASE ,
9
+ SNOWFLAKE_BASE_URL ,
10
+ SNOWFLAKE_DATABASE ,
14
11
SNOWFLAKE_SCHEMA ,
15
12
SNOWFLAKE_TOKEN
16
13
)
17
14
from metrics import track_snowflake_query
18
15
19
16
logger = logging .getLogger (__name__ )
20
17
18
+
21
19
def sanitize_sql_value (value : str ) -> str :
22
20
"""Sanitize a SQL value to prevent injection attacks"""
23
21
if not isinstance (value , str ):
@@ -26,37 +24,38 @@ def sanitize_sql_value(value: str) -> str:
26
24
# For string values, we'll escape single quotes by doubling them
27
25
return value .replace ("'" , "''" )
28
26
27
+
29
28
async def make_snowflake_request (
30
- endpoint : str ,
31
- method : str = "POST" ,
29
+ endpoint : str ,
30
+ method : str = "POST" ,
32
31
data : dict [str , Any ] = None ,
33
32
snowflake_token : Optional [str ] = None
34
33
) -> dict [str , Any ] | None :
35
34
"""Make a request to Snowflake API"""
36
35
# Use provided token or fall back to config
37
36
token = snowflake_token or SNOWFLAKE_TOKEN
38
-
37
+
39
38
if not token :
40
39
logger .error ("SNOWFLAKE_TOKEN environment variable is required but not set" )
41
40
return None
42
-
41
+
43
42
headers = {
44
43
"Authorization" : f"Bearer { token } " ,
45
44
"Accept" : "application/json" ,
46
45
"Content-Type" : "application/json"
47
46
}
48
-
47
+
49
48
url = f"{ SNOWFLAKE_BASE_URL } /{ endpoint } "
50
-
49
+
51
50
try :
52
51
async with httpx .AsyncClient (timeout = 30.0 ) as client :
53
52
if method .upper () == "GET" :
54
53
response = await client .request (method , url , headers = headers , params = data )
55
54
else :
56
55
response = await client .request (method , url , headers = headers , json = data )
57
-
56
+
58
57
response .raise_for_status ()
59
-
58
+
60
59
# Try to parse JSON, but handle cases where response is not valid JSON
61
60
try :
62
61
return response .json ()
@@ -65,19 +64,20 @@ async def make_snowflake_request(
65
64
logger .error (f"Response content: { response .text [:500 ]} ..." ) # Log first 500 chars
66
65
# Return None to indicate error, which will be handled by calling functions
67
66
return None
68
-
67
+
69
68
except httpx .HTTPStatusError as http_error :
70
69
logger .error (f"HTTP error from Snowflake API: { http_error .response .status_code } - { http_error .response .text } " )
71
70
return None
72
71
except Exception as e :
73
72
logger .error (f"Unexpected error in Snowflake API request: { str (e )} " )
74
73
return None
75
74
75
+
76
76
async def execute_snowflake_query (sql : str , snowflake_token : Optional [str ] = None ) -> List [Dict [str , Any ]]:
77
77
"""Execute a SQL query against Snowflake and return results"""
78
78
start_time = time .time ()
79
79
success = False
80
-
80
+
81
81
try :
82
82
# Use the statements endpoint to execute SQL
83
83
endpoint = "statements"
@@ -87,16 +87,16 @@ async def execute_snowflake_query(sql: str, snowflake_token: Optional[str] = Non
87
87
"database" : SNOWFLAKE_DATABASE ,
88
88
"schema" : SNOWFLAKE_SCHEMA
89
89
}
90
-
90
+
91
91
logger .info (f"Executing Snowflake query: { sql [:100 ]} ..." ) # Log first 100 chars of query
92
-
92
+
93
93
response = await make_snowflake_request (endpoint , "POST" , payload , snowflake_token )
94
-
94
+
95
95
# Check if response is None (indicating an error in API request or JSON parsing)
96
96
if response is None :
97
97
logger .error ("Failed to get valid response from Snowflake API" )
98
98
return []
99
-
99
+
100
100
# Parse the response to extract data
101
101
if response and "data" in response :
102
102
logger .info (f"Successfully got { len (response ['data' ])} rows from Snowflake" )
@@ -109,109 +109,112 @@ async def execute_snowflake_query(sql: str, snowflake_token: Optional[str] = Non
109
109
logger .info (f"Successfully got { len (result_set ['data' ])} rows from Snowflake (resultSet format)" )
110
110
success = True
111
111
return result_set ["data" ]
112
-
112
+
113
113
logger .warning ("No data found in Snowflake response" )
114
114
success = True # No data is still a successful query
115
115
return []
116
-
116
+
117
117
except Exception as e :
118
118
logger .error (f"Error executing Snowflake query: { str (e )} " )
119
119
logger .error (f"Query that failed: { sql } " )
120
120
return []
121
121
finally :
122
122
track_snowflake_query (start_time , success )
123
123
124
+
124
125
def format_snowflake_row (row_data : List [Any ], columns : List [str ]) -> Dict [str , Any ]:
125
126
"""Convert Snowflake row data to dictionary using column names"""
126
127
if len (row_data ) != len (columns ):
127
128
return {}
128
-
129
+
129
130
return {columns [i ]: row_data [i ] for i in range (len (columns ))}
130
131
132
+
131
133
async def get_issue_labels (issue_ids : List [str ], snowflake_token : Optional [str ] = None ) -> Dict [str , List [str ]]:
132
134
"""Get labels for given issue IDs from Snowflake"""
133
135
if not issue_ids :
134
136
return {}
135
-
137
+
136
138
labels_data = {}
137
-
139
+
138
140
try :
139
141
# Sanitize and validate issue IDs (should be numeric)
140
142
sanitized_ids = []
141
143
for issue_id in issue_ids :
142
144
# Ensure issue IDs are numeric to prevent injection
143
145
if isinstance (issue_id , (str , int )) and str (issue_id ).isdigit ():
144
146
sanitized_ids .append (str (issue_id ))
145
-
147
+
146
148
if not sanitized_ids :
147
149
return {}
148
-
150
+
149
151
# Create comma-separated list for IN clause
150
152
ids_str = "'" + "','" .join (sanitized_ids ) + "'"
151
-
153
+
152
154
sql = f"""
153
- SELECT ISSUE, LABEL
154
- FROM JIRA_LABEL_RHAI
155
+ SELECT ISSUE, LABEL
156
+ FROM JIRA_LABEL_RHAI
155
157
WHERE ISSUE IN ({ ids_str } ) AND LABEL IS NOT NULL
156
158
"""
157
-
159
+
158
160
rows = await execute_snowflake_query (sql , snowflake_token )
159
161
columns = ["ISSUE" , "LABEL" ]
160
-
162
+
161
163
for row in rows :
162
164
row_dict = format_snowflake_row (row , columns )
163
165
issue_id = str (row_dict .get ("ISSUE" ))
164
166
label = row_dict .get ("LABEL" )
165
-
167
+
166
168
if issue_id and label :
167
169
if issue_id not in labels_data :
168
170
labels_data [issue_id ] = []
169
171
labels_data [issue_id ].append (label )
170
-
172
+
171
173
except Exception as e :
172
174
logger .error (f"Error fetching labels: { str (e )} " )
173
-
175
+
174
176
return labels_data
175
177
178
+
176
179
async def get_issue_comments (issue_ids : List [str ], snowflake_token : Optional [str ] = None ) -> Dict [str , List [Dict [str , Any ]]]:
177
180
"""Get comments for given issue IDs from Snowflake"""
178
181
if not issue_ids :
179
182
return {}
180
-
183
+
181
184
comments_data = {}
182
-
185
+
183
186
try :
184
187
# Sanitize and validate issue IDs (should be numeric)
185
188
sanitized_ids = []
186
189
for issue_id in issue_ids :
187
190
# Ensure issue IDs are numeric to prevent injection
188
191
if isinstance (issue_id , (str , int )) and str (issue_id ).isdigit ():
189
192
sanitized_ids .append (str (issue_id ))
190
-
193
+
191
194
if not sanitized_ids :
192
195
return {}
193
-
196
+
194
197
# Create comma-separated list for IN clause
195
198
ids_str = "'" + "','" .join (sanitized_ids ) + "'"
196
-
199
+
197
200
sql = f"""
198
201
SELECT ID, ISSUEID, ROLELEVEL, BODY, CREATED, UPDATED
199
- FROM JIRA_COMMENT_NON_PII
202
+ FROM JIRA_COMMENT_NON_PII
200
203
WHERE ISSUEID IN ({ ids_str } ) AND BODY IS NOT NULL
201
204
ORDER BY ISSUEID, CREATED ASC
202
205
"""
203
-
206
+
204
207
rows = await execute_snowflake_query (sql , snowflake_token )
205
208
columns = ["ID" , "ISSUEID" , "ROLELEVEL" , "BODY" , "CREATED" , "UPDATED" ]
206
-
209
+
207
210
for row in rows :
208
211
row_dict = format_snowflake_row (row , columns )
209
212
issue_id = str (row_dict .get ("ISSUEID" ))
210
-
213
+
211
214
if issue_id :
212
215
if issue_id not in comments_data :
213
216
comments_data [issue_id ] = []
214
-
217
+
215
218
comment = {
216
219
"id" : row_dict .get ("ID" ),
217
220
"role_level" : row_dict .get ("ROLELEVEL" ),
@@ -220,8 +223,8 @@ async def get_issue_comments(issue_ids: List[str], snowflake_token: Optional[str
220
223
"updated" : row_dict .get ("UPDATED" )
221
224
}
222
225
comments_data [issue_id ].append (comment )
223
-
226
+
224
227
except Exception as e :
225
228
logger .error (f"Error fetching comments: { str (e )} " )
226
-
227
- return comments_data
229
+
230
+ return comments_data
0 commit comments