|
32 | 32 | from snowflake.cli.api.sql_execution import SqlExecutionMixin
|
33 | 33 | from snowflake.connector import ProgrammingError
|
34 | 34 | from snowflake.connector.cursor import DictCursor
|
| 35 | +from snowflake.core._root import Root |
| 36 | +from snowflake.core.cortex.inference_service import CortexInferenceService |
| 37 | +from snowflake.core.cortex.inference_service._generated.models import CompleteRequest |
| 38 | +from snowflake.core.cortex.inference_service._generated.models.complete_request_messages_inner import ( |
| 39 | + CompleteRequestMessagesInner, |
| 40 | +) |
35 | 41 |
|
36 | 42 | log = logging.getLogger(__name__)
|
37 | 43 |
|
38 | 44 |
|
| 45 | +class ResponseParseError(Exception): |
| 46 | + """This exception is raised when the server response cannot be parsed.""" |
| 47 | + |
| 48 | + pass |
| 49 | + |
| 50 | + |
| 51 | +class MidStreamError(Exception): |
| 52 | + """The SSE (Server-sent Event) stream can contain error messages in the middle of the stream, |
| 53 | + using the “error” event type. This exception is raised when there is such a mid-stream error.""" |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + reason: Optional[str] = None, |
| 58 | + ) -> None: |
| 59 | + message = "" |
| 60 | + if reason is not None: |
| 61 | + message = reason |
| 62 | + super().__init__(message) |
| 63 | + |
| 64 | + |
39 | 65 | class CortexManager(SqlExecutionMixin):
|
40 |
| - def complete_for_prompt( |
| 66 | + def complete( |
41 | 67 | self,
|
42 | 68 | text: Text,
|
43 | 69 | model: Model,
|
| 70 | + is_file_input: bool = False, |
44 | 71 | ) -> str:
|
45 |
| - query = f"""\ |
| 72 | + if not is_file_input: |
| 73 | + query = f"""\ |
| 74 | + SELECT SNOWFLAKE.CORTEX.COMPLETE( |
| 75 | + '{model}', |
| 76 | + '{self._escape_input(text)}' |
| 77 | + ) AS CORTEX_RESULT;""" |
| 78 | + return self._query_cortex_result_str(query) |
| 79 | + else: |
| 80 | + query = f"""\ |
46 | 81 | SELECT SNOWFLAKE.CORTEX.COMPLETE(
|
47 | 82 | '{model}',
|
48 |
| - '{self._escape_input(text)}' |
| 83 | + PARSE_JSON('{self._escape_input(text)}'), |
| 84 | + {{}} |
49 | 85 | ) AS CORTEX_RESULT;"""
|
50 |
| - return self._query_cortex_result_str(query) |
| 86 | + raw_result = self._query_cortex_result_str(query) |
| 87 | + json_result = json.loads(raw_result) |
| 88 | + return self._extract_text_result_from_json_result( |
| 89 | + lambda: json_result["choices"][0]["messages"] |
| 90 | + ) |
51 | 91 |
|
52 |
| - def complete_for_conversation( |
| 92 | + def make_rest_complete_request( |
53 | 93 | self,
|
54 |
| - conversation_json_file: SecurePath, |
55 | 94 | model: Model,
|
56 |
| - ) -> str: |
57 |
| - json_content = conversation_json_file.read_text( |
58 |
| - file_size_limit_mb=DEFAULT_SIZE_LIMIT_MB |
59 |
| - ) |
60 |
| - query = f"""\ |
61 |
| - SELECT SNOWFLAKE.CORTEX.COMPLETE( |
62 |
| - '{model}', |
63 |
| - PARSE_JSON('{self._escape_input(json_content)}'), |
64 |
| - {{}} |
65 |
| - ) AS CORTEX_RESULT;""" |
66 |
| - raw_result = self._query_cortex_result_str(query) |
67 |
| - json_result = json.loads(raw_result) |
68 |
| - return self._extract_text_result_from_json_result( |
69 |
| - lambda: json_result["choices"][0]["messages"] |
| 95 | + prompt: Text, |
| 96 | + ) -> CompleteRequest: |
| 97 | + return CompleteRequest( |
| 98 | + model=str(model), |
| 99 | + messages=[CompleteRequestMessagesInner(content=str(prompt))], |
| 100 | + stream=True, |
70 | 101 | )
|
71 | 102 |
|
| 103 | + def rest_complete( |
| 104 | + self, |
| 105 | + text: Text, |
| 106 | + model: Model, |
| 107 | + root: "Root", |
| 108 | + ) -> str: |
| 109 | + complete_request = self.make_rest_complete_request(model=model, prompt=text) |
| 110 | + cortex_inference_service = CortexInferenceService(root=root) |
| 111 | + try: |
| 112 | + raw_resp = cortex_inference_service.complete( |
| 113 | + complete_request=complete_request |
| 114 | + ) |
| 115 | + except Exception as e: |
| 116 | + raise |
| 117 | + result = "" |
| 118 | + for event in raw_resp.events(): |
| 119 | + try: |
| 120 | + parsed_resp = json.loads(event.data) |
| 121 | + except json.JSONDecodeError: |
| 122 | + raise ResponseParseError("Server response cannot be parsed") |
| 123 | + try: |
| 124 | + result += parsed_resp["choices"][0]["delta"]["content"] |
| 125 | + except (json.JSONDecodeError, KeyError, IndexError): |
| 126 | + if parsed_resp.get("error"): |
| 127 | + raise MidStreamError(reason=event.data) |
| 128 | + else: |
| 129 | + pass |
| 130 | + return result |
| 131 | + |
72 | 132 | def extract_answer_from_source_document(
|
73 | 133 | self,
|
74 | 134 | source_document: SourceDocument,
|
@@ -170,7 +230,7 @@ def _escape_input(plain_input: str):
|
170 | 230 |
|
171 | 231 | @staticmethod
|
172 | 232 | def _extract_text_result_from_json_result(
|
173 |
| - extract_function: Callable[[], str] |
| 233 | + extract_function: Callable[[], str], |
174 | 234 | ) -> str:
|
175 | 235 | try:
|
176 | 236 | return extract_function()
|
|
0 commit comments