|
45 | 45 | },
|
46 | 46 | }]
|
47 | 47 |
|
| 48 | +PRODUCT_TOOLS = [{ |
| 49 | + "type": "function", |
| 50 | + "function": { |
| 51 | + "name": "get_product_info", |
| 52 | + "description": "Get detailed information of a product based on its " |
| 53 | + "product ID.", |
| 54 | + "parameters": { |
| 55 | + "type": "object", |
| 56 | + "properties": { |
| 57 | + "inserted": { |
| 58 | + "type": "boolean", |
| 59 | + "description": "inserted.", |
| 60 | + }, |
| 61 | + "product_id": { |
| 62 | + "type": "integer", |
| 63 | + "description": "The product ID of the product.", |
| 64 | + }, |
| 65 | + }, |
| 66 | + "required": ["product_id", "inserted"], |
| 67 | + }, |
| 68 | + }, |
| 69 | +}] |
| 70 | + |
48 | 71 | MESSAGES = [{"role": "user", "content": "What's the weather like in Boston?"}]
|
49 | 72 |
|
| 73 | +PRODUCT_MESSAGES = [{ |
| 74 | + "role": |
| 75 | + "user", |
| 76 | + "content": |
| 77 | + "Hi! Do you have any detailed information about the product id " |
| 78 | + "7355608 and inserted true?" |
| 79 | +}] |
| 80 | + |
50 | 81 |
|
51 | 82 | @pytest.mark.asyncio
|
52 | 83 | async def test_non_streaming_tool_call():
|
@@ -127,3 +158,103 @@ async def test_streaming_tool_call():
|
127 | 158 | print("\n[Streaming Test Passed]")
|
128 | 159 | print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
129 | 160 | print(f"Reconstructed Arguments: {arguments}")
|
| 161 | + |
| 162 | + |
| 163 | +@pytest.mark.asyncio |
| 164 | +async def test_non_streaming_product_tool_call(): |
| 165 | + """Test tool call integer and boolean parameters in non-streaming mode.""" |
| 166 | + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: |
| 167 | + client = server.get_async_client() |
| 168 | + |
| 169 | + response = await client.chat.completions.create( |
| 170 | + model=LORA_MODEL, |
| 171 | + messages=PRODUCT_MESSAGES, |
| 172 | + tools=PRODUCT_TOOLS, |
| 173 | + tool_choice="auto", |
| 174 | + temperature=0.66, |
| 175 | + ) |
| 176 | + |
| 177 | + assert response.choices |
| 178 | + choice = response.choices[0] |
| 179 | + message = choice.message |
| 180 | + |
| 181 | + assert choice.finish_reason == "tool_calls" |
| 182 | + assert message.tool_calls is not None |
| 183 | + |
| 184 | + tool_call = message.tool_calls[0] |
| 185 | + assert tool_call.type == "function" |
| 186 | + assert tool_call.function.name == "get_product_info" |
| 187 | + |
| 188 | + arguments = json.loads(tool_call.function.arguments) |
| 189 | + assert "product_id" in arguments |
| 190 | + assert "inserted" in arguments |
| 191 | + |
| 192 | + product_id = arguments.get("product_id") |
| 193 | + inserted = arguments.get("inserted") |
| 194 | + |
| 195 | + assert isinstance(product_id, int) |
| 196 | + assert product_id == 7355608 |
| 197 | + assert isinstance(inserted, bool) |
| 198 | + assert inserted is True |
| 199 | + |
| 200 | + print("\n[Non-Streaming Product Test Passed]") |
| 201 | + print(f"Tool Call: {tool_call.function.name}") |
| 202 | + print(f"Arguments: {arguments}") |
| 203 | + |
| 204 | + |
| 205 | +@pytest.mark.asyncio |
| 206 | +async def test_streaming_product_tool_call(): |
| 207 | + """Test tool call integer and boolean parameters in streaming mode.""" |
| 208 | + with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: |
| 209 | + client = server.get_async_client() |
| 210 | + |
| 211 | + stream = await client.chat.completions.create( |
| 212 | + model=LORA_MODEL, |
| 213 | + messages=PRODUCT_MESSAGES, |
| 214 | + tools=PRODUCT_TOOLS, |
| 215 | + tool_choice="auto", |
| 216 | + temperature=0.66, |
| 217 | + stream=True, |
| 218 | + ) |
| 219 | + |
| 220 | + tool_call_chunks = {} |
| 221 | + async for chunk in stream: |
| 222 | + if not chunk.choices: |
| 223 | + continue |
| 224 | + |
| 225 | + delta = chunk.choices[0].delta |
| 226 | + if not delta or not delta.tool_calls: |
| 227 | + continue |
| 228 | + |
| 229 | + for tool_chunk in delta.tool_calls: |
| 230 | + index = tool_chunk.index |
| 231 | + if index not in tool_call_chunks: |
| 232 | + tool_call_chunks[index] = {"name": "", "arguments": ""} |
| 233 | + |
| 234 | + if tool_chunk.function.name: |
| 235 | + tool_call_chunks[index]["name"] += tool_chunk.function.name |
| 236 | + if tool_chunk.function.arguments: |
| 237 | + tool_call_chunks[index][ |
| 238 | + "arguments"] += tool_chunk.function.arguments |
| 239 | + |
| 240 | + assert len(tool_call_chunks) == 1 |
| 241 | + reconstructed_tool_call = tool_call_chunks[0] |
| 242 | + |
| 243 | + assert reconstructed_tool_call["name"] == "get_product_info" |
| 244 | + |
| 245 | + arguments = json.loads(reconstructed_tool_call["arguments"]) |
| 246 | + assert "product_id" in arguments |
| 247 | + assert "inserted" in arguments |
| 248 | + |
| 249 | + # Handle type coercion for streaming test as well |
| 250 | + product_id = arguments.get("product_id") |
| 251 | + inserted = arguments.get("inserted") |
| 252 | + |
| 253 | + assert isinstance(product_id, int) |
| 254 | + assert product_id == 7355608 |
| 255 | + assert isinstance(inserted, bool) |
| 256 | + assert inserted is True |
| 257 | + |
| 258 | + print("\n[Streaming Product Test Passed]") |
| 259 | + print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") |
| 260 | + print(f"Reconstructed Arguments: {arguments}") |
0 commit comments