diff --git a/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/schema/SchemaGenerationAllowlist.java b/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/schema/SchemaGenerationAllowlist.java index 1360c9f0baa..8637f2b6857 100644 --- a/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/schema/SchemaGenerationAllowlist.java +++ b/smithy-typescript-codegen/src/main/java/software/amazon/smithy/typescript/codegen/schema/SchemaGenerationAllowlist.java @@ -8,6 +8,7 @@ import java.util.HashSet; import java.util.Set; import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.protocol.traits.Rpcv2CborTrait; import software.amazon.smithy.typescript.codegen.TypeScriptSettings; import software.amazon.smithy.utils.SmithyInternalApi; @@ -19,22 +20,30 @@ */ @SmithyInternalApi public abstract class SchemaGenerationAllowlist { - private static final Set ALLOWED = new HashSet<>(); + private static final Set ALLOWED = new HashSet<>(); + private static final Set PROTOCOLS = new HashSet<>(); static { - ALLOWED.add("smithy.protocoltests.rpcv2Cbor#RpcV2Protocol"); - } - - @Deprecated - public static boolean allows(String serviceShapeId, TypeScriptSettings settings) { - return ALLOWED.contains(serviceShapeId) && settings.generateSchemas(); + ALLOWED.add(ShapeId.from("smithy.protocoltests.rpcv2Cbor#RpcV2Protocol")); + PROTOCOLS.add(Rpcv2CborTrait.ID); } public static boolean allows(ShapeId serviceShapeId, TypeScriptSettings settings) { - return ALLOWED.contains(serviceShapeId.toString()) && settings.generateSchemas(); + boolean allowedByProtocol = PROTOCOLS.contains(settings.getProtocol()); + boolean allowedByName = ALLOWED.contains(serviceShapeId); + return settings.generateSchemas() && (allowedByProtocol || allowedByName); } + @Deprecated public static void allow(String serviceShapeId) { + ALLOWED.add(ShapeId.from(serviceShapeId)); + } + + public static void allow(ShapeId serviceShapeId) { ALLOWED.add(serviceShapeId); } + + public static void allowProtocol(ShapeId protocolShapeId) { + PROTOCOLS.add(protocolShapeId); + } }