diff --git a/openapi/bundle.go b/openapi/bundle.go index 0f51a76..7ad818c 100644 --- a/openapi/bundle.go +++ b/openapi/bundle.go @@ -131,13 +131,14 @@ func Bundle(ctx context.Context, doc *OpenAPI, opts BundleOptions) error { } componentStorage := &componentStorage{ - schemaStorage: sequencedmap.New[string, *oas3.JSONSchema[oas3.Referenceable]](), - referenceStorage: sequencedmap.New[string, *sequencedmap.Map[string, any]](), - refs: make(map[string]string), - componentNames: make(map[string]bool), - schemaHashes: make(map[string]string), - schemaLocations: make(map[string]string), - rootLocation: targetLocation, + schemaStorage: sequencedmap.New[string, *oas3.JSONSchema[oas3.Referenceable]](), + referenceStorage: sequencedmap.New[string, *sequencedmap.Map[string, any]](), + refs: make(map[string]string), + componentNames: make(map[string]bool), + schemaHashes: make(map[string]string), + schemaLocations: make(map[string]string), + componentLocations: make(map[string]string), + rootLocation: targetLocation, } // Initialize existing component names and hashes to avoid conflicts @@ -160,6 +161,12 @@ func Bundle(ctx context.Context, doc *OpenAPI, opts BundleOptions) error { return fmt.Errorf("failed to rewrite references in bundled schemas: %w", err) } + // Rewrite references within bundled components (responses, headers, etc.) + err = rewriteRefsInBundledComponents(ctx, componentStorage) + if err != nil { + return fmt.Errorf("failed to rewrite references in bundled components: %w", err) + } + // Second pass: update all references to point to new component names err = updateReferencesToComponents(ctx, doc, componentStorage) if err != nil { @@ -173,13 +180,14 @@ func Bundle(ctx context.Context, doc *OpenAPI, opts BundleOptions) error { } type componentStorage struct { - schemaStorage *sequencedmap.Map[string, *oas3.JSONSchema[oas3.Referenceable]] - referenceStorage *sequencedmap.Map[string, *sequencedmap.Map[string, any]] - refs map[string]string // absolute ref -> component name - componentNames map[string]bool // track used names to avoid conflicts - schemaHashes map[string]string // component name -> hash for conflict detection - schemaLocations map[string]string // component name -> absolute source location (for rewriting refs) - rootLocation string // absolute path to root document for relative path calculation + schemaStorage *sequencedmap.Map[string, *oas3.JSONSchema[oas3.Referenceable]] + referenceStorage *sequencedmap.Map[string, *sequencedmap.Map[string, any]] + refs map[string]string // absolute ref -> component name + componentNames map[string]bool // track used names to avoid conflicts + schemaHashes map[string]string // component name -> hash for conflict detection + schemaLocations map[string]string // component name -> absolute source location (for rewriting refs) + componentLocations map[string]string // componentType/componentName -> absolute source location + rootLocation string // absolute path to root document for relative path calculation } func bundleObject[T any](ctx context.Context, obj *T, namingStrategy BundleNamingStrategy, opts ResolveOptions, componentStorage *componentStorage) error { @@ -315,12 +323,10 @@ func rewriteRefsInBundledSchemas(ctx context.Context, componentStorage *componen return nil } -// rewriteRefsInSchema rewrites references within a single schema -func rewriteRefsInSchema(ctx context.Context, schema *oas3.JSONSchema[oas3.Referenceable], componentStorage *componentStorage, sourceLocation string) error { - if schema == nil { - return nil - } - +// prepareSourceURI extracts and normalizes a source URI for use with handleReference. +// It extracts just the URI part (removing any fragment) and converts it to native path +// format for filepath operations. +func prepareSourceURI(sourceLocation string) string { // Extract just the URI part from sourceLocation (remove fragment if present) // sourceLocation might be like "/path/to/file.yaml#/components/schemas/SchemaName" // but we need just "/path/to/file.yaml" for resolving relative references @@ -329,12 +335,24 @@ func rewriteRefsInSchema(ctx context.Context, schema *oas3.JSONSchema[oas3.Refer sourceURI = sourceLocation // Fallback if no URI part } - // On Windows, normalize sourceURI to backslashes before using with filepath operations - // This prevents malformed paths when joining with relative references + // On Windows, convert forward slashes to backslashes for filepath operations + // componentLocations stores paths with forward slashes (from handleReference normalization) + // but filepath.Join needs native separators to work correctly if filepath.Separator == '\\' && filepath.IsAbs(sourceURI) { sourceURI = filepath.FromSlash(sourceURI) } + return sourceURI +} + +// rewriteRefsInSchema rewrites references within a single schema +func rewriteRefsInSchema(ctx context.Context, schema *oas3.JSONSchema[oas3.Referenceable], componentStorage *componentStorage, sourceLocation string) error { + if schema == nil { + return nil + } + + sourceURI := prepareSourceURI(sourceLocation) + // Walk through the schema and rewrite references for item := range oas3.Walk(ctx, schema) { err := item.Match(oas3.SchemaMatcher{ @@ -365,6 +383,229 @@ func rewriteRefsInSchema(ctx context.Context, schema *oas3.JSONSchema[oas3.Refer return nil } +// rewriteRefsInBundledComponents rewrites references within bundled components (responses, headers, etc.) +// to point to their new component locations. This handles cases where a bundled response contains +// header references that also need to be updated. +func rewriteRefsInBundledComponents(ctx context.Context, componentStorage *componentStorage) error { + // Walk through each component type in referenceStorage + for componentType, components := range componentStorage.referenceStorage.All() { + for componentName, component := range components.All() { + // Get the source location for this component + sourceLocation := componentStorage.componentLocations[componentType+"/"+componentName] + + // Walk through the component and update references + err := walkAndUpdateRefsInComponent(ctx, component, componentStorage, sourceLocation) + if err != nil { + return fmt.Errorf("failed to rewrite refs in %s component: %w", componentType, err) + } + } + } + return nil +} + +// walkAndUpdateRefsInComponent walks through a component and updates its internal references +func walkAndUpdateRefsInComponent(ctx context.Context, component any, componentStorage *componentStorage, sourceLocation string) error { + // We need to handle each component type that can contain references + // The Walk function will traverse the component and find all references + + // Use type-specific walking based on the component type + switch c := component.(type) { + case *ReferencedResponse: + return walkAndUpdateRefsInResponse(ctx, c, componentStorage, sourceLocation) + case *ReferencedParameter: + return walkAndUpdateRefsInParameter(ctx, c, componentStorage, sourceLocation) + case *ReferencedRequestBody: + return walkAndUpdateRefsInRequestBody(ctx, c, componentStorage, sourceLocation) + case *ReferencedCallback: + return walkAndUpdateRefsInCallback(ctx, c, componentStorage, sourceLocation) + case *ReferencedPathItem: + return walkAndUpdateRefsInPathItem(ctx, c, componentStorage, sourceLocation) + case *ReferencedLink: + return walkAndUpdateRefsInLink(ctx, c, componentStorage, sourceLocation) + case *ReferencedExample: + return walkAndUpdateRefsInExample(ctx, c, componentStorage, sourceLocation) + case *ReferencedSecurityScheme: + return walkAndUpdateRefsInSecurityScheme(ctx, c, componentStorage, sourceLocation) + case *ReferencedHeader: + return walkAndUpdateRefsInHeader(ctx, c, componentStorage, sourceLocation) + } + return nil +} + +// walkAndUpdateRefsInResponse walks through a response and updates its internal references +func walkAndUpdateRefsInResponse(_ context.Context, response *ReferencedResponse, componentStorage *componentStorage, sourceLocation string) error { + if response == nil || response.Object == nil { + return nil + } + + // Walk through the response's headers and update references + if response.Object.Headers != nil { + for _, header := range response.Object.Headers.All() { + if header != nil && header.IsReference() { + updateComponentRefWithSource(header.Reference, componentStorage, "headers", sourceLocation) + } + } + } + + // Walk through the response's content schemas + if response.Object.Content != nil { + for _, mediaType := range response.Object.Content.All() { + if mediaType != nil && mediaType.Schema != nil && mediaType.Schema.IsReference() { + updateSchemaRefWithSource(mediaType.Schema, componentStorage, sourceLocation) + } + } + } + + return nil +} + +// walkAndUpdateRefsInParameter walks through a parameter and updates its internal references +func walkAndUpdateRefsInParameter(_ context.Context, param *ReferencedParameter, componentStorage *componentStorage, sourceLocation string) error { + if param == nil || param.Object == nil { + return nil + } + + // Walk through the parameter's schema + if param.Object.Schema != nil && param.Object.Schema.IsReference() { + updateSchemaRefWithSource(param.Object.Schema, componentStorage, sourceLocation) + } + + // Walk through parameter examples + if param.Object.Examples != nil { + for _, example := range param.Object.Examples.All() { + if example != nil && example.IsReference() { + updateComponentRefWithSource(example.Reference, componentStorage, "examples", sourceLocation) + } + } + } + + return nil +} + +// walkAndUpdateRefsInRequestBody walks through a request body and updates its internal references +func walkAndUpdateRefsInRequestBody(_ context.Context, body *ReferencedRequestBody, componentStorage *componentStorage, sourceLocation string) error { + if body == nil || body.Object == nil { + return nil + } + + // Walk through the request body's content schemas + if body.Object.Content != nil { + for _, mediaType := range body.Object.Content.All() { + if mediaType != nil && mediaType.Schema != nil && mediaType.Schema.IsReference() { + updateSchemaRefWithSource(mediaType.Schema, componentStorage, sourceLocation) + } + } + } + + return nil +} + +// walkAndUpdateRefsInCallback walks through a callback and updates its internal references +func walkAndUpdateRefsInCallback(_ context.Context, callback *ReferencedCallback, componentStorage *componentStorage, sourceLocation string) error { + if callback == nil || callback.Object == nil { + return nil + } + + // Callbacks contain path items with operations + for _, pathItem := range callback.Object.All() { + if pathItem != nil && pathItem.IsReference() { + updateComponentRefWithSource(pathItem.Reference, componentStorage, "pathItems", sourceLocation) + } + } + + return nil +} + +// walkAndUpdateRefsInPathItem walks through a path item and updates its internal references +func walkAndUpdateRefsInPathItem(_ context.Context, pathItem *ReferencedPathItem, componentStorage *componentStorage, sourceLocation string) error { + if pathItem == nil || pathItem.Object == nil { + return nil + } + + // Path items can have parameters + if pathItem.Object.Parameters != nil { + for _, param := range pathItem.Object.Parameters { + if param != nil && param.IsReference() { + updateComponentRefWithSource(param.Reference, componentStorage, "parameters", sourceLocation) + } + } + } + + return nil +} + +// walkAndUpdateRefsInLink walks through a link and updates its internal references +func walkAndUpdateRefsInLink(_ context.Context, _ *ReferencedLink, _ *componentStorage, _ string) error { + // Links don't typically contain component references that need updating + return nil +} + +// walkAndUpdateRefsInExample walks through an example and updates its internal references +func walkAndUpdateRefsInExample(_ context.Context, _ *ReferencedExample, _ *componentStorage, _ string) error { + // Examples don't typically contain component references that need updating + return nil +} + +// walkAndUpdateRefsInSecurityScheme walks through a security scheme and updates its internal references +func walkAndUpdateRefsInSecurityScheme(_ context.Context, _ *ReferencedSecurityScheme, _ *componentStorage, _ string) error { + // Security schemes don't typically contain component references that need updating + return nil +} + +// walkAndUpdateRefsInHeader walks through a header and updates its internal references +func walkAndUpdateRefsInHeader(_ context.Context, header *ReferencedHeader, componentStorage *componentStorage, sourceLocation string) error { + if header == nil || header.Object == nil { + return nil + } + + // Walk through the header's schema + if header.Object.Schema != nil && header.Object.Schema.IsReference() { + updateSchemaRefWithSource(header.Object.Schema, componentStorage, sourceLocation) + } + + // Walk through header examples + if header.Object.Examples != nil { + for _, example := range header.Object.Examples.All() { + if example != nil && example.IsReference() { + updateComponentRefWithSource(example.Reference, componentStorage, "examples", sourceLocation) + } + } + } + + return nil +} + +// updateSchemaRefWithSource updates a schema reference using a specific source location for resolution +func updateSchemaRefWithSource(schema *oas3.JSONSchema[oas3.Referenceable], componentStorage *componentStorage, sourceLocation string) { + if schema == nil || !schema.IsReference() { + return + } + + sourceURI := prepareSourceURI(sourceLocation) + ref := schema.GetRef() + absRef, _ := handleReference(ref, sourceURI) + + if newName, exists := componentStorage.refs[absRef]; exists { + newRef := "#/components/schemas/" + newName + *schema.GetLeft().Ref = references.Reference(newRef) + } +} + +// updateComponentRefWithSource updates a component reference using a specific source location for resolution +func updateComponentRefWithSource(ref *references.Reference, componentStorage *componentStorage, componentSection string, sourceLocation string) { + if ref == nil { + return + } + + sourceURI := prepareSourceURI(sourceLocation) + absRef, _ := handleReference(*ref, sourceURI) + + if newName, exists := componentStorage.refs[absRef]; exists { + newRef := "#/components/" + componentSection + "/" + newName + *ref = references.Reference(newRef) + } +} + // bundleGenericReference handles bundling of generic OpenAPI component references func bundleGenericReference[T any, V interfaces.Validator[T], C marshaller.CoreModeler](ctx context.Context, ref *Reference[T, V, C], namingStrategy BundleNamingStrategy, opts ResolveOptions, componentStorage *componentStorage, componentType string) error { if ref == nil || !ref.IsReference() { @@ -381,11 +622,6 @@ func bundleGenericReference[T any, V interfaces.Validator[T], C marshaller.CoreM return nil } - // Check if we've already processed this reference - if _, exists := componentStorage.refs[refStr]; exists { - return nil - } - // Resolve the external reference resolveOpts := ResolveOptions{ RootDocument: opts.RootDocument, @@ -398,15 +634,36 @@ func bundleGenericReference[T any, V interfaces.Validator[T], C marshaller.CoreM return fmt.Errorf("failed to resolve external %s reference %s: %w", componentType, refStr, resolveErr) } - // Generate component name - componentName, err := generateComponentName(refStr, namingStrategy, componentStorage.componentNames, componentStorage.rootLocation) + // Get the final absolute reference by following the resolution chain + // This handles chained references (e.g., common.yaml -> headers.yaml) + finalAbsRef := getFinalAbsoluteRef(ref, refStr) + + // Normalize finalAbsRef to forward slashes for consistent map keys across platforms + // On Windows, resolution chains may introduce backslashes, but we need forward slashes + // to match the keys created by handleReference which always normalizes to forward slashes + finalAbsRef = filepath.ToSlash(finalAbsRef) + + // Check if we've already processed this reference (using final absolute ref for deduplication) + if existingName, exists := componentStorage.refs[finalAbsRef]; exists { + // Also map the intermediate reference to the same component name + if refStr != finalAbsRef { + componentStorage.refs[refStr] = existingName + } + return nil + } + + // Generate component name using the final absolute reference + componentName, err := generateComponentName(finalAbsRef, namingStrategy, componentStorage.componentNames, componentStorage.rootLocation) if err != nil { - return fmt.Errorf("failed to generate component name for %s: %w", refStr, err) + return fmt.Errorf("failed to generate component name for %s: %w", finalAbsRef, err) } componentStorage.componentNames[componentName] = true - // Store the mapping - componentStorage.refs[refStr] = componentName + // Store the mapping (both original and final refs point to the same component) + componentStorage.refs[finalAbsRef] = componentName + if refStr != finalAbsRef { + componentStorage.refs[refStr] = componentName + } // Get the resolved content and create a new non-reference version resolvedValue := ref.GetObject() @@ -426,7 +683,16 @@ func bundleGenericReference[T any, V interfaces.Validator[T], C marshaller.CoreM if !componentStorage.referenceStorage.GetOrZero(componentType).Has(componentName) { componentStorage.referenceStorage.GetOrZero(componentType).Set(componentName, bundledRef) - targetDocInfo := ref.GetReferenceResolutionInfo() + // Get the final resolution info for nested bundling + targetDocInfo := getFinalResolutionInfo(ref) + if targetDocInfo == nil { + // Fall back to the immediate resolution info if final resolution info is unavailable + targetDocInfo = ref.GetReferenceResolutionInfo() + } + if targetDocInfo == nil { + return fmt.Errorf("failed to get resolution info for %s reference %s", componentType, refStr) + } + componentStorage.componentLocations[componentType+"/"+componentName] = targetDocInfo.AbsoluteReference if err := bundleObject(ctx, bundledRef, namingStrategy, references.ResolveOptions{ RootDocument: opts.RootDocument, @@ -440,6 +706,74 @@ func bundleGenericReference[T any, V interfaces.Validator[T], C marshaller.CoreM return nil } +// getFinalAbsoluteRef follows the reference resolution chain to get the final absolute reference. +// This is needed for proper deduplication when we have chained references like: +// testapi.yaml -> common.yaml#/components/headers/X -> headers.yaml#/components/headers/X +// Both should resolve to the same final component. +func getFinalAbsoluteRef[T any, V interfaces.Validator[T], C marshaller.CoreModeler](ref *Reference[T, V, C], initialAbsRef string) string { + if ref == nil { + return initialAbsRef + } + + resInfo := ref.GetReferenceResolutionInfo() + if resInfo == nil { + return initialAbsRef + } + + // Check if the resolved object is itself a reference (chained reference) + if resInfo.Object != nil && resInfo.Object.IsReference() { + // Follow the chain to get the final resolution info + nextRefInfo := resInfo.Object.GetReferenceResolutionInfo() + if nextRefInfo != nil { + // Build the absolute reference from the final resolution + finalRef := nextRefInfo.AbsoluteReference + if nextRefInfo.Object != nil && nextRefInfo.Object.Reference != nil { + // Add the fragment from the chained reference + fragment := string(nextRefInfo.Object.Reference.GetJSONPointer()) + if fragment != "" { + finalRef = finalRef + "#" + fragment + } + } else { + // Use the original reference's fragment with the final file location + origRef := resInfo.Object.GetReference() + fragment := string(origRef.GetJSONPointer()) + if fragment != "" { + finalRef = finalRef + "#" + fragment + } + } + // Recursively follow more chains if needed + return getFinalAbsoluteRef(resInfo.Object, finalRef) + } + } + + return initialAbsRef +} + +// getFinalResolutionInfo follows the reference resolution chain to get the final resolution info. +// This returns the resolution info for the last step in a chained reference. +func getFinalResolutionInfo[T any, V interfaces.Validator[T], C marshaller.CoreModeler](ref *Reference[T, V, C]) *references.ResolveResult[Reference[T, V, C]] { + if ref == nil { + return nil + } + + resInfo := ref.GetReferenceResolutionInfo() + if resInfo == nil { + return nil + } + + // Check if the resolved object is itself a reference (chained reference) + if resInfo.Object != nil && resInfo.Object.IsReference() { + // Follow the chain to get the final resolution info + nextRefInfo := resInfo.Object.GetReferenceResolutionInfo() + if nextRefInfo != nil { + // Recursively follow more chains + return getFinalResolutionInfo(resInfo.Object) + } + } + + return resInfo +} + // generateComponentName creates a new component name based on the reference and naming strategy func generateComponentName(ref string, strategy BundleNamingStrategy, usedNames map[string]bool, targetLocation string) (string, error) { // Convert absolute path back to relative for component naming diff --git a/openapi/testdata/bundle/issue50/common.yaml b/openapi/testdata/bundle/issue50/common.yaml index 9886fd2..a78026a 100644 --- a/openapi/testdata/bundle/issue50/common.yaml +++ b/openapi/testdata/bundle/issue50/common.yaml @@ -19,13 +19,13 @@ components: description: Internal server error headers: X-RateLimit-Limit: - $ref: "#/components/headers/X-RateLimit-Limit" + $ref: "./headers.yaml#/components/headers/X-RateLimit-Limit" X-RateLimit-Remaining: - $ref: "#/components/headers/X-RateLimit-Remaining" + $ref: "./headers.yaml#/components/headers/X-RateLimit-Remaining" X-RateLimit-Reset: - $ref: "#/components/headers/X-RateLimit-Reset" + $ref: "./headers.yaml#/components/headers/X-RateLimit-Reset" Retry-After: - $ref: "#/components/headers/Retry-After" + $ref: "./headers.yaml#/components/headers/Retry-After" content: application/json: schema: @@ -37,36 +37,6 @@ components: title: Internal Server Error instance: accounts/123 detail: "An unexpected error occurred on the server." - - headers: - X-RateLimit-Limit: - schema: - type: integer - format: int32 - minimum: 1 - maximum: 10000 - example: 100 - X-RateLimit-Remaining: - schema: - type: integer - format: int32 - minimum: 0 - maximum: 10000 - example: 99 - X-RateLimit-Reset: - schema: - type: integer - format: int32 - minimum: 0 - maximum: 2147483647 - example: 1652364907 - Retry-After: - schema: - type: integer - format: int32 - minimum: 0 - maximum: 3600 - example: 60 schemas: IndividualEntity: title: IndividualEntity @@ -182,3 +152,12 @@ components: required: true schema: $ref: "#/components/schemas/UUID" + headers: + X-RateLimit-Limit: + $ref: "./headers.yaml#/components/headers/X-RateLimit-Limit" + X-RateLimit-Remaining: + $ref: "./headers.yaml#/components/headers/X-RateLimit-Remaining" + X-RateLimit-Reset: + $ref: "./headers.yaml#/components/headers/X-RateLimit-Reset" + Retry-After: + $ref: "./headers.yaml#/components/headers/Retry-After" diff --git a/openapi/testdata/bundle/issue50/expected.yaml b/openapi/testdata/bundle/issue50/expected.yaml index 2a0519d..1ed2658 100644 --- a/openapi/testdata/bundle/issue50/expected.yaml +++ b/openapi/testdata/bundle/issue50/expected.yaml @@ -269,6 +269,7 @@ components: $ref: '#/components/schemas/UUID' headers: X-RateLimit-Limit: + description: The maximum number of requests that the client is allowed to make in this window. schema: type: integer maximum: 10000 @@ -276,6 +277,7 @@ components: format: int32 example: 100 X-RateLimit-Remaining: + description: The number of requests remaining in the current rate limit window. schema: type: integer maximum: 10000 @@ -283,6 +285,7 @@ components: format: int32 example: 99 X-RateLimit-Reset: + description: The time at which the current rate limit window resets in UTC epoch seconds. schema: type: integer maximum: 2.147483647e+09 @@ -290,6 +293,7 @@ components: format: int32 example: 1652364907 Retry-After: + description: The number of seconds to wait before making a new request when rate limit is exceeded. schema: type: integer maximum: 3600 diff --git a/openapi/testdata/bundle/issue50/headers.yaml b/openapi/testdata/bundle/issue50/headers.yaml new file mode 100644 index 0000000..cb42f17 --- /dev/null +++ b/openapi/testdata/bundle/issue50/headers.yaml @@ -0,0 +1,49 @@ +openapi: 3.0.0 +info: + title: common definitions + version: "1.0" + description: common definitions + license: + name: apache 2 + url: "https://apache.org/licenses/LICENSE-2.0" +paths: + /dummy: + get: + summary: Dummy endpoint to satisfy OpenAPI validator + responses: + "200": + description: Successful response +components: + headers: + X-RateLimit-Limit: + description: The maximum number of requests that the client is allowed to make in this window. + schema: + type: integer + format: int32 + minimum: 1 + maximum: 10000 + example: 100 + X-RateLimit-Remaining: + description: The number of requests remaining in the current rate limit window. + schema: + type: integer + format: int32 + minimum: 0 + maximum: 10000 + example: 99 + X-RateLimit-Reset: + description: The time at which the current rate limit window resets in UTC epoch seconds. + schema: + type: integer + format: int32 + minimum: 0 + maximum: 2147483647 + example: 1652364907 + Retry-After: + description: The number of seconds to wait before making a new request when rate limit is exceeded. + schema: + type: integer + format: int32 + minimum: 0 + maximum: 3600 + example: 60