|
1 | | -from restack_ai.function import function, log, FunctionFailure |
2 | | -from openai import OpenAI |
3 | | -from dataclasses import dataclass |
4 | 1 | import os |
| 2 | +from dataclasses import dataclass |
| 3 | + |
5 | 4 | from dotenv import load_dotenv |
| 5 | +from openai import OpenAI |
| 6 | +from restack_ai.function import FunctionFailure, function, log |
6 | 7 |
|
7 | 8 | load_dotenv() |
8 | 9 |
|
| 10 | + |
9 | 11 | @dataclass |
10 | 12 | class FunctionInputParams: |
11 | 13 | user_content: str |
12 | 14 | system_content: str | None = None |
13 | 15 | model: str | None = None |
14 | 16 |
|
| 17 | + |
| 18 | +def raise_exception(message: str) -> None: |
| 19 | + log.error(message) |
| 20 | + raise Exception(message) |
| 21 | + |
| 22 | + |
15 | 23 | @function.defn() |
16 | | -async def llm(input: FunctionInputParams) -> str: |
| 24 | +async def llm(function_input: FunctionInputParams) -> str: |
17 | 25 | try: |
18 | | - log.info("llm function started", input=input) |
| 26 | + log.info("llm function started", input=function_input) |
19 | 27 |
|
20 | | - if (os.environ.get("RESTACK_API_KEY") is None): |
21 | | - raise FunctionFailure("RESTACK_API_KEY is not set", non_retryable=True) |
22 | | - |
23 | | - client = OpenAI(base_url="https://ai.restack.io", api_key=os.environ.get("RESTACK_API_KEY")) |
| 28 | + if os.environ.get("RESTACK_API_KEY") is None: |
| 29 | + error_message = "RESTACK_API_KEY is not set" |
| 30 | + raise_exception(error_message) |
| 31 | + |
| 32 | + client = OpenAI( |
| 33 | + base_url="https://ai.restack.io", api_key=os.environ.get("RESTACK_API_KEY") |
| 34 | + ) |
24 | 35 |
|
25 | 36 | messages = [] |
26 | | - if input.system_content: |
27 | | - messages.append({"role": "system", "content": input.system_content}) |
28 | | - messages.append({"role": "user", "content": input.user_content}) |
| 37 | + if function_input.system_content: |
| 38 | + messages.append( |
| 39 | + {"role": "system", "content": function_input.system_content} |
| 40 | + ) |
| 41 | + messages.append({"role": "user", "content": function_input.user_content}) |
29 | 42 |
|
30 | 43 | response = client.chat.completions.create( |
31 | | - model=input.model or "gpt-4o-mini", |
32 | | - messages=messages |
| 44 | + model=function_input.model or "gpt-4o-mini", messages=messages |
33 | 45 | ) |
34 | 46 | log.info("llm function completed", response=response) |
35 | 47 | return response.choices[0].message.content |
36 | 48 | except Exception as e: |
37 | | - log.error("llm function failed", error=e) |
38 | | - raise e |
| 49 | + error_message = "llm function failed" |
| 50 | + raise FunctionFailure(error_message, non_retryable=True) from e |
0 commit comments