diff --git a/go/rpc/generated_rpc_union_test.go b/go/rpc/generated_rpc_union_test.go index a8a34ec60..b33141926 100644 --- a/go/rpc/generated_rpc_union_test.go +++ b/go/rpc/generated_rpc_union_test.go @@ -185,6 +185,47 @@ func TestCommandsInvokeUnmarshalsSlashCommandInvocationResult(t *testing.T) { } } +func TestQueuedCommandResultBoolDiscriminatorJSONUnion(t *testing.T) { + stopProcessingQueue := true + var handled QueuedCommandResult = &QueuedCommandHandled{StopProcessingQueue: &stopProcessingQueue} + raw, err := json.Marshal(handled) + if err != nil { + t.Fatalf("marshal handled result: %v", err) + } + if string(raw) != `{"handled":true,"stopProcessingQueue":true}` { + t.Fatalf("marshal handled result = %s", raw) + } + + decodedHandled, err := unmarshalQueuedCommandResult([]byte(`{"handled":true,"stopProcessingQueue":true}`)) + if err != nil { + t.Fatalf("unmarshal handled result: %v", err) + } + decodedHandledValue, ok := decodedHandled.(*QueuedCommandHandled) + if !ok { + t.Fatalf("unmarshal handled result = %T, want *QueuedCommandHandled", decodedHandled) + } + if decodedHandledValue.StopProcessingQueue == nil || !*decodedHandledValue.StopProcessingQueue { + t.Fatalf("unmarshal handled stopProcessingQueue = %v, want true", decodedHandledValue.StopProcessingQueue) + } + + var notHandled QueuedCommandResult = &QueuedCommandNotHandled{} + raw, err = json.Marshal(notHandled) + if err != nil { + t.Fatalf("marshal not handled result: %v", err) + } + if string(raw) != `{"handled":false}` { + t.Fatalf("marshal not handled result = %s", raw) + } + + decodedNotHandled, err := unmarshalQueuedCommandResult([]byte(`{"handled":false}`)) + if err != nil { + t.Fatalf("unmarshal not handled result: %v", err) + } + if _, ok := decodedNotHandled.(*QueuedCommandNotHandled); !ok { + t.Fatalf("unmarshal not handled result = %T, want *QueuedCommandNotHandled", decodedNotHandled) + } +} + func TestUIElicitationFieldValueJSONUnion(t *testing.T) { raw, err := json.Marshal(UIElicitationBooleanValue(true)) if err != nil { diff --git a/go/rpc/zrpc.go b/go/rpc/zrpc.go index f83d26baa..9f13a52cb 100644 --- a/go/rpc/zrpc.go +++ b/go/rpc/zrpc.go @@ -1233,24 +1233,28 @@ type PluginList struct { Plugins []Plugin `json:"plugins"` } +// Result of the queued command execution +type QueuedCommandResult interface { + queuedCommandResult() + Handled() bool +} + type QueuedCommandHandled struct { - // The command was handled - Handled bool `json:"handled"` // If true, stop processing remaining queued items StopProcessingQueue *bool `json:"stopProcessingQueue,omitempty"` } +func (QueuedCommandHandled) queuedCommandResult() {} +func (QueuedCommandHandled) Handled() bool { + return true +} + type QueuedCommandNotHandled struct { - // The command was not handled - Handled bool `json:"handled"` } -// Result of the queued command execution -type QueuedCommandResult struct { - // The command was handled - Handled any `json:"handled"` - // If true, stop processing remaining queued items - StopProcessingQueue *bool `json:"stopProcessingQueue,omitempty"` +func (QueuedCommandNotHandled) queuedCommandResult() {} +func (QueuedCommandNotHandled) Handled() bool { + return false } // Experimental: RemoteDisableResult is part of an experimental API and may change or be diff --git a/go/rpc/zrpc_encoding.go b/go/rpc/zrpc_encoding.go index bf77c3c2e..ed6610c74 100644 --- a/go/rpc/zrpc_encoding.go +++ b/go/rpc/zrpc_encoding.go @@ -8,6 +8,80 @@ import ( "errors" ) +func unmarshalQueuedCommandResult(data []byte) (QueuedCommandResult, error) { + if string(data) == "null" { + return nil, nil + } + type rawUnion struct { + Handled *bool `json:"handled"` + } + var raw rawUnion + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + if raw.Handled == nil { + return nil, errors.New("data did not match any union variant for QueuedCommandResult") + } + + switch *raw.Handled { + case false: + var d QueuedCommandNotHandled + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + case true: + var d QueuedCommandHandled + if err := json.Unmarshal(data, &d); err != nil { + return nil, err + } + return &d, nil + } + return nil, errors.New("data did not match any union variant for QueuedCommandResult") +} + +func (r QueuedCommandHandled) MarshalJSON() ([]byte, error) { + type alias QueuedCommandHandled + return json.Marshal(struct { + Handled bool `json:"handled"` + alias + }{ + Handled: r.Handled(), + alias: alias(r), + }) +} + +func (r QueuedCommandNotHandled) MarshalJSON() ([]byte, error) { + type alias QueuedCommandNotHandled + return json.Marshal(struct { + Handled bool `json:"handled"` + alias + }{ + Handled: r.Handled(), + alias: alias(r), + }) +} + +func (r *CommandsRespondToQueuedCommandRequest) UnmarshalJSON(data []byte) error { + type rawCommandsRespondToQueuedCommandRequest struct { + RequestID string `json:"requestId"` + Result json.RawMessage `json:"result"` + } + var raw rawCommandsRespondToQueuedCommandRequest + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + r.RequestID = raw.RequestID + if raw.Result != nil { + value, err := unmarshalQueuedCommandResult(raw.Result) + if err != nil { + return err + } + r.Result = value + } + return nil +} + func unmarshalExternalToolTextResultForLlmContent(data []byte) (ExternalToolTextResultForLlmContent, error) { if string(data) == "null" { return nil, nil diff --git a/scripts/codegen/go.ts b/scripts/codegen/go.ts index 0f251e626..c4643c320 100644 --- a/scripts/codegen/go.ts +++ b/scripts/codegen/go.ts @@ -326,15 +326,19 @@ interface GoDiscriminatedUnionInfo { unmarshalFuncName: string; } +type GoDiscriminatorValue = string | boolean; +type GoDiscriminatorValueKind = "string" | "boolean"; + interface GoDiscriminatedUnionVariant { schema: JSONSchema7; typeName: string; - discriminatorValues: string[]; + discriminatorValues: GoDiscriminatorValue[]; } interface GoDiscriminatorInfo { property: string; - mapping: Map; + valueKind: GoDiscriminatorValueKind; + mapping: Map; variants: GoDiscriminatedUnionVariant[]; } @@ -430,8 +434,23 @@ function sortedGoEventEnvelopeProperties(properties: GoEventEnvelopeProperty[]): return [...properties].sort((left, right) => compareGoFieldNames(left.fieldName, right.fieldName)); } +interface GoDiscriminatorValues { + kind: GoDiscriminatorValueKind; + values: GoDiscriminatorValue[]; +} + +function goDiscriminatorValues(schema: JSONSchema7, ctx: GoCodegenCtx): GoDiscriminatorValues | undefined { + const stringValues = goStringEnumValues(schema, ctx); + if (stringValues) return { kind: "string", values: stringValues }; + + const booleanValues = goBooleanDiscriminatorValues(schema, ctx); + if (booleanValues) return { kind: "boolean", values: booleanValues }; + + return undefined; +} + /** - * Find a string-valued discriminator property shared by all anyOf variants. + * Find a literal-valued discriminator property shared by all anyOf variants. */ function findGoDiscriminator( variants: JSONSchema7[], @@ -444,10 +463,10 @@ function findGoDiscriminator( for (const [propName, propSchema] of Object.entries(firstVariant.properties)) { if (typeof propSchema !== "object") continue; - const firstValues = goStringEnumValues(propSchema as JSONSchema7, ctx); - if (!firstValues || firstValues.length === 0) continue; + const firstDiscriminatorValues = goDiscriminatorValues(propSchema as JSONSchema7, ctx); + if (!firstDiscriminatorValues || firstDiscriminatorValues.values.length === 0) continue; - const mapping = new Map(); + const mapping = new Map(); const unionVariants: GoDiscriminatedUnionVariant[] = []; let valid = true; for (const variantSource of variants) { @@ -456,9 +475,10 @@ function findGoDiscriminator( if (!(variant.required || []).includes(propName)) { valid = false; break; } const vp = variant.properties[propName]; if (typeof vp !== "object") { valid = false; break; } - const discriminatorValues = goStringEnumValues(vp as JSONSchema7, ctx); - if (!discriminatorValues || discriminatorValues.length === 0) { valid = false; break; } - const dedupedValues = [...new Set(discriminatorValues)]; + const discriminatorValues = goDiscriminatorValues(vp as JSONSchema7, ctx); + if (!discriminatorValues || discriminatorValues.values.length === 0 || discriminatorValues.kind !== firstDiscriminatorValues.kind) { valid = false; break; } + const dedupedValues = [...new Set(discriminatorValues.values)]; + if (discriminatorValues.kind === "boolean" && dedupedValues.length > 1) { valid = false; break; } const unionVariant = { schema: variant, typeName: goDiscriminatedUnionVariantTypeName(unionTypeName, dedupedValues[0], variantSource, variant, ctx), @@ -472,7 +492,7 @@ function findGoDiscriminator( } } if (valid && mapping.size > 0 && unionVariants.length === variants.length) { - return { property: propName, mapping, variants: unionVariants }; + return { property: propName, valueKind: firstDiscriminatorValues.kind, mapping, variants: unionVariants }; } } return null; @@ -580,7 +600,7 @@ function goEnumConstSuffix(value: string): string { function goDiscriminatedUnionVariantTypeName( unionTypeName: string, - discriminatorValue: string, + discriminatorValue: GoDiscriminatorValue, variantSource: JSONSchema7, variant: JSONSchema7, ctx: GoCodegenCtx @@ -592,7 +612,24 @@ function goDiscriminatedUnionVariantTypeName( if (definitionRef) { return goDefinitionName(refTypeName(definitionRef, ctx.definitions)); } - return `${unionTypeName}${goEnumConstSuffix(discriminatorValue)}`; + return `${unionTypeName}${goDiscriminatorConstSuffix(discriminatorValue)}`; +} + +function goDiscriminatorConstSuffix(value: GoDiscriminatorValue): string { + return typeof value === "boolean" ? (value ? "True" : "False") : goEnumConstSuffix(value); +} + +function compareGoDiscriminatorValues(left: GoDiscriminatorValue, right: GoDiscriminatorValue): number { + if (typeof left === "boolean" && typeof right === "boolean") { + return Number(left) - Number(right); + } + return String(left).localeCompare(String(right)); +} + +function goDiscriminatorValueExpr(value: GoDiscriminatorValue, enumName: string | undefined): string { + if (typeof value === "boolean") return value ? "true" : "false"; + if (!enumName) throw new Error(`Missing enum name for string discriminator value ${value}`); + return `${enumName}${goEnumConstSuffix(value)}`; } function schemaForConstValue(value: unknown): JSONSchema7 { @@ -1458,24 +1495,29 @@ function emitGoFlatDiscriminatedUnion( if (ctx.generatedNames.has(typeName)) return; ctx.generatedNames.add(typeName); - // Discriminator field: generate an enum from the const values const discriminatorProp = discriminator.property; const mapping = discriminator.mapping; const unionVariants = [...discriminator.variants].sort((left, right) => compareGoTypeNames(left.typeName, right.typeName)); const discGoName = toGoFieldName(discriminatorProp); const discriminatorMethodName = discGoName; - const discValues = [...mapping.keys()]; - const discEnumName = getOrCreateGoEnum( - typeName + discGoName, - discValues, - ctx, - `${discGoName} discriminator for ${typeName}.`, - false, - experimental - ); + let discEnumName: string | undefined; + let discGoType = "bool"; + if (discriminator.valueKind === "string") { + const discValues = [...mapping.keys()].filter((value): value is string => typeof value === "string"); + discEnumName = getOrCreateGoEnum( + typeName + discGoName, + discValues, + ctx, + `${discGoName} discriminator for ${typeName}.`, + false, + experimental + ); + discGoType = discEnumName; + } const unmarshalFuncName = goUnexportedFunctionName("unmarshal", typeName); const rawDataName = `Raw${typeName}${ctx.discriminatedUnionRawVariantSuffix ?? "Data"}`; + const hasRawVariant = discriminator.valueKind === "string"; const markerName = `${typeName.charAt(0).toLowerCase()}${typeName.slice(1)}`; ctx.discriminatedUnions.set(typeName, { typeName, unmarshalFuncName }); @@ -1488,7 +1530,7 @@ function emitGoFlatDiscriminatedUnion( } lines.push(`type ${typeName} interface {`); lines.push(`\t${markerName}()`); - lines.push(`\t${discriminatorMethodName}() ${discEnumName}`); + lines.push(`\t${discriminatorMethodName}() ${discGoType}`); lines.push(`}`); lines.push(``); @@ -1513,16 +1555,23 @@ function emitGoFlatDiscriminatedUnion( unmarshalLines.push(`\t\treturn nil, nil`); unmarshalLines.push(`\t}`); unmarshalLines.push(`\ttype rawUnion struct {`); - unmarshalLines.push(`\t\t${discGoName} ${discEnumName} \`json:"${discriminatorProp}"\``); + const rawDiscGoType = discriminator.valueKind === "boolean" ? `*${discGoType}` : discGoType; + const rawDiscExpr = discriminator.valueKind === "boolean" ? `*raw.${discGoName}` : `raw.${discGoName}`; + unmarshalLines.push(`\t\t${discGoName} ${rawDiscGoType} \`json:"${discriminatorProp}"\``); unmarshalLines.push(`\t}`); unmarshalLines.push(`\tvar raw rawUnion`); unmarshalLines.push(`\tif err := json.Unmarshal(data, &raw); err != nil {`); unmarshalLines.push(`\t\treturn nil, err`); unmarshalLines.push(`\t}`); + if (discriminator.valueKind === "boolean") { + unmarshalLines.push(`\tif raw.${discGoName} == nil {`); + unmarshalLines.push(`\t\treturn nil, errors.New("data did not match any union variant for ${typeName}")`); + unmarshalLines.push(`\t}`); + } unmarshalLines.push(``); - unmarshalLines.push(`\tswitch raw.${discGoName} {`); - for (const discriminatorValue of [...mapping.keys()].sort()) { - const constName = `${discEnumName}${goEnumConstSuffix(discriminatorValue)}`; + unmarshalLines.push(`\tswitch ${rawDiscExpr} {`); + for (const discriminatorValue of [...mapping.keys()].sort(compareGoDiscriminatorValues)) { + const constName = goDiscriminatorValueExpr(discriminatorValue, discEnumName); const mappedVariants = [...mapping.get(discriminatorValue)!].sort((left, right) => compareGoTypeNames(left.typeName, right.typeName)); unmarshalLines.push(`\tcase ${constName}:`); if (mappedVariants.length === 1) { @@ -1542,36 +1591,47 @@ function emitGoFlatDiscriminatedUnion( unmarshalLines.push(`\t\t\treturn &d, nil`); unmarshalLines.push(`\t\t}`); } - unmarshalLines.push(`\t\treturn &${rawDataName}{Discriminator: raw.${discGoName}, Raw: data}, nil`); + if (hasRawVariant) { + unmarshalLines.push(`\t\treturn &${rawDataName}{Discriminator: ${rawDiscExpr}, Raw: data}, nil`); + } else { + unmarshalLines.push(`\t\treturn nil, errors.New("data did not match any union variant for ${typeName}")`); + } } } - unmarshalLines.push(`\tdefault:`); - unmarshalLines.push(`\t\treturn &${rawDataName}{Discriminator: raw.${discGoName}, Raw: data}, nil`); + if (hasRawVariant) { + unmarshalLines.push(`\tdefault:`); + unmarshalLines.push(`\t\treturn &${rawDataName}{Discriminator: ${rawDiscExpr}, Raw: data}, nil`); + } unmarshalLines.push(`\t}`); + if (discriminator.valueKind === "boolean") { + unmarshalLines.push(`\treturn nil, errors.New("data did not match any union variant for ${typeName}")`); + } unmarshalLines.push(`}`); pushGoEncodingBlock(unmarshalLines, ctx); - lines.push(`type ${rawDataName} struct {`); - lines.push(`\tDiscriminator ${discEnumName}`); - lines.push(`\tRaw json.RawMessage`); - lines.push(`}`); - lines.push(``); - lines.push(`func (${rawDataName}) ${markerName}() {}`); - lines.push(`func (r ${rawDataName}) ${discriminatorMethodName}() ${discEnumName} {`); - lines.push(`\treturn r.Discriminator`); - lines.push(`}`); - pushGoEncodingBlock([ - `func (r ${rawDataName}) MarshalJSON() ([]byte, error) {`, - `\tif r.Raw != nil {`, - `\t\treturn r.Raw, nil`, - `\t}`, - `\treturn json.Marshal(struct {`, - `\t\t${discGoName} ${discEnumName} \`json:"${discriminatorProp}"\``, - `\t}{`, - `\t\t${discGoName}: r.Discriminator,`, - `\t})`, - `}`, - ], ctx); + if (hasRawVariant) { + lines.push(`type ${rawDataName} struct {`); + lines.push(`\tDiscriminator ${discGoType}`); + lines.push(`\tRaw json.RawMessage`); + lines.push(`}`); + lines.push(``); + lines.push(`func (${rawDataName}) ${markerName}() {}`); + lines.push(`func (r ${rawDataName}) ${discriminatorMethodName}() ${discGoType} {`); + lines.push(`\treturn r.Discriminator`); + lines.push(`}`); + pushGoEncodingBlock([ + `func (r ${rawDataName}) MarshalJSON() ([]byte, error) {`, + `\tif r.Raw != nil {`, + `\t\treturn r.Raw, nil`, + `\t}`, + `\treturn json.Marshal(struct {`, + `\t\t${discGoName} ${discGoType} \`json:"${discriminatorProp}"\``, + `\t}{`, + `\t\t${discGoName}: r.Discriminator,`, + `\t})`, + `}`, + ], ctx); + } for (const mappedVariant of unionVariants) { const variant = mappedVariant.schema; @@ -1611,23 +1671,26 @@ function emitGoFlatDiscriminatedUnion( pushGoStructUnmarshalJSON(lines, variantTypeName, fields, ctx); lines.push(``); lines.push(`func (${variantTypeName}) ${markerName}() {}`); - const defaultConstName = `${discEnumName}${goEnumConstSuffix(mappedVariant.discriminatorValues[0])}`; + const defaultConstName = goDiscriminatorValueExpr(mappedVariant.discriminatorValues[0], discEnumName); if (mappedVariant.discriminatorValues.length <= 1) { - lines.push(`func (${variantTypeName}) ${discriminatorMethodName}() ${discEnumName} {`); + lines.push(`func (${variantTypeName}) ${discriminatorMethodName}() ${discGoType} {`); lines.push(`\treturn ${defaultConstName}`); + } else if (discriminator.valueKind === "boolean") { + lines.push(`func (r ${variantTypeName}) ${discriminatorMethodName}() ${discGoType} {`); + lines.push(`\treturn r.Discriminator`); } else { - lines.push(`func (r ${variantTypeName}) ${discriminatorMethodName}() ${discEnumName} {`); + lines.push(`func (r ${variantTypeName}) ${discriminatorMethodName}() ${discGoType} {`); lines.push(`\tif r.Discriminator == "" {`); lines.push(`\t\treturn ${defaultConstName}`); lines.push(`\t}`); - lines.push(`\treturn ${discEnumName}(r.Discriminator)`); + lines.push(`\treturn ${discGoType}(r.Discriminator)`); } lines.push(`}`); pushGoEncodingBlock([ `func (r ${variantTypeName}) MarshalJSON() ([]byte, error) {`, `\ttype alias ${variantTypeName}`, `\treturn json.Marshal(struct {`, - `\t\t${discGoName} ${discEnumName} \`json:"${discriminatorProp}"\``, + `\t\t${discGoName} ${discGoType} \`json:"${discriminatorProp}"\``, `\t\talias`, `\t}{`, `\t\t${discGoName}: r.${discriminatorMethodName}(),`, @@ -1893,6 +1956,49 @@ function goStringEnumValues(schema: JSONSchema7, ctx: GoCodegenCtx): string[] | return undefined; } +function goBooleanValues(schema: JSONSchema7, ctx: GoCodegenCtx): boolean[] | undefined { + const resolved = resolveSchema(schema, ctx.definitions) ?? schema; + if (typeof resolved.const === "boolean") return [resolved.const]; + if (Array.isArray(resolved.enum) && resolved.enum.every((value) => typeof value === "boolean")) { + return resolved.enum as boolean[]; + } + if (resolved.type === "boolean") return [true, false]; + + const unionMembers = goNonNullUnionMembers(resolved); + if (unionMembers.length > 0) { + const values: boolean[] = []; + for (const member of unionMembers) { + const memberValues = goBooleanValues(member, ctx); + if (!memberValues) return undefined; + values.push(...memberValues); + } + return [...new Set(values)]; + } + + return undefined; +} + +function goBooleanDiscriminatorValues(schema: JSONSchema7, ctx: GoCodegenCtx): boolean[] | undefined { + const resolved = resolveSchema(schema, ctx.definitions) ?? schema; + if (typeof resolved.const === "boolean") return [resolved.const]; + if (Array.isArray(resolved.enum) && resolved.enum.every((value) => typeof value === "boolean")) { + return resolved.enum as boolean[]; + } + + const unionMembers = goNonNullUnionMembers(resolved); + if (unionMembers.length > 0) { + const values: boolean[] = []; + for (const member of unionMembers) { + const memberValues = goBooleanDiscriminatorValues(member, ctx); + if (!memberValues) return undefined; + values.push(...memberValues); + } + return [...new Set(values)]; + } + + return undefined; +} + function mergeGoFlattenedPropertySchema( typeName: string, propName: string, @@ -1910,6 +2016,11 @@ function mergeGoFlattenedPropertySchema( }; } + const booleanValues = schemas.map((schema) => goBooleanValues(schema, ctx)); + if (booleanValues.every((values): values is boolean[] => values !== undefined)) { + return { type: "boolean" }; + } + const firstSchemaKey = stableStringify(resolveSchema(schemas[0], ctx.definitions) ?? schemas[0]); if (schemas.every((schema) => stableStringify(resolveSchema(schema, ctx.definitions) ?? schema) === firstSchemaKey)) { return schemas[0];