Skip to content

Commit 633a820

Browse files
authored
Merge pull request #18 from redhat-ai-tools/fix-remote-exec
Fix remote setup
2 parents f1ff9d2 + 5573f89 commit 633a820

File tree

3 files changed

+66
-29
lines changed

3 files changed

+66
-29
lines changed

src/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616
SNOWFLAKE_DATABASE = os.environ.get("SNOWFLAKE_DATABASE")
1717
SNOWFLAKE_SCHEMA = os.environ.get("SNOWFLAKE_SCHEMA")
1818

19+
# Snowflake token handling - for stdio transport, get from environment
20+
# For other transports, it will be retrieved from request context in tools layer
21+
if MCP_TRANSPORT == "stdio":
22+
SNOWFLAKE_TOKEN = os.environ.get("SNOWFLAKE_TOKEN")
23+
else:
24+
# For non-stdio transports, token will be passed from tools layer
25+
SNOWFLAKE_TOKEN = None
26+
1927
# Prometheus metrics configuration
2028
ENABLE_METRICS = os.environ.get("ENABLE_METRICS", "false").lower() == "true"
2129
METRICS_PORT = int(os.environ.get("METRICS_PORT", "8000"))

src/database.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
import re
44
import time
55
import logging
6-
from typing import Any, List, Dict
6+
from typing import Any, List, Dict, Optional
77

88
import httpx
99

1010
from config import (
1111
MCP_TRANSPORT,
1212
SNOWFLAKE_BASE_URL,
1313
SNOWFLAKE_DATABASE,
14-
SNOWFLAKE_SCHEMA
14+
SNOWFLAKE_SCHEMA,
15+
SNOWFLAKE_TOKEN
1516
)
1617
from metrics import track_snowflake_query
1718

@@ -28,22 +29,19 @@ def sanitize_sql_value(value: str) -> str:
2829
async def make_snowflake_request(
2930
endpoint: str,
3031
method: str = "POST",
31-
data: dict[str, Any] = None
32+
data: dict[str, Any] = None,
33+
snowflake_token: Optional[str] = None
3234
) -> dict[str, Any] | None:
3335
"""Make a request to Snowflake API"""
34-
# Get token based on transport type
35-
if MCP_TRANSPORT == "stdio":
36-
snowflake_token = os.environ.get("SNOWFLAKE_TOKEN")
37-
else:
38-
# This would need to be passed in or handled differently in non-stdio mode
39-
snowflake_token = os.environ.get("SNOWFLAKE_TOKEN")
36+
# Use provided token or fall back to config
37+
token = snowflake_token or SNOWFLAKE_TOKEN
4038

41-
if not snowflake_token:
39+
if not token:
4240
logger.error("SNOWFLAKE_TOKEN environment variable is required but not set")
4341
return None
4442

4543
headers = {
46-
"Authorization": f"Bearer {snowflake_token}",
44+
"Authorization": f"Bearer {token}",
4745
"Accept": "application/json",
4846
"Content-Type": "application/json"
4947
}
@@ -75,7 +73,7 @@ async def make_snowflake_request(
7573
logger.error(f"Unexpected error in Snowflake API request: {str(e)}")
7674
return None
7775

78-
async def execute_snowflake_query(sql: str) -> List[Dict[str, Any]]:
76+
async def execute_snowflake_query(sql: str, snowflake_token: Optional[str] = None) -> List[Dict[str, Any]]:
7977
"""Execute a SQL query against Snowflake and return results"""
8078
start_time = time.time()
8179
success = False
@@ -92,7 +90,7 @@ async def execute_snowflake_query(sql: str) -> List[Dict[str, Any]]:
9290

9391
logger.info(f"Executing Snowflake query: {sql[:100]}...") # Log first 100 chars of query
9492

95-
response = await make_snowflake_request(endpoint, "POST", payload)
93+
response = await make_snowflake_request(endpoint, "POST", payload, snowflake_token)
9694

9795
# Check if response is None (indicating an error in API request or JSON parsing)
9896
if response is None:
@@ -130,7 +128,7 @@ def format_snowflake_row(row_data: List[Any], columns: List[str]) -> Dict[str, A
130128

131129
return {columns[i]: row_data[i] for i in range(len(columns))}
132130

133-
async def get_issue_labels(issue_ids: List[str]) -> Dict[str, List[str]]:
131+
async def get_issue_labels(issue_ids: List[str], snowflake_token: Optional[str] = None) -> Dict[str, List[str]]:
134132
"""Get labels for given issue IDs from Snowflake"""
135133
if not issue_ids:
136134
return {}
@@ -157,7 +155,7 @@ async def get_issue_labels(issue_ids: List[str]) -> Dict[str, List[str]]:
157155
WHERE ISSUE IN ({ids_str}) AND LABEL IS NOT NULL
158156
"""
159157

160-
rows = await execute_snowflake_query(sql)
158+
rows = await execute_snowflake_query(sql, snowflake_token)
161159
columns = ["ISSUE", "LABEL"]
162160

163161
for row in rows:

src/tools.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from mcp.server.fastmcp import FastMCP
55

6+
from config import MCP_TRANSPORT, SNOWFLAKE_TOKEN
67
from database import (
78
execute_snowflake_query,
89
format_snowflake_row,
@@ -13,12 +14,27 @@
1314

1415
logger = logging.getLogger(__name__)
1516

17+
def get_snowflake_token(mcp: FastMCP) -> Optional[str]:
18+
"""Get Snowflake token from either config (stdio) or request headers (non-stdio)"""
19+
if MCP_TRANSPORT == "stdio":
20+
return SNOWFLAKE_TOKEN
21+
else:
22+
try:
23+
# Get token from request headers for non-stdio transports
24+
context = mcp.get_context()
25+
if context and hasattr(context, 'request_context') and context.request_context:
26+
headers = context.request_context.request.headers
27+
return headers.get("X-Snowflake-Token")
28+
except Exception as e:
29+
logger.error(f"Error getting token from request context: {e}")
30+
return None
31+
1632
def register_tools(mcp: FastMCP) -> None:
1733
"""Register all MCP tools"""
1834

1935
@mcp.tool()
20-
@track_tool_usage("list_issues")
21-
async def list_issues(
36+
@track_tool_usage("list_jira_issues")
37+
async def list_jira_issues(
2238
project: Optional[str] = None,
2339
issue_type: Optional[str] = None,
2440
status: Optional[str] = None,
@@ -41,6 +57,11 @@ async def list_issues(
4157
Dictionary containing issues list and metadata
4258
"""
4359
try:
60+
# Get the Snowflake token
61+
snowflake_token = get_snowflake_token(mcp)
62+
if not snowflake_token:
63+
return {"error": "Snowflake token not available", "issues": []}
64+
4465
# Build SQL query with filters
4566
sql_conditions = []
4667

@@ -77,7 +98,7 @@ async def list_issues(
7798
LIMIT {limit}
7899
"""
79100

80-
rows = await execute_snowflake_query(sql)
101+
rows = await execute_snowflake_query(sql, snowflake_token)
81102

82103
issues = []
83104
issue_ids = []
@@ -121,7 +142,7 @@ async def list_issues(
121142
issue_ids.append(str(row_dict.get("ID")))
122143

123144
# Get labels for enrichment
124-
labels_data = await get_issue_labels(issue_ids)
145+
labels_data = await get_issue_labels(issue_ids, snowflake_token)
125146

126147
# Enrich issues with labels
127148
for issue in issues:
@@ -145,8 +166,8 @@ async def list_issues(
145166
return {"error": f"Error reading issues from Snowflake: {str(e)}", "issues": []}
146167

147168
@mcp.tool()
148-
@track_tool_usage("get_issue_details")
149-
async def get_issue_details(issue_key: str) -> Dict[str, Any]:
169+
@track_tool_usage("get_jira_issue_details")
170+
async def get_jira_issue_details(issue_key: str) -> Dict[str, Any]:
150171
"""
151172
Get detailed information for a specific JIRA issue by its key from Snowflake.
152173
@@ -157,6 +178,11 @@ async def get_issue_details(issue_key: str) -> Dict[str, Any]:
157178
Dictionary containing detailed issue information
158179
"""
159180
try:
181+
# Get the Snowflake token
182+
snowflake_token = get_snowflake_token(mcp)
183+
if not snowflake_token:
184+
return {"error": "Snowflake token not available"}
185+
160186
sql = f"""
161187
SELECT
162188
ID, ISSUE_KEY, PROJECT, ISSUENUM, ISSUETYPE, SUMMARY, DESCRIPTION,
@@ -169,7 +195,7 @@ async def get_issue_details(issue_key: str) -> Dict[str, Any]:
169195
LIMIT 1
170196
"""
171197

172-
rows = await execute_snowflake_query(sql)
198+
rows = await execute_snowflake_query(sql, snowflake_token)
173199

174200
if not rows:
175201
return {"error": f"Issue with key '{issue_key}' not found"}
@@ -215,24 +241,29 @@ async def get_issue_details(issue_key: str) -> Dict[str, Any]:
215241
}
216242

217243
# Get labels for this issue
218-
labels_data = await get_issue_labels([str(issue['id'])])
244+
labels_data = await get_issue_labels([str(issue['id'])], snowflake_token)
219245
issue['labels'] = labels_data.get(str(issue['id']), [])
220246

221-
return {"issue": issue}
247+
return issue
222248

223249
except Exception as e:
224-
return {"error": f"Error retrieving issue details from Snowflake: {str(e)}"}
250+
return {"error": f"Error reading issue details from Snowflake: {str(e)}"}
225251

226252
@mcp.tool()
227-
@track_tool_usage("get_project_summary")
228-
async def get_project_summary() -> Dict[str, Any]:
253+
@track_tool_usage("get_jira_project_summary")
254+
async def get_jira_project_summary() -> Dict[str, Any]:
229255
"""
230256
Get a summary of all projects in the JIRA data from Snowflake.
231257
232258
Returns:
233259
Dictionary containing project statistics
234260
"""
235261
try:
262+
# Get the Snowflake token
263+
snowflake_token = get_snowflake_token(mcp)
264+
if not snowflake_token:
265+
return {"error": "Snowflake token not available"}
266+
236267
sql = """
237268
SELECT
238269
PROJECT,
@@ -244,7 +275,7 @@ async def get_project_summary() -> Dict[str, Any]:
244275
ORDER BY PROJECT, ISSUESTATUS, PRIORITY
245276
"""
246277

247-
rows = await execute_snowflake_query(sql)
278+
rows = await execute_snowflake_query(sql, snowflake_token)
248279
columns = ["PROJECT", "ISSUESTATUS", "PRIORITY", "COUNT"]
249280

250281
project_stats = {}
@@ -256,7 +287,7 @@ async def get_project_summary() -> Dict[str, Any]:
256287
project = row_dict.get("PROJECT", "Unknown")
257288
status = row_dict.get("ISSUESTATUS", "Unknown")
258289
priority = row_dict.get("PRIORITY", "Unknown")
259-
count = row_dict.get("COUNT", 0)
290+
count = int(row_dict.get("COUNT", 0)) if row_dict.get("COUNT") is not None else 0
260291

261292
if project not in project_stats:
262293
project_stats[project] = {

0 commit comments

Comments
 (0)