Skip to content

Commit 93ecd2f

Browse files
SNOW-1994541: Change cortex complete to use snowflake.core APIs (#2260)
Co-authored-by: Patryk Czajka <[email protected]>
1 parent c9045c9 commit 93ecd2f

File tree

5 files changed

+130
-33
lines changed

5 files changed

+130
-33
lines changed

src/snowflake/cli/_plugins/cortex/commands.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
from __future__ import annotations
1616

1717
import sys
18+
from enum import Enum
1819
from pathlib import Path
1920
from typing import List, Optional
2021

2122
import click
2223
import typer
2324
from click import UsageError
24-
from snowflake.cli._plugins.cortex.constants import DEFAULT_MODEL
25+
from snowflake.cli._plugins.cortex.constants import (
26+
DEFAULT_BACKEND,
27+
DEFAULT_MODEL,
28+
)
2529
from snowflake.cli._plugins.cortex.manager import CortexManager
2630
from snowflake.cli._plugins.cortex.types import (
2731
Language,
@@ -36,7 +40,7 @@
3640
OverrideableOption,
3741
)
3842
from snowflake.cli.api.commands.snow_typer import SnowTyperFactory
39-
from snowflake.cli.api.constants import PYTHON_3_12
43+
from snowflake.cli.api.constants import DEFAULT_SIZE_LIMIT_MB, PYTHON_3_12
4044
from snowflake.cli.api.output.types import (
4145
CollectionResult,
4246
CommandResult,
@@ -115,6 +119,11 @@ def search(
115119
return CollectionResult(response.results)
116120

117121

122+
class Backend(Enum):
123+
SQL = "sql"
124+
REST = "rest"
125+
126+
118127
@app.command(
119128
name="complete",
120129
requires_connection=True,
@@ -130,6 +139,11 @@ def complete(
130139
"--model",
131140
help="String specifying the model to be used.",
132141
),
142+
backend: Optional[Backend] = typer.Option(
143+
DEFAULT_BACKEND,
144+
"--backend",
145+
help="String specifying whether to use sql or rest backend.",
146+
),
133147
file: Optional[Path] = ExclusiveReadableFileOption(
134148
help="JSON file containing conversation history to be used to generate a completion. Cannot be combined with TEXT argument.",
135149
),
@@ -143,18 +157,30 @@ def complete(
143157

144158
manager = CortexManager()
145159

160+
is_file_input: bool = False
146161
if text:
147-
result_text = manager.complete_for_prompt(
148-
text=Text(text),
162+
prompt = text
163+
elif file:
164+
prompt = SecurePath(file).read_text(file_size_limit_mb=DEFAULT_SIZE_LIMIT_MB)
165+
is_file_input = True
166+
else:
167+
raise UsageError("Either --file option or TEXT argument has to be provided.")
168+
169+
if backend == Backend.SQL:
170+
result_text = manager.complete(
171+
text=Text(prompt),
149172
model=Model(model),
173+
is_file_input=is_file_input,
150174
)
151-
elif file:
152-
result_text = manager.complete_for_conversation(
153-
conversation_json_file=SecurePath(file),
175+
elif backend == Backend.REST:
176+
root = get_cli_context().snow_api_root
177+
result_text = manager.rest_complete(
178+
text=Text(prompt),
154179
model=Model(model),
180+
root=root,
155181
)
156182
else:
157-
raise UsageError("Either --file option or TEXT argument has to be provided.")
183+
raise UsageError("--backend option should be either rest or sql.")
158184

159185
return MessageResult(result_text.strip())
160186

src/snowflake/cli/_plugins/cortex/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414

1515
from snowflake.cli._plugins.cortex.types import Model
1616

17-
DEFAULT_MODEL: Model = Model("snowflake-arctic")
17+
DEFAULT_MODEL: Model = Model("llama3.1-70b")
18+
DEFAULT_BACKEND = "rest"

src/snowflake/cli/_plugins/cortex/manager.py

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,43 +32,103 @@
3232
from snowflake.cli.api.sql_execution import SqlExecutionMixin
3333
from snowflake.connector import ProgrammingError
3434
from snowflake.connector.cursor import DictCursor
35+
from snowflake.core._root import Root
36+
from snowflake.core.cortex.inference_service import CortexInferenceService
37+
from snowflake.core.cortex.inference_service._generated.models import CompleteRequest
38+
from snowflake.core.cortex.inference_service._generated.models.complete_request_messages_inner import (
39+
CompleteRequestMessagesInner,
40+
)
3541

3642
log = logging.getLogger(__name__)
3743

3844

45+
class ResponseParseError(Exception):
46+
"""This exception is raised when the server response cannot be parsed."""
47+
48+
pass
49+
50+
51+
class MidStreamError(Exception):
52+
"""The SSE (Server-sent Event) stream can contain error messages in the middle of the stream,
53+
using the “error” event type. This exception is raised when there is such a mid-stream error."""
54+
55+
def __init__(
56+
self,
57+
reason: Optional[str] = None,
58+
) -> None:
59+
message = ""
60+
if reason is not None:
61+
message = reason
62+
super().__init__(message)
63+
64+
3965
class CortexManager(SqlExecutionMixin):
40-
def complete_for_prompt(
66+
def complete(
4167
self,
4268
text: Text,
4369
model: Model,
70+
is_file_input: bool = False,
4471
) -> str:
45-
query = f"""\
72+
if not is_file_input:
73+
query = f"""\
74+
SELECT SNOWFLAKE.CORTEX.COMPLETE(
75+
'{model}',
76+
'{self._escape_input(text)}'
77+
) AS CORTEX_RESULT;"""
78+
return self._query_cortex_result_str(query)
79+
else:
80+
query = f"""\
4681
SELECT SNOWFLAKE.CORTEX.COMPLETE(
4782
'{model}',
48-
'{self._escape_input(text)}'
83+
PARSE_JSON('{self._escape_input(text)}'),
84+
{{}}
4985
) AS CORTEX_RESULT;"""
50-
return self._query_cortex_result_str(query)
86+
raw_result = self._query_cortex_result_str(query)
87+
json_result = json.loads(raw_result)
88+
return self._extract_text_result_from_json_result(
89+
lambda: json_result["choices"][0]["messages"]
90+
)
5191

52-
def complete_for_conversation(
92+
def make_rest_complete_request(
5393
self,
54-
conversation_json_file: SecurePath,
5594
model: Model,
56-
) -> str:
57-
json_content = conversation_json_file.read_text(
58-
file_size_limit_mb=DEFAULT_SIZE_LIMIT_MB
59-
)
60-
query = f"""\
61-
SELECT SNOWFLAKE.CORTEX.COMPLETE(
62-
'{model}',
63-
PARSE_JSON('{self._escape_input(json_content)}'),
64-
{{}}
65-
) AS CORTEX_RESULT;"""
66-
raw_result = self._query_cortex_result_str(query)
67-
json_result = json.loads(raw_result)
68-
return self._extract_text_result_from_json_result(
69-
lambda: json_result["choices"][0]["messages"]
95+
prompt: Text,
96+
) -> CompleteRequest:
97+
return CompleteRequest(
98+
model=str(model),
99+
messages=[CompleteRequestMessagesInner(content=str(prompt))],
100+
stream=True,
70101
)
71102

103+
def rest_complete(
104+
self,
105+
text: Text,
106+
model: Model,
107+
root: "Root",
108+
) -> str:
109+
complete_request = self.make_rest_complete_request(model=model, prompt=text)
110+
cortex_inference_service = CortexInferenceService(root=root)
111+
try:
112+
raw_resp = cortex_inference_service.complete(
113+
complete_request=complete_request
114+
)
115+
except Exception as e:
116+
raise
117+
result = ""
118+
for event in raw_resp.events():
119+
try:
120+
parsed_resp = json.loads(event.data)
121+
except json.JSONDecodeError:
122+
raise ResponseParseError("Server response cannot be parsed")
123+
try:
124+
result += parsed_resp["choices"][0]["delta"]["content"]
125+
except (json.JSONDecodeError, KeyError, IndexError):
126+
if parsed_resp.get("error"):
127+
raise MidStreamError(reason=event.data)
128+
else:
129+
pass
130+
return result
131+
72132
def extract_answer_from_source_document(
73133
self,
74134
source_document: SourceDocument,
@@ -170,7 +230,7 @@ def _escape_input(plain_input: str):
170230

171231
@staticmethod
172232
def _extract_text_result_from_json_result(
173-
extract_function: Callable[[], str]
233+
extract_function: Callable[[], str],
174234
) -> str:
175235
try:
176236
return extract_function()

tests/__snapshots__/test_docs_generation_output.ambr

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
snow cortex complete
1212
<text>
1313
--model <model>
14+
--backend <backend>
1415
--file <file>
1516
--connection <connection>
1617
--host <host>
@@ -56,7 +57,10 @@
5657
===============================================================================
5758

5859
:samp:`--model {TEXT}`
59-
String specifying the model to be used. Default: snowflake-arctic.
60+
String specifying the model to be used. Default: llama3.1-70b.
61+
62+
:samp:`--backend [sql|rest]`
63+
String specifying whether to use sql or rest backend. Default: rest.
6064

6165
:samp:`--file {FILE}`
6266
JSON file containing conversation history to be used to generate a completion. Cannot be combined with TEXT argument.

tests/cortex/test_cortex_commands.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ def _mock(raw_result: Any, expected_query: Optional[str] = None):
5151
def test_cortex_complete_for_prompt_with_default_model(_mock_cortex_result, runner):
5252
with _mock_cortex_result(
5353
raw_result="Yes",
54-
expected_query="SELECT SNOWFLAKE.CORTEX.COMPLETE( 'snowflake-arctic', 'Is 5 more than 4? Please answer using one word without a period.' ) AS CORTEX_RESULT;",
54+
expected_query="SELECT SNOWFLAKE.CORTEX.COMPLETE( 'llama3.1-70b', 'Is 5 more than 4? Please answer using one word without a period.' ) AS CORTEX_RESULT;",
5555
):
5656
result = runner.invoke(
5757
[
5858
"cortex",
5959
"complete",
6060
"Is 5 more than 4? Please answer using one word without a period.",
61+
"--backend",
62+
"sql",
6163
]
6264
)
6365
assert_successful_result_message(result, expected_msg="Yes")
@@ -113,6 +115,8 @@ def test_cortex_complete_for_prompt_with_chosen_model(_mock_cortex_result, runne
113115
"Is 5 more than 4? Please answer using one word without a period.",
114116
"--model",
115117
"reka-flash",
118+
"--backend",
119+
"sql",
116120
]
117121
)
118122
assert_successful_result_message(result, expected_msg="Yes")
@@ -121,14 +125,16 @@ def test_cortex_complete_for_prompt_with_chosen_model(_mock_cortex_result, runne
121125
def test_cortex_complete_for_file(_mock_cortex_result, runner):
122126
with _mock_cortex_result(
123127
raw_result="""{"choices": [{"messages": "No, I'm not"}]}""",
124-
expected_query="""SELECT SNOWFLAKE.CORTEX.COMPLETE( 'snowflake-arctic', PARSE_JSON('[ { "role": "user", "content": "how does a \\\\"snowflake\\\\" get its \\'unique\\' pattern?" }, { "role": "system", "content": "I don\\'t know" }, { "role": "user", "content": "I thought \\\\"you\\\\" are smarter" } ] '), {} ) AS CORTEX_RESULT;""",
128+
expected_query="""SELECT SNOWFLAKE.CORTEX.COMPLETE( 'llama3.1-70b', PARSE_JSON('[ { "role": "user", "content": "how does a \\\\"snowflake\\\\" get its \\'unique\\' pattern?" }, { "role": "system", "content": "I don\\'t know" }, { "role": "user", "content": "I thought \\\\"you\\\\" are smarter" } ] '), {} ) AS CORTEX_RESULT;""",
125129
):
126130
result = runner.invoke(
127131
[
128132
"cortex",
129133
"complete",
130134
"--file",
131135
str(TEST_DIR / "test_data/cortex/conversation.json"),
136+
"--backend",
137+
"sql",
132138
]
133139
)
134140
assert_successful_result_message(result, expected_msg="No, I'm not")

0 commit comments

Comments
 (0)