1
+ import asyncio
1
2
import unittest .mock
2
3
from unittest .mock import ANY , MagicMock , call
3
4
11
12
from tests .fixtures .mocked_model_provider import MockedModelProvider
12
13
13
14
15
+ @strands .tool
16
+ def normal_tool (agent : Agent ):
17
+ return f"Done with synchronous { agent .name } !"
18
+
19
+
20
+ @strands .tool
21
+ async def async_tool (agent : Agent ):
22
+ await asyncio .sleep (0.1 )
23
+ return f"Done with asynchronous { agent .name } !"
24
+
25
+
26
+ @strands .tool
27
+ async def streaming_tool ():
28
+ await asyncio .sleep (0.2 )
29
+ yield {"tool_streaming" : True }
30
+ yield "Final result"
31
+
32
+
14
33
@pytest .fixture
15
34
def mock_time ():
16
35
with unittest .mock .patch .object (strands .event_loop .event_loop , "time" ) as mock :
17
36
yield mock
18
37
19
38
20
- @pytest .mark .asyncio
21
- async def test_stream_async_e2e (alist , mock_time ):
22
- @strands .tool
23
- def fake_tool (agent : Agent ):
24
- return "Done!"
39
+ any_props = {
40
+ "agent" : ANY ,
41
+ "event_loop_cycle_id" : ANY ,
42
+ "event_loop_cycle_span" : ANY ,
43
+ "event_loop_cycle_trace" : ANY ,
44
+ "request_state" : {},
45
+ }
46
+
25
47
48
+ @pytest .mark .asyncio
49
+ async def test_stream_e2e_success (alist ):
26
50
mock_provider = MockedModelProvider (
27
51
[
28
- {"redactedUserContent" : "BLOCKED!" , "redactedAssistantContent" : "INPUT BLOCKED!" },
29
- {"role" : "assistant" , "content" : [{"text" : "Okay invoking tool!" }]},
30
52
{
31
53
"role" : "assistant" ,
32
- "content" : [{"toolUse" : {"name" : "fake_tool" , "toolUseId" : "123" , "input" : {}}}],
54
+ "content" : [
55
+ {"text" : "Okay invoking normal tool" },
56
+ {"toolUse" : {"name" : "normal_tool" , "toolUseId" : "123" , "input" : {}}},
57
+ ],
58
+ },
59
+ {
60
+ "role" : "assistant" ,
61
+ "content" : [
62
+ {"text" : "Invoking async tool" },
63
+ {"toolUse" : {"name" : "async_tool" , "toolUseId" : "1234" , "input" : {}}},
64
+ ],
65
+ },
66
+ {
67
+ "role" : "assistant" ,
68
+ "content" : [
69
+ {"text" : "Invoking streaming tool" },
70
+ {"toolUse" : {"name" : "streaming_tool" , "toolUseId" : "12345" , "input" : {}}},
71
+ ],
72
+ },
73
+ {
74
+ "role" : "assistant" ,
75
+ "content" : [
76
+ {"text" : "I invoked the tools!" },
77
+ ],
33
78
},
34
- {"role" : "assistant" , "content" : [{"text" : "I invoked a tool!" }]},
35
79
]
36
80
)
81
+
82
+ mock_callback = unittest .mock .Mock ()
83
+ agent = Agent (model = mock_provider , tools = [async_tool , normal_tool , streaming_tool ], callback_handler = mock_callback )
84
+
85
+ stream = agent .stream_async ("Do the stuff" , arg1 = 1013 )
86
+
87
+ tool_config = {
88
+ "toolChoice" : {"auto" : {}},
89
+ "tools" : [
90
+ {
91
+ "toolSpec" : {
92
+ "description" : "async_tool" ,
93
+ "inputSchema" : {"json" : {"properties" : {}, "required" : [], "type" : "object" }},
94
+ "name" : "async_tool" ,
95
+ }
96
+ },
97
+ {
98
+ "toolSpec" : {
99
+ "description" : "normal_tool" ,
100
+ "inputSchema" : {"json" : {"properties" : {}, "required" : [], "type" : "object" }},
101
+ "name" : "normal_tool" ,
102
+ }
103
+ },
104
+ {
105
+ "toolSpec" : {
106
+ "description" : "streaming_tool" ,
107
+ "inputSchema" : {"json" : {"properties" : {}, "required" : [], "type" : "object" }},
108
+ "name" : "streaming_tool" ,
109
+ }
110
+ },
111
+ ],
112
+ }
113
+
114
+ tru_events = await alist (stream )
115
+ exp_events = [
116
+ # Cycle 1: Initialize and invoke normal_tool
117
+ {"arg1" : 1013 , "init_event_loop" : True },
118
+ {"start" : True },
119
+ {"start_event_loop" : True },
120
+ {"event" : {"messageStart" : {"role" : "assistant" }}},
121
+ {"event" : {"contentBlockStart" : {"start" : {}}}},
122
+ {"event" : {"contentBlockDelta" : {"delta" : {"text" : "Okay invoking normal tool" }}}},
123
+ {
124
+ ** any_props ,
125
+ "arg1" : 1013 ,
126
+ "data" : "Okay invoking normal tool" ,
127
+ "delta" : {"text" : "Okay invoking normal tool" },
128
+ },
129
+ {"event" : {"contentBlockStop" : {}}},
130
+ {"event" : {"contentBlockStart" : {"start" : {"toolUse" : {"name" : "normal_tool" , "toolUseId" : "123" }}}}},
131
+ {"event" : {"contentBlockDelta" : {"delta" : {"toolUse" : {"input" : "{}" }}}}},
132
+ {
133
+ ** any_props ,
134
+ "arg1" : 1013 ,
135
+ "current_tool_use" : {"input" : {}, "name" : "normal_tool" , "toolUseId" : "123" },
136
+ "delta" : {"toolUse" : {"input" : "{}" }},
137
+ },
138
+ {"event" : {"contentBlockStop" : {}}},
139
+ {"event" : {"messageStop" : {"stopReason" : "tool_use" }}},
140
+ {
141
+ "message" : {
142
+ "content" : [
143
+ {"text" : "Okay invoking normal tool" },
144
+ {"toolUse" : {"input" : {}, "name" : "normal_tool" , "toolUseId" : "123" }},
145
+ ],
146
+ "role" : "assistant" ,
147
+ }
148
+ },
149
+ {
150
+ "message" : {
151
+ "content" : [
152
+ {
153
+ "toolResult" : {
154
+ "content" : [{"text" : "Done with synchronous Strands Agents!" }],
155
+ "status" : "success" ,
156
+ "toolUseId" : "123" ,
157
+ }
158
+ },
159
+ ],
160
+ "role" : "user" ,
161
+ }
162
+ },
163
+ # Cycle 2: Invoke async_tool
164
+ {"start" : True },
165
+ {"start" : True },
166
+ {"start_event_loop" : True },
167
+ {"event" : {"messageStart" : {"role" : "assistant" }}},
168
+ {"event" : {"contentBlockStart" : {"start" : {}}}},
169
+ {"event" : {"contentBlockDelta" : {"delta" : {"text" : "Invoking async tool" }}}},
170
+ {
171
+ ** any_props ,
172
+ "arg1" : 1013 ,
173
+ "data" : "Invoking async tool" ,
174
+ "delta" : {"text" : "Invoking async tool" },
175
+ "event_loop_parent_cycle_id" : ANY ,
176
+ "messages" : ANY ,
177
+ "model" : ANY ,
178
+ "system_prompt" : None ,
179
+ "tool_config" : tool_config ,
180
+ },
181
+ {"event" : {"contentBlockStop" : {}}},
182
+ {"event" : {"contentBlockStart" : {"start" : {"toolUse" : {"name" : "async_tool" , "toolUseId" : "1234" }}}}},
183
+ {"event" : {"contentBlockDelta" : {"delta" : {"toolUse" : {"input" : "{}" }}}}},
184
+ {
185
+ ** any_props ,
186
+ "arg1" : 1013 ,
187
+ "current_tool_use" : {"input" : {}, "name" : "async_tool" , "toolUseId" : "1234" },
188
+ "delta" : {"toolUse" : {"input" : "{}" }},
189
+ "event_loop_parent_cycle_id" : ANY ,
190
+ "messages" : ANY ,
191
+ "model" : ANY ,
192
+ "system_prompt" : None ,
193
+ "tool_config" : tool_config ,
194
+ },
195
+ {"event" : {"contentBlockStop" : {}}},
196
+ {"event" : {"messageStop" : {"stopReason" : "tool_use" }}},
197
+ {
198
+ "message" : {
199
+ "content" : [
200
+ {"text" : "Invoking async tool" },
201
+ {"toolUse" : {"input" : {}, "name" : "async_tool" , "toolUseId" : "1234" }},
202
+ ],
203
+ "role" : "assistant" ,
204
+ }
205
+ },
206
+ {
207
+ "message" : {
208
+ "content" : [
209
+ {
210
+ "toolResult" : {
211
+ "content" : [{"text" : "Done with asynchronous Strands Agents!" }],
212
+ "status" : "success" ,
213
+ "toolUseId" : "1234" ,
214
+ }
215
+ },
216
+ ],
217
+ "role" : "user" ,
218
+ }
219
+ },
220
+ # Cycle 3: Invoke streaming_tool
221
+ {"start" : True },
222
+ {"start" : True },
223
+ {"start_event_loop" : True },
224
+ {"event" : {"messageStart" : {"role" : "assistant" }}},
225
+ {"event" : {"contentBlockStart" : {"start" : {}}}},
226
+ {"event" : {"contentBlockDelta" : {"delta" : {"text" : "Invoking streaming tool" }}}},
227
+ {
228
+ ** any_props ,
229
+ "arg1" : 1013 ,
230
+ "data" : "Invoking streaming tool" ,
231
+ "delta" : {"text" : "Invoking streaming tool" },
232
+ "event_loop_parent_cycle_id" : ANY ,
233
+ "messages" : ANY ,
234
+ "model" : ANY ,
235
+ "system_prompt" : None ,
236
+ "tool_config" : tool_config ,
237
+ },
238
+ {"event" : {"contentBlockStop" : {}}},
239
+ {"event" : {"contentBlockStart" : {"start" : {"toolUse" : {"name" : "streaming_tool" , "toolUseId" : "12345" }}}}},
240
+ {"event" : {"contentBlockDelta" : {"delta" : {"toolUse" : {"input" : "{}" }}}}},
241
+ {
242
+ ** any_props ,
243
+ "arg1" : 1013 ,
244
+ "current_tool_use" : {"input" : {}, "name" : "streaming_tool" , "toolUseId" : "12345" },
245
+ "delta" : {"toolUse" : {"input" : "{}" }},
246
+ "event_loop_parent_cycle_id" : ANY ,
247
+ "messages" : ANY ,
248
+ "model" : ANY ,
249
+ "system_prompt" : None ,
250
+ "tool_config" : tool_config ,
251
+ },
252
+ {"event" : {"contentBlockStop" : {}}},
253
+ {"event" : {"messageStop" : {"stopReason" : "tool_use" }}},
254
+ {
255
+ "message" : {
256
+ "content" : [
257
+ {"text" : "Invoking streaming tool" },
258
+ {"toolUse" : {"input" : {}, "name" : "streaming_tool" , "toolUseId" : "12345" }},
259
+ ],
260
+ "role" : "assistant" ,
261
+ }
262
+ },
263
+ {
264
+ "message" : {
265
+ "content" : [
266
+ {
267
+ "toolResult" : {
268
+ # TODO update this text when we get tool streaming implemented; right now this
269
+ # TODO is of the form '<async_generator object streaming_tool at 0x107d18a00>'
270
+ "content" : [{"text" : ANY }],
271
+ "status" : "success" ,
272
+ "toolUseId" : "12345" ,
273
+ }
274
+ },
275
+ ],
276
+ "role" : "user" ,
277
+ }
278
+ },
279
+ # Cycle 4: Final response
280
+ {"start" : True },
281
+ {"start" : True },
282
+ {"start_event_loop" : True },
283
+ {"event" : {"messageStart" : {"role" : "assistant" }}},
284
+ {"event" : {"contentBlockStart" : {"start" : {}}}},
285
+ {"event" : {"contentBlockDelta" : {"delta" : {"text" : "I invoked the tools!" }}}},
286
+ {
287
+ ** any_props ,
288
+ "arg1" : 1013 ,
289
+ "data" : "I invoked the tools!" ,
290
+ "delta" : {"text" : "I invoked the tools!" },
291
+ "event_loop_parent_cycle_id" : ANY ,
292
+ "messages" : ANY ,
293
+ "model" : ANY ,
294
+ "system_prompt" : None ,
295
+ "tool_config" : tool_config ,
296
+ },
297
+ {"event" : {"contentBlockStop" : {}}},
298
+ {"event" : {"messageStop" : {"stopReason" : "end_turn" }}},
299
+ {"message" : {"content" : [{"text" : "I invoked the tools!" }], "role" : "assistant" }},
300
+ {
301
+ "result" : AgentResult (
302
+ stop_reason = "end_turn" ,
303
+ message = {"content" : [{"text" : "I invoked the tools!" }], "role" : "assistant" },
304
+ metrics = ANY ,
305
+ state = {},
306
+ )
307
+ },
308
+ ]
309
+ assert tru_events == exp_events
310
+
311
+ exp_calls = [call (** event ) for event in exp_events ]
312
+ act_calls = mock_callback .call_args_list
313
+ assert act_calls == exp_calls
314
+
315
+ # Ensure that all events coming out of the agent are *not* typed events
316
+ typed_events = [event for event in tru_events if isinstance (event , TypedEvent )]
317
+ assert typed_events == []
318
+
319
+
320
+ @pytest .mark .asyncio
321
+ async def test_stream_e2e_throttle_and_redact (alist , mock_time ):
37
322
model = MagicMock ()
38
323
model .stream .side_effect = [
39
324
ModelThrottledException ("ThrottlingException | ConverseStream" ),
40
325
ModelThrottledException ("ThrottlingException | ConverseStream" ),
41
- mock_provider .stream ([]),
326
+ MockedModelProvider (
327
+ [
328
+ {"redactedUserContent" : "BLOCKED!" , "redactedAssistantContent" : "INPUT BLOCKED!" },
329
+ ]
330
+ ).stream ([]),
42
331
]
43
332
44
333
mock_callback = unittest .mock .Mock ()
45
- agent = Agent (model = model , tools = [fake_tool ], callback_handler = mock_callback )
334
+ agent = Agent (model = model , tools = [normal_tool ], callback_handler = mock_callback )
46
335
47
336
stream = agent .stream_async ("Do the stuff" , arg1 = 1013 )
48
337
49
338
# Base object with common properties
50
339
throttle_props = {
51
- "agent" : ANY ,
52
- "event_loop_cycle_id" : ANY ,
53
- "event_loop_cycle_span" : ANY ,
54
- "event_loop_cycle_trace" : ANY ,
340
+ ** any_props ,
55
341
"arg1" : 1013 ,
56
- "request_state" : {},
57
342
}
58
343
59
344
tru_events = await alist (stream )
@@ -68,14 +353,10 @@ def fake_tool(agent: Agent):
68
353
{"event" : {"contentBlockStart" : {"start" : {}}}},
69
354
{"event" : {"contentBlockDelta" : {"delta" : {"text" : "INPUT BLOCKED!" }}}},
70
355
{
71
- "agent" : ANY ,
356
+ ** any_props ,
72
357
"arg1" : 1013 ,
73
358
"data" : "INPUT BLOCKED!" ,
74
359
"delta" : {"text" : "INPUT BLOCKED!" },
75
- "event_loop_cycle_id" : ANY ,
76
- "event_loop_cycle_span" : ANY ,
77
- "event_loop_cycle_trace" : ANY ,
78
- "request_state" : {},
79
360
},
80
361
{"event" : {"contentBlockStop" : {}}},
81
362
{"event" : {"messageStop" : {"stopReason" : "guardrail_intervened" }}},
@@ -128,12 +409,8 @@ async def test_event_loop_cycle_text_response_throttling_early_end(
128
409
129
410
# Base object with common properties
130
411
common_props = {
131
- "agent" : ANY ,
132
- "event_loop_cycle_id" : ANY ,
133
- "event_loop_cycle_span" : ANY ,
134
- "event_loop_cycle_trace" : ANY ,
412
+ ** any_props ,
135
413
"arg1" : 1013 ,
136
- "request_state" : {},
137
414
}
138
415
139
416
exp_events = [
0 commit comments