2222
2323import org .junit .jupiter .api .Test ;
2424
25+ import org .springframework .ai .chat .model .ToolContext ;
2526import org .springframework .ai .tool .definition .DefaultToolDefinition ;
2627import org .springframework .ai .tool .definition .ToolDefinition ;
2728
2829import static org .assertj .core .api .Assertions .assertThat ;
30+ import static org .assertj .core .api .Assertions .assertThatThrownBy ;
2931
3032/**
3133 * Tests for {@link MethodToolCallback} with generic types.
@@ -137,6 +139,76 @@ void testNestedGenericType() throws Exception {
137139 assertThat (result ).isEqualTo ("2 maps processed: [{a=1, b=2}, {c=3, d=4}]" );
138140 }
139141
142+ @ Test
143+ void testToolContextType () throws Exception {
144+ // Create a test object with a method that takes a List<Map<String, Integer>>
145+ TestGenericClass testObject = new TestGenericClass ();
146+ Method method = TestGenericClass .class .getMethod ("processStringListInToolContext" , ToolContext .class );
147+
148+ // Create a tool definition
149+ ToolDefinition toolDefinition = DefaultToolDefinition .builder ()
150+ .name ("processToolContext" )
151+ .description ("Process tool context" )
152+ .inputSchema ("{}" )
153+ .build ();
154+
155+ // Create a MethodToolCallback
156+ MethodToolCallback callback = MethodToolCallback .builder ()
157+ .toolDefinition (toolDefinition )
158+ .toolMethod (method )
159+ .toolObject (testObject )
160+ .build ();
161+
162+ // Create an empty JSON input
163+ String toolInput = """
164+ {}
165+ """ ;
166+
167+ // Create a toolContext
168+ ToolContext toolContext = new ToolContext (Map .of ("foo" , "bar" ));
169+
170+ // Call the tool
171+ String result = callback .call (toolInput , toolContext );
172+
173+ // Verify the result
174+ assertThat (result ).isEqualTo ("1 entries processed {foo=bar}" );
175+ }
176+
177+ @ Test
178+ void testToolContextTypeWithNonToolContextArgs () throws Exception {
179+ // Create a test object with a method that takes a List<String>
180+ TestGenericClass testObject = new TestGenericClass ();
181+ Method method = TestGenericClass .class .getMethod ("processStringList" , List .class );
182+
183+ // Create a tool definition
184+ ToolDefinition toolDefinition = DefaultToolDefinition .builder ()
185+ .name ("processStringList" )
186+ .description ("Process a list of strings" )
187+ .inputSchema ("{}" )
188+ .build ();
189+
190+ // Create a MethodToolCallback
191+ MethodToolCallback callback = MethodToolCallback .builder ()
192+ .toolDefinition (toolDefinition )
193+ .toolMethod (method )
194+ .toolObject (testObject )
195+ .build ();
196+
197+ // Create a JSON input with a list of strings
198+ String toolInput = """
199+ {
200+ "strings": ["one", "two", "three"]
201+ }
202+ """ ;
203+
204+ // Create a toolContext
205+ ToolContext toolContext = new ToolContext (Map .of ("foo" , "bar" ));
206+
207+ // Call the tool and verify
208+ assertThatThrownBy (() -> callback .call (toolInput , toolContext )).isInstanceOf (IllegalArgumentException .class )
209+ .hasMessageContaining ("ToolContext is required by the method as an argument" );
210+ }
211+
140212 /**
141213 * Test class with methods that use generic types.
142214 */
@@ -154,6 +226,11 @@ public String processListOfMaps(List<Map<String, Integer>> listOfMaps) {
154226 return listOfMaps .size () + " maps processed: " + listOfMaps ;
155227 }
156228
229+ public String processStringListInToolContext (ToolContext toolContext ) {
230+ Map <String , Object > context = toolContext .getContext ();
231+ return context .size () + " entries processed " + context ;
232+ }
233+
157234 }
158235
159236}
0 commit comments