Skip to content

Commit 2996387

Browse files
committed
Add streaming and thoughtSignature support to Google LLM
1 parent 5693b79 commit 2996387

File tree

9 files changed

+288
-16
lines changed

9 files changed

+288
-16
lines changed

src/agent/src/agent.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,7 @@ function agent:step(prompt_builder: any, runtime_options: any): (table?, string?
587587
arguments = tool_call.arguments,
588588
registry_id = tool_call.registry_id,
589589
context = tool_call.context,
590+
provider_metadata = tool_call.provider_metadata,
590591
agent_id = tool_info.agent_id
591592
}
592593

src/agent/src/tools/caller.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ function tool_caller:validate(tool_calls: {ToolCall}?): (any, string?)
126126
registry_id = registry_id,
127127
meta = meta,
128128
context = tool_call.context, -- Preserve tool context
129+
provider_metadata = tool_call.provider_metadata,
129130
valid = true
130131
}
131132

src/llm/src/google/_index.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ entries:
113113
modules:
114114
- json
115115
- http_client
116+
imports:
117+
output: wippy.llm:output
116118

117119
# wippy.llm.google:client_test
118120
- name: client_test

src/llm/src/google/client.lua

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,27 @@
11
local json = require("json")
22
local http_client = require("http_client")
3+
local output = require("output")
4+
5+
type StreamCallbacks = {
6+
on_content: ((text: string) -> ())?,
7+
on_tool_call: ((part: any) -> ())?,
8+
on_thinking: ((text: string) -> ())?,
9+
on_error: ((error_info: any) -> ())?,
10+
on_done: ((result: StreamResult) -> ())?,
11+
}
12+
13+
type StreamInput = {
14+
stream: any,
15+
metadata: table?,
16+
}
17+
18+
type StreamResult = {
19+
content: string,
20+
tool_calls: {any},
21+
finish_reason: string?,
22+
usage: any?,
23+
metadata: table,
24+
}
325

426
local client = {
527
_http_client = http_client
@@ -41,9 +63,203 @@ local function parse_error_response(http_response)
4163
return error_info
4264
end
4365

66+
function client.process_stream(stream_response: StreamInput, callbacks: StreamCallbacks?): (string?, string?, StreamResult?)
67+
if not stream_response or not stream_response.stream then
68+
return nil, "Invalid stream response"
69+
end
70+
71+
callbacks = callbacks or {}
72+
local on_content = callbacks.on_content or function() end
73+
local on_tool_call = callbacks.on_tool_call or function() end
74+
local on_thinking = callbacks.on_thinking or function() end
75+
local on_error = callbacks.on_error or function() end
76+
local on_done = callbacks.on_done or function() end
77+
78+
local full_content = ""
79+
local tool_calls = {}
80+
local finish_reason = nil
81+
local usage = nil
82+
local metadata = stream_response.metadata or {}
83+
84+
while true do
85+
local chunk, err = stream_response.stream:read()
86+
87+
if err then
88+
on_error({ message = err })
89+
return nil, err
90+
end
91+
92+
if not chunk then
93+
break
94+
end
95+
96+
if chunk == "" then
97+
goto continue
98+
end
99+
100+
for data_line in chunk:gmatch('data:%s*(.-)%s*\n') do
101+
if data_line == "" then
102+
goto continue_line
103+
end
104+
105+
local parsed, parse_err = json.decode(data_line)
106+
if parse_err then
107+
goto continue_line
108+
end
109+
110+
if parsed.error then
111+
local error_info = {
112+
message = parsed.error.message,
113+
code = parsed.error.code,
114+
status = parsed.error.status
115+
}
116+
on_error(error_info)
117+
return nil, error_info.message, { error = error_info }
118+
end
119+
120+
if parsed.modelVersion then
121+
metadata.model_version = parsed.modelVersion
122+
end
123+
if parsed.responseId then
124+
metadata.response_id = parsed.responseId
125+
end
126+
127+
if parsed.candidates and parsed.candidates[1] then
128+
local candidate = parsed.candidates[1]
129+
130+
if candidate.content and candidate.content.parts then
131+
for _, part in ipairs(candidate.content.parts) do
132+
if part.functionCall then
133+
table.insert(tool_calls, part)
134+
on_tool_call(part)
135+
elseif part.text then
136+
if part.thought == true then
137+
on_thinking(part.text)
138+
else
139+
full_content = full_content .. part.text
140+
on_content(part.text)
141+
end
142+
end
143+
end
144+
end
145+
146+
if candidate.finishReason then
147+
finish_reason = candidate.finishReason
148+
end
149+
end
150+
151+
if parsed.usageMetadata then
152+
usage = parsed.usageMetadata
153+
end
154+
155+
::continue_line::
156+
end
157+
158+
::continue::
159+
end
160+
161+
local result: StreamResult = {
162+
content = full_content,
163+
tool_calls = tool_calls,
164+
finish_reason = finish_reason,
165+
usage = usage,
166+
metadata = metadata
167+
}
168+
169+
on_done(result)
170+
return full_content, nil, result
171+
end
172+
173+
--- Process a streaming response and send chunks via output.streamer.
174+
--- Returns an aggregated Google-like response compatible with map_success_response().
175+
local function handle_stream_response(response, http_options)
176+
local streamer = output.streamer(
177+
http_options.stream_reply_to,
178+
http_options.stream_topic,
179+
http_options.stream_buffer_size or 10
180+
)
181+
182+
local full_content = ""
183+
local tool_call_parts = {}
184+
local finish_reason = nil
185+
local usage_metadata = nil
186+
local response_metadata = {}
187+
188+
local _, stream_err = client.process_stream(
189+
{ stream = response.stream, metadata = {} },
190+
{
191+
on_content = function(chunk: string)
192+
full_content = full_content .. chunk
193+
streamer:buffer_content(chunk)
194+
end,
195+
196+
on_tool_call = function(tool_part: any)
197+
table.insert(tool_call_parts, tool_part)
198+
if tool_part.functionCall then
199+
streamer:send_tool_call(
200+
tool_part.functionCall.name,
201+
tool_part.functionCall.args or {},
202+
tool_part.functionCall.name
203+
)
204+
end
205+
end,
206+
207+
on_thinking = function(text: string)
208+
streamer:send_thinking(text)
209+
end,
210+
211+
on_error = function(error_info: any)
212+
streamer:send_error("server_error", error_info.message)
213+
end,
214+
215+
on_done = function(result: StreamResult)
216+
streamer:flush()
217+
finish_reason = result.finish_reason
218+
usage_metadata = result.usage
219+
response_metadata = result.metadata
220+
end
221+
}
222+
)
223+
224+
if stream_err then
225+
return nil, {
226+
status_code = 500,
227+
message = "Stream processing failed: " .. tostring(stream_err)
228+
}
229+
end
230+
231+
-- Reconstruct Google-like response
232+
local parts = {}
233+
if full_content ~= "" then
234+
table.insert(parts, { text = full_content })
235+
end
236+
for _, tc_part in ipairs(tool_call_parts) do
237+
table.insert(parts, tc_part)
238+
end
239+
240+
return {
241+
candidates = {
242+
{
243+
content = { parts = parts, role = "model" },
244+
finishReason = finish_reason
245+
}
246+
},
247+
usageMetadata = usage_metadata,
248+
modelVersion = response_metadata.model_version,
249+
responseId = response_metadata.response_id,
250+
metadata = response_metadata,
251+
status_code = response.status_code or 200
252+
}
253+
end
254+
44255
function client.request(method, url, http_options)
45256
http_options.headers["Accept"] = "application/json"
46257

258+
if http_options.stream then
259+
url = url .. "?alt=sse"
260+
http_options.headers["Accept"] = "text/event-stream"
261+
end
262+
47263
local response = nil
48264
local err = nil
49265
if method == "GET" then
@@ -61,10 +277,18 @@ function client.request(method, url, http_options)
61277
end
62278

63279
if response.status_code < 200 or response.status_code >= 300 then
280+
if http_options.stream and response.stream and not response.body then
281+
response.body = response.stream:read()
282+
end
64283
local parsed_error = parse_error_response(response)
65284
return nil, parsed_error
66285
end
67286

287+
-- Streaming: process stream, send chunks via streamer, return aggregated response
288+
if http_options.stream and response.stream then
289+
return handle_stream_response(response, http_options)
290+
end
291+
68292
local parsed, parse_err = json.decode(response.body)
69293
if parse_err then
70294
local parse_error = {

src/llm/src/google/generate.lua

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,22 @@ function generate.handler(contract_args)
8484
})
8585
end
8686

87+
local endpoint_path = "generateContent"
88+
local request_options = { timeout = contract_args.timeout }
89+
90+
if contract_args.stream and contract_args.stream.reply_to then
91+
endpoint_path = "streamGenerateContent"
92+
request_options.stream = true
93+
request_options.stream_reply_to = contract_args.stream.reply_to
94+
request_options.stream_topic = contract_args.stream.topic
95+
request_options.stream_buffer_size = contract_args.stream.buffer_size
96+
end
97+
8798
local response = client_instance:request({
88-
endpoint_path = "generateContent",
99+
endpoint_path = endpoint_path,
89100
model = contract_args.model,
90101
payload = payload,
91-
options = { timeout = contract_args.timeout }
102+
options = request_options
92103
})
93104

94105
if response.status_code < 200 or response.status_code >= 300 then

src/llm/src/google/generative_ai/client.lua

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ function generative_ai_client.request(contract_args)
2626
if contract_args.options.method == "POST" then
2727
options.body = json.encode(contract_args.payload or {})
2828
end
29+
if contract_args.options.stream then
30+
options.stream = true
31+
options.stream_reply_to = contract_args.options.stream_reply_to
32+
options.stream_topic = contract_args.options.stream_topic
33+
options.stream_buffer_size = contract_args.options.stream_buffer_size
34+
end
2935

3036
local base_url = contract_args.options.base_url or generative_ai_client._config.get_generative_ai_base_url()
3137
if contract_args.model and contract_args.model ~= "" then

src/llm/src/google/mapper.lua

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,16 @@ function mapper.map_messages(contract_messages, options)
185185
and json.decode(msg.function_call.arguments)
186186
or msg.function_call.arguments
187187

188-
table.insert(processed_messages, { role = "model", parts = {
188+
local part = {
189189
functionCall = {
190190
name = msg.function_call.name,
191191
args = next(arguments or {}) ~= nil and arguments or nil
192192
}
193-
} })
193+
}
194+
if msg.function_call.provider_metadata and msg.function_call.provider_metadata.thought_signature then
195+
part.thoughtSignature = msg.function_call.provider_metadata.thought_signature
196+
end
197+
table.insert(processed_messages, { role = "model", parts = part })
194198
i = i + 1
195199
else
196200
-- Skip unknown message types
@@ -255,18 +259,25 @@ function mapper.map_options(contract_options)
255259
}
256260
end
257261

258-
function mapper.map_tool_calls(function_calls)
259-
if not function_calls then
262+
function mapper.map_tool_calls(content_parts)
263+
if not content_parts or #content_parts == 0 then
260264
return {}
261265
end
262266

263267
local contract_tool_calls = {}
264-
for i, function_call in ipairs(function_calls) do
265-
contract_tool_calls[i] = {
266-
id = (function_call.name or "func") .. "_" .. time.now():unix(),
267-
name = function_call.name,
268-
arguments = function_call.args or {},
269-
}
268+
for i, content_part in ipairs(content_parts) do
269+
if content_part.functionCall then
270+
contract_tool_calls[i] = {
271+
id = (content_part.functionCall.name or "func") .. "_" .. time.now():unix(),
272+
name = content_part.functionCall.name,
273+
arguments = content_part.functionCall.args or {},
274+
}
275+
if content_part.thoughtSignature then
276+
contract_tool_calls[i].provider_metadata = {
277+
thought_signature = content_part.thoughtSignature
278+
}
279+
end
280+
end
270281
end
271282

272283
return contract_tool_calls
@@ -327,7 +338,7 @@ function mapper.map_success_response(google_response)
327338
if content_part.text then
328339
content = content .. content_part.text
329340
elseif content_part.functionCall then
330-
table.insert(tool_calls, content_part.functionCall)
341+
table.insert(tool_calls, content_part)
331342
end
332343
end
333344
end

src/llm/src/google/vertex/client.lua

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ local vertex_client = {
88
}
99

1010
local PROJECT_REQUIRED_ENDPOINTS = {
11-
"generateContent"
11+
"generateContent",
12+
"streamGenerateContent"
1213
}
1314

1415
local function build_url(base_url, contract_args)
@@ -60,6 +61,12 @@ function vertex_client.request(contract_args)
6061
if contract_args.options.method == "POST" then
6162
options.body = json.encode(contract_args.payload or {})
6263
end
64+
if contract_args.options.stream then
65+
options.stream = true
66+
options.stream_reply_to = contract_args.options.stream_reply_to
67+
options.stream_topic = contract_args.options.stream_topic
68+
options.stream_buffer_size = contract_args.options.stream_buffer_size
69+
end
6370

6471
local base_url = contract_args.options.base_url
6572
if not base_url then

0 commit comments

Comments
 (0)