forked from livekit/agents
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrestaurant_agent.py
More file actions
347 lines (283 loc) ยท 12.5 KB
/
restaurant_agent.py
File metadata and controls
347 lines (283 loc) ยท 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import logging
from dataclasses import dataclass, field
from typing import Annotated
import yaml
from dotenv import load_dotenv
from pydantic import Field
from livekit.agents import AgentServer, JobContext, cli
from livekit.agents.llm import function_tool
from livekit.agents.voice import Agent, AgentSession, RunContext
from livekit.plugins import cartesia, deepgram, openai, silero
# from livekit.plugins import noise_cancellation
# This example demonstrates a multi-agent system where tasks are delegated to sub-agents
# based on the user's request.
#
# The user is initially connected to a greeter, and depending on their need, the call is
# handed off to other agents that could help with the more specific tasks.
# This helps to keep each agent focused on the task at hand, and also reduces costs
# since only a subset of the tools are used at any given time.
logger = logging.getLogger("restaurant-example")
logger.setLevel(logging.INFO)
load_dotenv()
voices = {
"greeter": "694f9389-aac1-45b6-b726-9d9369183238",
"reservation": "156fb8d2-335b-4950-9cb3-a2d33befec77",
"takeaway": "6f84f4b8-58a2-430c-8c79-688dad597532",
"checkout": "39b376fc-488e-4d0c-8b37-e00b72059fdd",
}
@dataclass
class UserData:
customer_name: str | None = None
customer_phone: str | None = None
reservation_time: str | None = None
order: list[str] | None = None
customer_credit_card: str | None = None
customer_credit_card_expiry: str | None = None
customer_credit_card_cvv: str | None = None
expense: float | None = None
checked_out: bool | None = None
agents: dict[str, Agent] = field(default_factory=dict)
prev_agent: Agent | None = None
def summarize(self) -> str:
data = {
"customer_name": self.customer_name or "unknown",
"customer_phone": self.customer_phone or "unknown",
"reservation_time": self.reservation_time or "unknown",
"order": self.order or "unknown",
"credit_card": {
"number": self.customer_credit_card or "unknown",
"expiry": self.customer_credit_card_expiry or "unknown",
"cvv": self.customer_credit_card_cvv or "unknown",
}
if self.customer_credit_card
else None,
"expense": self.expense or "unknown",
"checked_out": self.checked_out or False,
}
# summarize in yaml performs better than json
return yaml.dump(data)
RunContext_T = RunContext[UserData]
# common functions
@function_tool()
async def update_name(
name: Annotated[str, Field(description="The customer's name")],
context: RunContext_T,
) -> str:
"""Called when the user provides their name.
Confirm the spelling with the user before calling the function."""
userdata = context.userdata
userdata.customer_name = name
return f"The name is updated to {name}"
@function_tool()
async def update_phone(
phone: Annotated[str, Field(description="The customer's phone number")],
context: RunContext_T,
) -> str:
"""Called when the user provides their phone number.
Confirm the spelling with the user before calling the function."""
userdata = context.userdata
userdata.customer_phone = phone
return f"The phone number is updated to {phone}"
@function_tool()
async def to_greeter(context: RunContext_T) -> Agent:
"""Called when user asks any unrelated questions or requests
any other services not in your job description."""
curr_agent: BaseAgent = context.session.current_agent
return await curr_agent._transfer_to_agent("greeter", context)
class BaseAgent(Agent):
async def on_enter(self) -> None:
agent_name = self.__class__.__name__
logger.info(f"entering task {agent_name}")
userdata: UserData = self.session.userdata
chat_ctx = self.chat_ctx.copy()
# add the previous agent's chat history to the current agent
if isinstance(userdata.prev_agent, Agent):
truncated_chat_ctx = userdata.prev_agent.chat_ctx.copy(
exclude_instructions=True,
exclude_function_call=False,
exclude_handoff=True,
exclude_config_update=True,
).truncate(max_items=6)
existing_ids = {item.id for item in chat_ctx.items}
items_copy = [item for item in truncated_chat_ctx.items if item.id not in existing_ids]
chat_ctx.items.extend(items_copy)
# add an instructions including the user data as assistant message
chat_ctx.add_message(
role="system", # role=system works for OpenAI's LLM and Realtime API
content=f"You are {agent_name} agent. Current user data is {userdata.summarize()}",
)
await self.update_chat_ctx(chat_ctx)
self.session.generate_reply(tool_choice="none")
async def _transfer_to_agent(self, name: str, context: RunContext_T) -> tuple[Agent, str]:
userdata = context.userdata
current_agent = context.session.current_agent
next_agent = userdata.agents[name]
userdata.prev_agent = current_agent
return next_agent, f"Transferring to {name}."
class Greeter(BaseAgent):
def __init__(self, menu: str) -> None:
super().__init__(
instructions=(
f"You are a friendly restaurant receptionist. The menu is: {menu}\n"
"Your jobs are to greet the caller and understand if they want to "
"make a reservation or order takeaway. Guide them to the right agent using tools."
),
llm=openai.LLM(parallel_tool_calls=False),
tts=cartesia.TTS(voice=voices["greeter"]),
)
self.menu = menu
@function_tool()
async def to_reservation(self, context: RunContext_T) -> tuple[Agent, str]:
"""Called when user wants to make or update a reservation.
This function handles transitioning to the reservation agent
who will collect the necessary details like reservation time,
customer name and phone number."""
return await self._transfer_to_agent("reservation", context)
@function_tool()
async def to_takeaway(self, context: RunContext_T) -> tuple[Agent, str]:
"""Called when the user wants to place a takeaway order.
This includes handling orders for pickup, delivery, or when the user wants to
proceed to checkout with their existing order."""
return await self._transfer_to_agent("takeaway", context)
class Reservation(BaseAgent):
def __init__(self) -> None:
super().__init__(
instructions="You are a reservation agent at a restaurant. Your jobs are to ask for "
"the reservation time, then customer's name, and phone number. Then "
"confirm the reservation details with the customer.",
tools=[update_name, update_phone, to_greeter],
tts=cartesia.TTS(voice=voices["reservation"]),
)
@function_tool()
async def update_reservation_time(
self,
time: Annotated[str, Field(description="The reservation time")],
context: RunContext_T,
) -> str:
"""Called when the user provides their reservation time.
Confirm the time with the user before calling the function."""
userdata = context.userdata
userdata.reservation_time = time
return f"The reservation time is updated to {time}"
@function_tool()
async def confirm_reservation(self, context: RunContext_T) -> str | tuple[Agent, str]:
"""Called when the user confirms the reservation."""
userdata = context.userdata
if not userdata.customer_name or not userdata.customer_phone:
return "Please provide your name and phone number first."
if not userdata.reservation_time:
return "Please provide reservation time first."
return await self._transfer_to_agent("greeter", context)
class Takeaway(BaseAgent):
def __init__(self, menu: str) -> None:
super().__init__(
instructions=(
f"Your are a takeaway agent that takes orders from the customer. "
f"Our menu is: {menu}\n"
"Clarify special requests and confirm the order with the customer."
),
tools=[to_greeter],
tts=cartesia.TTS(voice=voices["takeaway"]),
)
@function_tool()
async def update_order(
self,
items: Annotated[list[str], Field(description="The items of the full order")],
context: RunContext_T,
) -> str:
"""Called when the user create or update their order."""
userdata = context.userdata
userdata.order = items
return f"The order is updated to {items}"
@function_tool()
async def to_checkout(self, context: RunContext_T) -> str | tuple[Agent, str]:
"""Called when the user confirms the order."""
userdata = context.userdata
if not userdata.order:
return "No takeaway order found. Please make an order first."
return await self._transfer_to_agent("checkout", context)
class Checkout(BaseAgent):
def __init__(self, menu: str) -> None:
super().__init__(
instructions=(
f"You are a checkout agent at a restaurant. The menu is: {menu}\n"
"Your are responsible for confirming the expense of the "
"order and then collecting customer's name, phone number and credit card "
"information, including the card number, expiry date, and CVV step by step."
),
tools=[update_name, update_phone, to_greeter],
tts=cartesia.TTS(voice=voices["checkout"]),
)
@function_tool()
async def confirm_expense(
self,
expense: Annotated[float, Field(description="The expense of the order")],
context: RunContext_T,
) -> str:
"""Called when the user confirms the expense."""
userdata = context.userdata
userdata.expense = expense
return f"The expense is confirmed to be {expense}"
@function_tool()
async def update_credit_card(
self,
number: Annotated[str, Field(description="The credit card number")],
expiry: Annotated[str, Field(description="The expiry date of the credit card")],
cvv: Annotated[str, Field(description="The CVV of the credit card")],
context: RunContext_T,
) -> str:
"""Called when the user provides their credit card number, expiry date, and CVV.
Confirm the spelling with the user before calling the function."""
userdata = context.userdata
userdata.customer_credit_card = number
userdata.customer_credit_card_expiry = expiry
userdata.customer_credit_card_cvv = cvv
return f"The credit card number is updated to {number}"
@function_tool()
async def confirm_checkout(self, context: RunContext_T) -> str | tuple[Agent, str]:
"""Called when the user confirms the checkout."""
userdata = context.userdata
if not userdata.expense:
return "Please confirm the expense first."
if (
not userdata.customer_credit_card
or not userdata.customer_credit_card_expiry
or not userdata.customer_credit_card_cvv
):
return "Please provide the credit card information first."
userdata.checked_out = True
return await to_greeter(context)
@function_tool()
async def to_takeaway(self, context: RunContext_T) -> tuple[Agent, str]:
"""Called when the user wants to update their order."""
return await self._transfer_to_agent("takeaway", context)
server = AgentServer()
@server.rtc_session()
async def entrypoint(ctx: JobContext):
menu = "Pizza: $10, Salad: $5, Ice Cream: $3, Coffee: $2"
userdata = UserData()
userdata.agents.update(
{
"greeter": Greeter(menu),
"reservation": Reservation(),
"takeaway": Takeaway(menu),
"checkout": Checkout(menu),
}
)
session = AgentSession[UserData](
userdata=userdata,
stt=deepgram.STT(),
llm=openai.LLM(),
tts=cartesia.TTS(),
vad=silero.VAD.load(),
max_tool_steps=5,
# to use realtime model, replace the stt, llm, tts and vad with the following
# llm=openai.realtime.RealtimeModel(voice="alloy"),
)
await session.start(
agent=userdata.agents["greeter"],
room=ctx.room,
)
# await agent.say("Welcome to our restaurant! How may I assist you today?")
if __name__ == "__main__":
cli.run_app(server)