Skip to content

Commit 8a739ab

Browse files
feat: add bedrock reasoning stream support (envoyproxy#1173)
**Description** adding reasoning support for aws bedrock streaming --------- Signed-off-by: Alexa Griffith <[email protected]> Signed-off-by: Alexa Griffith <[email protected]> Signed-off-by: Dan Sun <[email protected]> Co-authored-by: Dan Sun <[email protected]>
1 parent 23dd567 commit 8a739ab

File tree

6 files changed

+763
-46
lines changed

6 files changed

+763
-46
lines changed

internal/apischema/awsbedrock/awsbedrock.go

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -255,16 +255,7 @@ type ReasoningTextBlock struct {
255255
// The reasoning that the model used to return the output.
256256
Text string `json:"text"`
257257
// A token that verifies that the reasoning text was generated by the model.
258-
Signature *string `json:"signature,omitempty"`
259-
}
260-
261-
// RedactedContentBlock contains content that has been redacted.
262-
// This is based on the structure of ReasoningTextBlock as per AWS documentation patterns.
263-
type RedactedContentBlock struct {
264-
// The redacted text.
265-
Text string `json:"text"`
266-
// A token that verifies the redaction.
267-
Signature *string `json:"signature,omitempty"`
258+
Signature string `json:"signature,omitzero"`
268259
}
269260

270261
// ReasoningContentBlock contains the reasoning trace for the inference that the model ran.
@@ -273,7 +264,7 @@ type ReasoningContentBlock struct {
273264
// The reasoning that the model used to return the output.
274265
ReasoningText *ReasoningTextBlock `json:"reasoningText,omitempty"`
275266
// The content that has been redacted from the reasoning trace.
276-
RedactedContent *RedactedContentBlock `json:"redactedContent,omitempty"`
267+
RedactedContent []byte `json:"redactedContent,omitempty"`
277268
}
278269

279270
// ToolResultContentBlock The tool result content block.
@@ -413,8 +404,9 @@ type ConverseStreamEvent struct {
413404
// ConverseStreamEventContentBlockDelta is defined in the AWS Bedrock API:
414405
// https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ContentBlockDelta.html
415406
type ConverseStreamEventContentBlockDelta struct {
416-
Text *string `json:"text,omitempty"`
417-
ToolUse *ToolUseBlockDelta `json:"toolUse,omitempty"`
407+
Text *string `json:"text,omitempty"`
408+
ToolUse *ToolUseBlockDelta `json:"toolUse,omitempty"`
409+
ReasoningContent *ReasoningContentBlock `json:"reasoningContent,omitempty"`
418410
}
419411

420412
// ContentBlockStart is the start information.
@@ -512,7 +504,7 @@ type ToolInputSchema struct {
512504
// ToolSpecification The specification for the tool.
513505
type ToolSpecification struct {
514506
// The description for the tool.
515-
Description *string `json:"description,omitempty"`
507+
Description *string `json:"description,omitzero"`
516508

517509
// The schema for the tool in JSON format.
518510
//

internal/apischema/openai/openai.go

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,13 @@ func (s *StringOrAssistantRoleContentUnion) UnmarshalJSON(data []byte) error {
239239
return nil
240240
}
241241

242+
var singleContent ChatCompletionAssistantMessageParamContent
243+
err = json.Unmarshal(data, &singleContent)
244+
if err == nil {
245+
s.Value = singleContent
246+
return nil
247+
}
248+
242249
return errors.New("cannot unmarshal JSON data as string or assistant content parts")
243250
}
244251

@@ -448,8 +455,10 @@ type ChatCompletionAssistantMessageParamAudio struct {
448455
type ChatCompletionAssistantMessageParamContentType string
449456

450457
const (
451-
ChatCompletionAssistantMessageParamContentTypeText ChatCompletionAssistantMessageParamContentType = "text"
452-
ChatCompletionAssistantMessageParamContentTypeRefusal ChatCompletionAssistantMessageParamContentType = "refusal"
458+
ChatCompletionAssistantMessageParamContentTypeText ChatCompletionAssistantMessageParamContentType = "text"
459+
ChatCompletionAssistantMessageParamContentTypeRefusal ChatCompletionAssistantMessageParamContentType = "refusal"
460+
ChatCompletionAssistantMessageParamContentTypeThinking ChatCompletionAssistantMessageParamContentType = "thinking"
461+
ChatCompletionAssistantMessageParamContentTypeRedactedThinking ChatCompletionAssistantMessageParamContentType = "redacted_thinking"
453462
)
454463

455464
// ChatCompletionAssistantMessageParamContent Learn about
@@ -461,6 +470,10 @@ type ChatCompletionAssistantMessageParamContent struct {
461470
Refusal *string `json:"refusal,omitempty"`
462471
// The text content.
463472
Text *string `json:"text,omitempty"`
473+
474+
// The signature for a thinking block.
475+
Signature *string `json:"signature,omitempty"`
476+
RedactedContent []byte `json:"redactedContent,omitempty"`
464477
}
465478

466479
// ChatCompletionAssistantMessageParam Messages sent by the model in response to user messages.
@@ -1286,10 +1299,11 @@ type ChatCompletionResponseChunkChoice struct {
12861299
// ChatCompletionResponseChunkChoiceDelta is described in the OpenAI API documentation:
12871300
// https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-choices
12881301
type ChatCompletionResponseChunkChoiceDelta struct {
1289-
Content *string `json:"content,omitempty"`
1290-
Role string `json:"role,omitempty"`
1291-
ToolCalls []ChatCompletionMessageToolCallParam `json:"tool_calls,omitempty"`
1292-
Annotations *[]Annotation `json:"annotations,omitempty"`
1302+
Content *string `json:"content,omitempty"`
1303+
Role string `json:"role,omitempty"`
1304+
ToolCalls []ChatCompletionMessageToolCallParam `json:"tool_calls,omitempty"`
1305+
Annotations *[]Annotation `json:"annotations,omitempty"`
1306+
ReasoningContent *AWSBedRockStreamReasoningContent `json:"reasoning_content,omitempty"`
12931307
}
12941308

12951309
// Error is described in the OpenAI API documentation
@@ -1508,5 +1522,15 @@ type AWSBedRockResponseVendorFields struct {
15081522
// Note: This object is a Union. Only one member of this object can be specified or returned.
15091523
// Required: No
15101524
// See https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ReasoningContentBlock.html for more information.
1525+
ReasoningContent *AWSBedRockReasoningContent `json:"reasoning_content,omitzero"`
1526+
}
1527+
1528+
type AWSBedRockReasoningContent struct {
15111529
ReasoningContent *awsbedrock.ReasoningContentBlock `json:"reasoningContent,omitzero"`
15121530
}
1531+
1532+
type AWSBedRockStreamReasoningContent struct {
1533+
Text string `json:"text,omitzero"`
1534+
Signature string `json:"signature,omitzero"`
1535+
RedactedContent []byte `json:"redactedContent,omitzero"`
1536+
}

internal/apischema/openai/openai_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,3 +1725,64 @@ func TestChatCompletionResponseChunkChoiceDelta_annotations_round_trip(t *testin
17251725
require.NoError(t, err)
17261726
require.JSONEq(t, `{}`, string(marshaled))
17271727
}
1728+
1729+
func TestStringOrAssistantRoleContentUnionUnmarshal(t *testing.T) {
1730+
testCases := []struct {
1731+
name string
1732+
input string
1733+
expected StringOrAssistantRoleContentUnion
1734+
expErr string
1735+
}{
1736+
{
1737+
name: "string value",
1738+
input: `"hello"`,
1739+
expected: StringOrAssistantRoleContentUnion{
1740+
Value: "hello",
1741+
},
1742+
},
1743+
{
1744+
name: "array of content objects",
1745+
input: `[{"type": "text", "text": "hello from array"}]`,
1746+
expected: StringOrAssistantRoleContentUnion{
1747+
Value: []ChatCompletionAssistantMessageParamContent{
1748+
{
1749+
Type: ChatCompletionAssistantMessageParamContentTypeText,
1750+
Text: ptr.To("hello from array"),
1751+
},
1752+
},
1753+
},
1754+
},
1755+
{
1756+
name: "single content object",
1757+
input: `{"type": "text", "text": "hello from single object"}`,
1758+
expected: StringOrAssistantRoleContentUnion{
1759+
Value: ChatCompletionAssistantMessageParamContent{
1760+
Type: ChatCompletionAssistantMessageParamContentTypeText,
1761+
Text: ptr.To("hello from single object"),
1762+
},
1763+
},
1764+
},
1765+
{
1766+
name: "invalid json",
1767+
input: `12345`,
1768+
expErr: "cannot unmarshal JSON data as string or assistant content parts",
1769+
},
1770+
}
1771+
1772+
for _, tc := range testCases {
1773+
t.Run(tc.name, func(t *testing.T) {
1774+
var result StringOrAssistantRoleContentUnion
1775+
err := json.Unmarshal([]byte(tc.input), &result)
1776+
1777+
if tc.expErr != "" {
1778+
require.ErrorContains(t, err, tc.expErr)
1779+
return
1780+
}
1781+
1782+
require.NoError(t, err)
1783+
if !cmp.Equal(tc.expected, result) {
1784+
t.Errorf("Unmarshal diff(got, expected) = %s\n", cmp.Diff(result, tc.expected))
1785+
}
1786+
})
1787+
}
1788+
}

internal/extproc/translator/openai_awsbedrock.go

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,15 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) openAIToolsToBedrockToolC
120120
for i := range openAIReq.Tools {
121121
toolDefinition := &openAIReq.Tools[i]
122122
if toolDefinition.Function != nil {
123-
var toolName, toolDes string
124-
toolName = toolDefinition.Function.Name
125-
toolDes = toolDefinition.Function.Description
123+
toolName := toolDefinition.Function.Name
124+
var toolDesc *string
125+
if toolDefinition.Function.Description != "" {
126+
toolDesc = &toolDefinition.Function.Description
127+
}
126128
tool := &awsbedrock.Tool{
127129
ToolSpec: &awsbedrock.ToolSpecification{
128130
Name: &toolName,
129-
Description: &toolDes,
131+
Description: toolDesc,
130132
InputSchema: &awsbedrock.ToolInputSchema{
131133
JSON: toolDefinition.Function.Parameters,
132134
},
@@ -247,21 +249,58 @@ func unmarshalToolCallArguments(arguments string) (map[string]any, error) {
247249
func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) openAIMessageToBedrockMessageRoleAssistant(
248250
openAiMessage *openai.ChatCompletionAssistantMessageParam, role string,
249251
) (*awsbedrock.Message, error) {
250-
var bedrockMessage *awsbedrock.Message
252+
bedrockMessage := &awsbedrock.Message{Role: role}
251253
contentBlocks := make([]*awsbedrock.ContentBlock, 0)
254+
255+
var contentParts []openai.ChatCompletionAssistantMessageParamContent
252256
if v, ok := openAiMessage.Content.Value.(string); ok && len(v) > 0 {
253-
contentBlocks = append(contentBlocks, &awsbedrock.ContentBlock{Text: &v})
254-
} else if content, ok := openAiMessage.Content.Value.(openai.ChatCompletionAssistantMessageParamContent); ok {
255-
if content.Type == openai.ChatCompletionAssistantMessageParamContentTypeRefusal {
256-
contentBlocks = append(contentBlocks, &awsbedrock.ContentBlock{Text: content.Refusal})
257-
} else if content.Text != nil {
258-
contentBlocks = append(contentBlocks, &awsbedrock.ContentBlock{Text: content.Text})
257+
// Case 1: Content is a simple string.
258+
contentParts = append(contentParts, openai.ChatCompletionAssistantMessageParamContent{Type: openai.ChatCompletionAssistantMessageParamContentTypeText, Text: &v})
259+
} else if singleContent, ok := openAiMessage.Content.Value.(openai.ChatCompletionAssistantMessageParamContent); ok {
260+
// Case 2: Content is a single object.
261+
contentParts = append(contentParts, singleContent)
262+
} else if sliceContent, ok := openAiMessage.Content.Value.([]openai.ChatCompletionAssistantMessageParamContent); ok {
263+
// Case 3: Content is already a slice of objects.
264+
contentParts = sliceContent
265+
}
266+
267+
for _, content := range contentParts {
268+
switch content.Type {
269+
case openai.ChatCompletionAssistantMessageParamContentTypeText:
270+
if content.Text != nil {
271+
contentBlocks = append(contentBlocks, &awsbedrock.ContentBlock{Text: content.Text})
272+
}
273+
case openai.ChatCompletionAssistantMessageParamContentTypeThinking:
274+
if content.Text != nil {
275+
reasoningText := &awsbedrock.ReasoningTextBlock{
276+
Text: *content.Text,
277+
}
278+
if content.Signature != nil {
279+
reasoningText.Signature = *content.Signature
280+
}
281+
contentBlocks = append(contentBlocks, &awsbedrock.ContentBlock{
282+
ReasoningContent: &awsbedrock.ReasoningContentBlock{
283+
ReasoningText: reasoningText,
284+
},
285+
})
286+
}
287+
case openai.ChatCompletionAssistantMessageParamContentTypeRedactedThinking:
288+
if content.RedactedContent != nil {
289+
contentBlocks = append(contentBlocks, &awsbedrock.ContentBlock{
290+
ReasoningContent: &awsbedrock.ReasoningContentBlock{
291+
RedactedContent: content.RedactedContent,
292+
},
293+
})
294+
}
295+
case openai.ChatCompletionAssistantMessageParamContentTypeRefusal:
296+
if content.Refusal != nil {
297+
contentBlocks = append(contentBlocks, &awsbedrock.ContentBlock{Text: content.Refusal})
298+
}
259299
}
260300
}
261-
bedrockMessage = &awsbedrock.Message{
262-
Role: role,
263-
Content: contentBlocks,
264-
}
301+
302+
bedrockMessage.Content = contentBlocks
303+
265304
for i := range openAiMessage.ToolCalls {
266305
toolCall := &openAiMessage.ToolCalls[i]
267306
input, err := unmarshalToolCallArguments(toolCall.Function.Arguments)
@@ -628,11 +667,11 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(_ map[string
628667
if choice.Message.Content == nil {
629668
choice.Message.Content = output.Text
630669
}
631-
case output.ReasoningContent != nil && output.ReasoningContent.ReasoningText != nil:
670+
case output.ReasoningContent != nil:
632671
if choice.Message.AWSBedRockResponseVendorFields == nil {
633672
choice.Message.AWSBedRockResponseVendorFields = &openai.AWSBedRockResponseVendorFields{}
634673
}
635-
choice.Message.ReasoningContent = output.ReasoningContent
674+
choice.Message.ReasoningContent = &openai.AWSBedRockReasoningContent{ReasoningContent: output.ReasoningContent}
636675
}
637676
}
638677
openAIResp.Choices = append(openAIResp.Choices, choice)
@@ -661,10 +700,7 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) extractAmazonEventStreamE
661700
for {
662701
msg, err := dec.Decode(r, nil)
663702
if err != nil {
664-
// When failed, we stop processing the events.
665-
// Copy the unread bytes to the beginning of the buffer.
666-
copy(o.bufferedBody, o.bufferedBody[lastRead:])
667-
o.bufferedBody = o.bufferedBody[:len(o.bufferedBody)-int(lastRead)]
703+
o.bufferedBody = o.bufferedBody[lastRead:]
668704
return
669705
}
670706
var event awsbedrock.ConverseStreamEvent
@@ -700,15 +736,16 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) convertEvent(event *awsbe
700736
})
701737
o.role = *event.Role
702738
case event.Delta != nil:
703-
if event.Delta.Text != nil {
739+
switch {
740+
case event.Delta.Text != nil:
704741
chunk.Choices = append(chunk.Choices, openai.ChatCompletionResponseChunkChoice{
705742
Index: 0,
706743
Delta: &openai.ChatCompletionResponseChunkChoiceDelta{
707744
Role: o.role,
708745
Content: event.Delta.Text,
709746
},
710747
})
711-
} else if event.Delta.ToolUse != nil {
748+
case event.Delta.ToolUse != nil:
712749
chunk.Choices = append(chunk.Choices, openai.ChatCompletionResponseChunkChoice{
713750
Index: 0,
714751
Delta: &openai.ChatCompletionResponseChunkChoiceDelta{
@@ -723,6 +760,25 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) convertEvent(event *awsbe
723760
},
724761
},
725762
})
763+
case event.Delta.ReasoningContent != nil:
764+
reasoningDelta := &openai.AWSBedRockStreamReasoningContent{}
765+
766+
// Map all relevant fields from the Bedrock delta to our flattened OpenAI delta struct.
767+
if event.Delta.ReasoningContent.ReasoningText != nil {
768+
reasoningDelta.Text = event.Delta.ReasoningContent.ReasoningText.Text
769+
reasoningDelta.Signature = event.Delta.ReasoningContent.ReasoningText.Signature
770+
}
771+
if event.Delta.ReasoningContent.RedactedContent != nil {
772+
reasoningDelta.RedactedContent = event.Delta.ReasoningContent.RedactedContent
773+
}
774+
775+
chunk.Choices = append(chunk.Choices, openai.ChatCompletionResponseChunkChoice{
776+
Index: 0,
777+
Delta: &openai.ChatCompletionResponseChunkChoiceDelta{
778+
Role: o.role,
779+
ReasoningContent: reasoningDelta,
780+
},
781+
})
726782
}
727783
case event.Start != nil:
728784
if event.Start.ToolUse != nil {

0 commit comments

Comments
 (0)