Skip to content

Commit 1768ed7

Browse files
committed
function calling
1 parent e11f1da commit 1768ed7

File tree

10 files changed

+361
-17
lines changed

10 files changed

+361
-17
lines changed

community/gemini/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ python = ">=3.10,<4.0"
1313
restack-ai = "^0.0.52"
1414
google-genai = "0.5.0"
1515
watchfiles = "^1.0.0"
16+
pydantic = "^2.10.5"
1617

1718
[build-system]
1819
requires = ["poetry-core"]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from restack_ai.function import function, log
2+
from pydantic import BaseModel
3+
from google import genai
4+
from google.genai import types
5+
6+
import os
7+
8+
@function.defn()
9+
def get_current_weather(location: str) -> str:
10+
"""Returns the current weather.
11+
12+
Args:
13+
location: The city and state, e.g. San Francisco, CA
14+
"""
15+
log.info("get_current_weather function started", location=location)
16+
return 'sunny'
17+
18+
class FunctionInputParams(BaseModel):
19+
user_content: str
20+
21+
@function.defn()
22+
async def gemini_function_call(input: FunctionInputParams) -> str:
23+
try:
24+
log.info("gemini_function_call function started", input=input)
25+
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
26+
27+
response = client.models.generate_content(
28+
model='gemini-2.0-flash-exp',
29+
contents=input.user_content,
30+
config=types.GenerateContentConfig(tools=[get_current_weather])
31+
)
32+
log.info("gemini_function_call function completed", response=response.text)
33+
return response.text
34+
except Exception as e:
35+
log.error("gemini_function_call function failed", error=e)
36+
raise e

community/gemini/src/functions/function.py renamed to community/gemini/src/functions/generate_content.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
from restack_ai.function import function, log
2-
from dataclasses import dataclass
3-
import google.generativeai as genai
2+
from pydantic import BaseModel
3+
from google import genai
4+
from google.genai import types
45

56
import os
67

7-
@dataclass
8-
class FunctionInputParams:
8+
class FunctionInputParams(BaseModel):
99
user_content: str
1010

1111
@function.defn()
1212
async def gemini_generate_content(input: FunctionInputParams) -> str:
1313
try:
1414
log.info("gemini_generate_content function started", input=input)
15-
genai.configure(api_key=os.environ.get("GEMINI_API_KEY"))
16-
model = genai.GenerativeModel("gemini-2.0-flash-exp")
17-
18-
response = model.generate_content(input.user_content)
15+
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
16+
17+
response = client.models.generate_content(
18+
model='gemini-2.0-flash-exp',
19+
contents=input.user_content
20+
)
1921
log.info("gemini_generate_content function completed", response=response.text)
2022
return response.text
2123
except Exception as e:
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from restack_ai.function import function, log
2+
from pydantic import BaseModel
3+
from google import genai
4+
from google.genai import types
5+
6+
import os
7+
8+
@function.defn()
9+
def get_current_weather(location: str) -> str:
10+
"""Returns the current weather.
11+
12+
Args:
13+
location: The city and state, e.g. San Francisco, CA
14+
"""
15+
log.info("get_current_weather function started", location=location)
16+
return 'sunny'
17+
18+
@function.defn()
19+
def get_humidity(location: str) -> str:
20+
"""Returns the current humidity.
21+
22+
Args:
23+
location: The city and state, e.g. San Francisco, CA
24+
"""
25+
log.info("get_humidity function started", location=location)
26+
return '65%'
27+
28+
@function.defn()
29+
def get_air_quality(location: str) -> str:
30+
"""Returns the current air quality.
31+
32+
Args:
33+
location: The city and state, e.g. San Francisco, CA
34+
"""
35+
log.info("get_air_quality function started", location=location)
36+
return 'good'
37+
38+
class FunctionInputParams(BaseModel):
39+
user_content: str
40+
41+
@function.defn()
42+
async def gemini_multi_function_call(input: FunctionInputParams) -> str:
43+
try:
44+
log.info("gemini_multi_function_call function started", input=input)
45+
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
46+
47+
response = client.models.generate_content(
48+
model='gemini-2.0-flash-exp',
49+
contents=input.user_content,
50+
config=types.GenerateContentConfig(tools=[get_current_weather, get_humidity, get_air_quality])
51+
)
52+
log.info("gemini_multi_function_call function completed", response=response.text)
53+
return response.text
54+
except Exception as e:
55+
log.error("gemini_multi_function_call function failed", error=e)
56+
raise e
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from restack_ai.function import function, log
2+
from pydantic import BaseModel
3+
from google import genai
4+
from google.genai import types
5+
from typing import List, Optional
6+
7+
import os
8+
9+
class ChatMessage(BaseModel):
10+
role: str
11+
content: str
12+
13+
class FunctionInputParams(BaseModel):
14+
user_content: str
15+
chat_history: Optional[List[ChatMessage]] = None
16+
17+
class WeatherInput(BaseModel):
18+
location: str
19+
20+
class HumidityInput(BaseModel):
21+
location: str
22+
23+
class AirQualityInput(BaseModel):
24+
location: str
25+
26+
@function.defn()
27+
async def get_current_weather(input: WeatherInput) -> str:
28+
log.info("get_current_weather function started", location=input.location)
29+
return 'sunny'
30+
31+
@function.defn()
32+
async def get_humidity(input: HumidityInput) -> str:
33+
log.info("get_humidity function started", location=input.location)
34+
return '65%'
35+
36+
@function.defn()
37+
async def get_air_quality(input: AirQualityInput) -> str:
38+
log.info("get_air_quality function started", location=input.location)
39+
return 'good'
40+
41+
@function.defn()
42+
async def gemini_multi_function_call_advanced(input: FunctionInputParams) :
43+
try:
44+
log.info("gemini_multi_function_call_advanced function started", input=input)
45+
client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY"))
46+
47+
functions = [
48+
{
49+
"name": "get_current_weather",
50+
"description": "Get the current weather in a given location",
51+
"parameters": {
52+
"type": "OBJECT",
53+
"properties": {
54+
"location": {
55+
"type": "STRING",
56+
"description": "The city and state, e.g. San Francisco, CA",
57+
},
58+
},
59+
"required": ["location"],
60+
}
61+
},
62+
{
63+
"name": "get_humidity",
64+
"description": "Get the current humidity in a given location",
65+
"parameters": {
66+
"type": "OBJECT",
67+
"properties": {
68+
"location": {
69+
"type": "STRING",
70+
"description": "The city and state, e.g. San Francisco, CA",
71+
},
72+
},
73+
"required": ["location"],
74+
}
75+
},
76+
{
77+
"name": "get_air_quality",
78+
"description": "Get the current air quality in a given location",
79+
"parameters": {
80+
"type": "OBJECT",
81+
"properties": {
82+
"location": {
83+
"type": "STRING",
84+
"description": "The city and state, e.g. San Francisco, CA",
85+
},
86+
},
87+
"required": ["location"],
88+
}
89+
}
90+
]
91+
92+
tools = [types.Tool(function_declarations=functions)]
93+
94+
response = client.models.generate_content(
95+
model='gemini-2.0-flash-exp',
96+
contents=[input.user_content] + ([msg.content for msg in input.chat_history] if input.chat_history else []),
97+
config=types.GenerateContentConfig(
98+
tools=tools
99+
)
100+
)
101+
return response
102+
103+
except Exception as e:
104+
log.error("Error in gemini_multi_function_call_advanced", error=str(e))
105+
raise e

community/gemini/src/services.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,23 @@
44
import os
55

66
from src.client import client
7-
from src.functions.function import gemini_generate_content
8-
from src.workflows.gemini_generate_content import GeminiGenerateContentWorkflow
7+
from src.functions.generate_content import gemini_generate_content
8+
from src.workflows.generate_content import GeminiGenerateContentWorkflow
9+
10+
from src.workflows.function_call import GeminiFunctionCallWorkflow
11+
from src.functions.function_call import gemini_function_call
12+
13+
from src.workflows.multi_function_call import GeminiMultiFunctionCallWorkflow
14+
from src.functions.multi_function_call import gemini_multi_function_call
15+
16+
from src.workflows.multi_function_call_advanced import GeminiMultiFunctionCallAdvancedWorkflow
17+
from src.functions.multi_function_call_advanced import gemini_multi_function_call_advanced
18+
from src.functions.multi_function_call_advanced import get_current_weather, get_humidity, get_air_quality
19+
920
async def main():
1021
await client.start_service(
11-
workflows= [GeminiGenerateContentWorkflow],
12-
functions= [gemini_generate_content]
22+
workflows= [GeminiGenerateContentWorkflow, GeminiFunctionCallWorkflow, GeminiMultiFunctionCallWorkflow, GeminiMultiFunctionCallAdvancedWorkflow],
23+
functions= [gemini_generate_content, gemini_function_call, gemini_multi_function_call, gemini_multi_function_call_advanced, get_current_weather, get_humidity, get_air_quality]
1324
)
1425

1526
def run_services():
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from restack_ai.workflow import workflow, import_functions, log, RetryPolicy
2+
from pydantic import BaseModel
3+
from datetime import timedelta
4+
5+
with import_functions():
6+
from src.functions.function_call import gemini_function_call, FunctionInputParams
7+
8+
class WorkflowInputParams(BaseModel):
9+
user_content: str = "what's the weather in San Francisco?"
10+
11+
@workflow.defn()
12+
class GeminiFunctionCallWorkflow:
13+
@workflow.run
14+
async def run(self, input: WorkflowInputParams):
15+
log.info("GeminiFunctionCallWorkflow started", input=input)
16+
result = await workflow.step(
17+
gemini_function_call,
18+
FunctionInputParams(user_content=input.user_content),
19+
start_to_close_timeout=timedelta(seconds=120),
20+
retry_policy=RetryPolicy(
21+
maximum_attempts=1
22+
)
23+
)
24+
log.info("GeminiFunctionCallWorkflow completed", result=result)
25+
return result

community/gemini/src/workflows/gemini_generate_content.py renamed to community/gemini/src/workflows/generate_content.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from restack_ai.workflow import workflow, import_functions, log, RetryPolicy
2-
from dataclasses import dataclass
2+
from pydantic import BaseModel
33
from datetime import timedelta
44

55
with import_functions():
6-
from src.functions.function import gemini_generate_content, FunctionInputParams
6+
from src.functions.generate_content import gemini_generate_content, FunctionInputParams
77

8-
@dataclass
9-
class WorkflowInputParams:
10-
user_content: str
8+
class WorkflowInputParams(BaseModel):
9+
user_content: str = "what's the weather in San Francisco?"
1110

1211
@workflow.defn()
1312
class GeminiGenerateContentWorkflow:
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from restack_ai.workflow import workflow, import_functions, log, RetryPolicy
2+
from pydantic import BaseModel
3+
from datetime import timedelta
4+
5+
with import_functions():
6+
from src.functions.multi_function_call import gemini_multi_function_call, FunctionInputParams
7+
8+
class WorkflowInputParams(BaseModel):
9+
user_content: str = "what's the weather in San Francisco?"
10+
11+
@workflow.defn()
12+
class GeminiMultiFunctionCallWorkflow:
13+
@workflow.run
14+
async def run(self, input: WorkflowInputParams):
15+
log.info("GeminiMultiFunctionCallWorkflow started", input=input)
16+
result = await workflow.step(
17+
gemini_multi_function_call,
18+
FunctionInputParams(user_content=input.user_content),
19+
start_to_close_timeout=timedelta(seconds=120),
20+
retry_policy=RetryPolicy(
21+
maximum_attempts=1
22+
)
23+
)
24+
log.info("GeminiMultiFunctionCallWorkflow completed", result=result)
25+
return result

0 commit comments

Comments
 (0)