diff --git a/src/main/java/com/wolfssl/provider/jce/WolfCryptASN1Util.java b/src/main/java/com/wolfssl/provider/jce/WolfCryptASN1Util.java index 9d81defc..5300b591 100644 --- a/src/main/java/com/wolfssl/provider/jce/WolfCryptASN1Util.java +++ b/src/main/java/com/wolfssl/provider/jce/WolfCryptASN1Util.java @@ -35,11 +35,47 @@ public class WolfCryptASN1Util { /* ASN.1 Universal Tags */ - private static final byte ASN1_INTEGER = 0x02; - private static final byte ASN1_BIT_STRING = 0x03; - private static final byte ASN1_OCTET_STRING = 0x04; - private static final byte ASN1_OBJECT_IDENTIFIER = 0x06; - private static final byte ASN1_SEQUENCE = 0x30; + static final byte ASN1_INTEGER = 0x02; + static final byte ASN1_BIT_STRING = 0x03; + static final byte ASN1_OCTET_STRING = 0x04; + static final byte ASN1_NULL = 0x05; + static final byte ASN1_OBJECT_IDENTIFIER = 0x06; + static final byte ASN1_SEQUENCE = 0x30; + + /* ASN.1 Context-Specific Tags (Constructed) */ + static final byte ASN1_CONTEXT_SPECIFIC_0 = (byte)0xa0; + static final byte ASN1_CONTEXT_SPECIFIC_1 = (byte)0xa1; + static final byte ASN1_CONTEXT_SPECIFIC_2 = (byte)0xa2; + static final byte ASN1_CONTEXT_SPECIFIC_3 = (byte)0xa3; + + /* Hash Algorithm OIDs. No tag/length, just encoded value. */ + private static final byte[] OID_SHA1 = + { 0x2b, 0x0e, 0x03, 0x02, 0x1a }; + private static final byte[] OID_SHA224 = + { 0x60, (byte)0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04 }; + private static final byte[] OID_SHA256 = + { 0x60, (byte)0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01}; + private static final byte[] OID_SHA384 = + { 0x60, (byte)0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02 }; + private static final byte[] OID_SHA512 = + { 0x60, (byte)0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03 }; + private static final byte[] OID_SHA512_224 = + { 0x60, (byte)0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x05 }; + private static final byte[] OID_SHA512_256 = + { 0x60, (byte)0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x06 }; + private static final byte[] OID_SHA3_224 = + { 0x60, (byte)0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x07 }; + private static final byte[] OID_SHA3_256 = + { 0x60, (byte)0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x08 }; + private static final byte[] OID_SHA3_384 = + { 0x60, (byte)0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x09 }; + private static final byte[] OID_SHA3_512 = + { 0x60, (byte)0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x0a }; + + /* MGF1 OID: 1.2.840.113549.1.1.8 */ + private static final byte[] OID_MGF1 = + { 0x2a, (byte)0x86, 0x48, (byte)0x86, (byte)0xf7, + 0x0d, 0x01, 0x01, 0x08 }; /* DH Algorithm OID: 1.2.840.113549.1.3.1 (pkcs-3) */ private static final byte[] DH_ALGORITHM_OID = { @@ -49,14 +85,12 @@ public class WolfCryptASN1Util { (byte)0x01 }; - /** - * Private constructor, all methods are static. - */ + /** Private constructor, all methods are static. */ private WolfCryptASN1Util() { } /** - * Encode a BigInteger as a DER INTEGER. + * Encode BigInteger as a DER INTEGER. * * DER INTEGER format: * - Tag: 0x02 @@ -83,8 +117,8 @@ public static byte[] encodeDERInteger(BigInteger value) valueBytes = value.toByteArray(); /* BigInteger.toByteArray() handles sign bit correctly: - * - For positive numbers, adds leading 0x00 if MSB is set - * - For negative numbers (shouldn't happen), uses two's complement */ + * - Positive numbers, adds leading 0x00 if MSB is set + * - Negative numbers, uses two's complement */ out = new ByteArrayOutputStream(); try { @@ -480,5 +514,395 @@ public static byte[] bigIntegerToByteArray(BigInteger value) return bytes; } + + /** + * Encode an integer as a DER INTEGER. + * + * DER INTEGER format: + * - Tag: 0x02 + * - Length: variable + * - Value: big-endian bytes, with leading 0x00 if MSB is set + * + * @param value the integer value to encode (must be non-negative) + * + * @return DER-encoded INTEGER (tag + length + value) + * + * @throws IllegalArgumentException if value is negative + */ + static byte[] encodeDERInteger(int value) + throws IllegalArgumentException { + + ByteArrayOutputStream result; + + if (value < 0) { + throw new IllegalArgumentException( + "Negative integers not supported"); + } + + result = new ByteArrayOutputStream(); + + try { + result.write(ASN1_INTEGER); + + if (value == 0) { + result.write(0x01); + result.write(0x00); + } + else { + /* Determine minimum bytes needed */ + int temp = value; + int numBytes = 0; + while (temp > 0) { + numBytes++; + temp >>= 8; + } + + /* Build byte array */ + byte[] bytes = new byte[numBytes]; + int tempValue = value; + for (int i = numBytes - 1; i >= 0; i--) { + bytes[i] = (byte) (tempValue & 0xff); + tempValue >>= 8; + } + + /* Add padding byte if high bit set (to keep positive) */ + if ((bytes[0] & 0x80) != 0) { + result.write(encodeDERLength(numBytes + 1)); + result.write(0x00); + result.write(bytes); + } + else { + result.write(encodeDERLength(numBytes)); + result.write(bytes); + } + } + + return result.toByteArray(); + + } catch (IOException e) { + throw new IllegalArgumentException( + "Failed to encode INTEGER: " + e.getMessage(), e); + } + } + + /** + * Decode DER length from encoded bytes, returning length and + * new offset. + * + * Reads DER length encoding at the specified offset and returns + * an array containing [length, newOffset] where newOffset points + * to the first byte after the length encoding. + * + * @param data DER-encoded data + * @param offset index where length encoding starts + * + * @return int array with [length, newOffset] + * + * @throws IllegalArgumentException if data is null, offset invalid, + * or encoding is invalid + */ + static int[] decodeDERLengthWithOffset(byte[] data, int offset) + throws IllegalArgumentException { + + int firstByte, numBytes, length; + + if (data == null) { + throw new IllegalArgumentException("Data cannot be null"); + } + if (offset < 0 || offset >= data.length) { + throw new IllegalArgumentException( + "Invalid offset: " + offset); + } + + firstByte = data[offset++] & 0xff; + + if ((firstByte & 0x80) == 0) { + /* Short form */ + return new int[] {firstByte, offset}; + } + else { + /* Long form */ + numBytes = firstByte & 0x7f; + + if (numBytes == 0) { + throw new IllegalArgumentException( + "Indefinite length encoding not supported"); + } + + if (numBytes > 4) { + throw new IllegalArgumentException("Length too large"); + } + + if (offset + numBytes > data.length) { + throw new IllegalArgumentException( + "Invalid DER length: extends beyond data"); + } + + length = 0; + for (int i = 0; i < numBytes; i++) { + length = (length << 8) | (data[offset++] & 0xff); + } + + return new int[] {length, offset}; + } + } + + /** + * Decode a DER INTEGER value. + * + * Decodes a DER-encoded INTEGER (including tag and length) and + * returns the integer value. Only supports non-negative integers + * that fit in a Java int (32 bits). + * + * @param data DER-encoded INTEGER (including tag and length) + * + * @return the decoded integer value + * + * @throws IllegalArgumentException if data is null, invalid, + * or represents a negative number + */ + static int decodeDERInteger(byte[] data) + throws IllegalArgumentException { + + int idx, len, value; + int[] lenInfo; + + if (data == null || data.length < 3) { + throw new IllegalArgumentException( + "Invalid INTEGER: too short"); + } + + idx = 0; + + /* Check INTEGER tag */ + if (data[idx++] != ASN1_INTEGER) { + throw new IllegalArgumentException( + "Invalid INTEGER: expected tag 0x02"); + } + + /* Get length */ + lenInfo = decodeDERLengthWithOffset(data, idx); + len = lenInfo[0]; + idx = lenInfo[1]; + + if (len < 1 || len > 4) { + throw new IllegalArgumentException( + "Invalid INTEGER: length out of range"); + } + + if (idx + len > data.length) { + throw new IllegalArgumentException( + "Invalid INTEGER: value extends beyond data"); + } + + /* Check for negative integers (high bit set on first byte). + * We don't support negative integers. */ + if ((data[idx] & 0x80) != 0) { + throw new IllegalArgumentException( + "Negative integers not supported"); + } + + /* Decode value */ + value = 0; + for (int i = 0; i < len; i++) { + value = (value << 8) | (data[idx++] & 0xff); + } + + return value; + } + + /** + * Encode contents as a DER OBJECT IDENTIFIER. + * + * DER OBJECT IDENTIFIER format: + * - Tag: 0x06 + * - Length: variable + * - Contents: encoded OID bytes + * + * @param oidBytes the OID bytes (already encoded, without tag/length) + * + * @return DER-encoded OBJECT IDENTIFIER (tag + length + oidBytes) + * + * @throws IllegalArgumentException if oidBytes is null + */ + static byte[] encodeDERObjectIdentifier(byte[] oidBytes) + throws IllegalArgumentException { + + ByteArrayOutputStream out; + + if (oidBytes == null) { + throw new IllegalArgumentException("OID bytes cannot be null"); + } + + out = new ByteArrayOutputStream(); + + try { + out.write(ASN1_OBJECT_IDENTIFIER); + out.write(encodeDERLength(oidBytes.length)); + out.write(oidBytes); + + return out.toByteArray(); + + } catch (IOException e) { + throw new IllegalArgumentException( + "Failed to encode OBJECT IDENTIFIER: " + e.getMessage(), e); + } + } + + /** + * Encode NULL as a DER NULL. + * + * DER NULL format: + * - Tag: 0x05 + * - Length: 0x00 + * + * @return DER-encoded NULL (0x05 0x00) + */ + static byte[] encodeDERNull() { + return new byte[] {0x05, 0x00}; + } + + /** + * Get hash algorithm OID bytes. + * + * Returns the OID bytes (without tag and length) for the specified + * hash algorithm. + * + * @param hashAlgorithm the hash algorithm name + * (e.g., "SHA-1", "SHA-256", etc.) + * + * @return OID bytes (cloned array) + * + * @throws IllegalArgumentException if hash algorithm not supported + */ + static byte[] getHashAlgorithmOID(String hashAlgorithm) + throws IllegalArgumentException { + + if (hashAlgorithm == null) { + throw new IllegalArgumentException( + "Hash algorithm cannot be null"); + } + + switch (hashAlgorithm.toUpperCase()) { + case "SHA-1": + return OID_SHA1.clone(); + case "SHA-224": + return OID_SHA224.clone(); + case "SHA-256": + return OID_SHA256.clone(); + case "SHA-384": + return OID_SHA384.clone(); + case "SHA-512": + return OID_SHA512.clone(); + case "SHA-512/224": + return OID_SHA512_224.clone(); + case "SHA-512/256": + return OID_SHA512_256.clone(); + case "SHA3-224": + return OID_SHA3_224.clone(); + case "SHA3-256": + return OID_SHA3_256.clone(); + case "SHA3-384": + return OID_SHA3_384.clone(); + case "SHA3-512": + return OID_SHA3_512.clone(); + default: + throw new IllegalArgumentException( + "Unsupported hash algorithm: " + hashAlgorithm); + } + } + + /** + * Get hash algorithm name from OID bytes. + * + * Returns the hash algorithm name for the given OID bytes + * (without tag and length). + * + * @param oidBytes the OID bytes + * + * @return hash algorithm name + * + * @throws IllegalArgumentException if OID not recognized + */ + static String getHashAlgorithmName(byte[] oidBytes) + throws IllegalArgumentException { + + if (oidBytes == null) { + throw new IllegalArgumentException("OID bytes cannot be null"); + } + + if (bytesEqual(oidBytes, OID_SHA1)) { + return "SHA-1"; + } + else if (bytesEqual(oidBytes, OID_SHA224)) { + return "SHA-224"; + } + else if (bytesEqual(oidBytes, OID_SHA256)) { + return "SHA-256"; + } + else if (bytesEqual(oidBytes, OID_SHA384)) { + return "SHA-384"; + } + else if (bytesEqual(oidBytes, OID_SHA512)) { + return "SHA-512"; + } + else if (bytesEqual(oidBytes, OID_SHA512_224)) { + return "SHA-512/224"; + } + else if (bytesEqual(oidBytes, OID_SHA512_256)) { + return "SHA-512/256"; + } + else if (bytesEqual(oidBytes, OID_SHA3_224)) { + return "SHA3-224"; + } + else if (bytesEqual(oidBytes, OID_SHA3_256)) { + return "SHA3-256"; + } + else if (bytesEqual(oidBytes, OID_SHA3_384)) { + return "SHA3-384"; + } + else if (bytesEqual(oidBytes, OID_SHA3_512)) { + return "SHA3-512"; + } + else { + throw new IllegalArgumentException("Unrecognized hash OID"); + } + } + + /** + * Get MGF1 algorithm OID bytes. + * + * Returns the OID bytes (without tag and length) for MGF1. + * + * @return MGF1 OID bytes (cloned array) + */ + static byte[] getMGF1OID() { + return OID_MGF1.clone(); + } + + /** + * Compare two byte arrays for equality, constant time. + * + * @param a first array + * @param b second array + * + * @return true if arrays are equal, false otherwise + */ + static boolean bytesEqual(byte[] a, byte[] b) { + int result = 0, i = 0; + + if (a == null || b == null || a.length != b.length) { + return false; + } + + for (i = 0; i < a.length; i++) { + result |= a[i] ^ b[i]; + } + + if (result == 0) { + return true; + } + return false; + } } diff --git a/src/main/java/com/wolfssl/provider/jce/WolfCryptPssParameters.java b/src/main/java/com/wolfssl/provider/jce/WolfCryptPssParameters.java index f4e01f48..7af2896f 100644 --- a/src/main/java/com/wolfssl/provider/jce/WolfCryptPssParameters.java +++ b/src/main/java/com/wolfssl/provider/jce/WolfCryptPssParameters.java @@ -21,6 +21,7 @@ package com.wolfssl.provider.jce; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.security.AlgorithmParametersSpi; import java.security.spec.AlgorithmParameterSpec; @@ -54,22 +55,33 @@ public WolfCryptPssParameters() { @Override protected void engineInit(AlgorithmParameterSpec paramSpec) - throws InvalidParameterSpecException { + throws InvalidParameterSpecException { + + PSSParameterSpec pss; if (!(paramSpec instanceof PSSParameterSpec)) { throw new InvalidParameterSpecException( "Only PSSParameterSpec supported"); } - PSSParameterSpec pss = (PSSParameterSpec)paramSpec; + pss = (PSSParameterSpec)paramSpec; validatePSSParameters(pss); this.pssSpec = pss; } @Override protected void engineInit(byte[] params) throws IOException { - /* ASN.1 DER decoding would be implemented here */ - throw new IOException("DER encoding/decoding not yet implemented"); + if (params == null || params.length == 0) { + throw new IOException("Parameters cannot be null or empty"); + } + + try { + this.pssSpec = decodePssParameters(params); + validatePSSParameters(this.pssSpec); + } catch (InvalidParameterSpecException e) { + throw new IOException( + "Failed to decode PSS parameters: " + e.getMessage(), e); + } } @Override @@ -99,8 +111,11 @@ protected T engineGetParameterSpec( @Override protected byte[] engineGetEncoded() throws IOException { - /* ASN.1 DER encoding would be implemented here */ - throw new IOException("DER encoding/decoding not yet implemented"); + if (this.pssSpec == null) { + throw new IOException("PSS parameters not initialized"); + } + + return encodePssParameters(this.pssSpec); } @Override @@ -210,4 +225,375 @@ private boolean isDigestSupported(String digestAlg) { return false; } } + + + /** + * Encode PSS parameters to DER format (RFC 4055). + * + * @param spec The PSSParameterSpec to encode + * + * @return DER-encoded PSS parameters + * + * @throws IOException if encoding fails + */ + private byte[] encodePssParameters(PSSParameterSpec spec) + throws IOException { + + int saltLen, trailer; + String digestAlg, mgfDigest; + ByteArrayOutputStream seq = new ByteArrayOutputStream(); + MGF1ParameterSpec mgf1Spec; + byte[] seqBytes; + + /* Get hash algorithm name */ + digestAlg = spec.getDigestAlgorithm(); + saltLen = spec.getSaltLength(); + trailer = spec.getTrailerField(); + + /* Get MGF digest algorithm from MGF1ParameterSpec. + * Per RFC 4055, MGF1 is the only supported MGF algorithm. */ + if (!(spec.getMGFParameters() instanceof MGF1ParameterSpec)) { + throw new IOException( + "MGF parameters must be MGF1ParameterSpec"); + } + + mgf1Spec = (MGF1ParameterSpec) spec.getMGFParameters(); + mgfDigest = mgf1Spec.getDigestAlgorithm(); + + /* Encode hashAlgorithm [0] if not default (SHA-1). If default, + * we omit from encoding. */ + if (!digestAlg.equalsIgnoreCase("SHA-1")) { + byte[] hashAlgId = encodeAlgorithmIdentifier(digestAlg); + seq.write(WolfCryptASN1Util.ASN1_CONTEXT_SPECIFIC_0); + seq.write(WolfCryptASN1Util.encodeDERLength(hashAlgId.length)); + seq.write(hashAlgId); + } + + /* Encode maskGenAlgorithm [1] if not default (mgf1SHA1). If default, + * we omit from encoding. */ + if (!mgfDigest.equalsIgnoreCase("SHA-1")) { + byte[] mgfHashAlgId = encodeAlgorithmIdentifier(mgfDigest); + byte[] mgfAlgId = encodeMGF1AlgorithmIdentifier(mgfHashAlgId); + seq.write(WolfCryptASN1Util.ASN1_CONTEXT_SPECIFIC_1); + seq.write(WolfCryptASN1Util.encodeDERLength(mgfAlgId.length)); + seq.write(mgfAlgId); + } + + /* Encode saltLength [2] if not default (20) */ + if (saltLen != 20) { + byte[] saltLenBytes = WolfCryptASN1Util.encodeDERInteger(saltLen); + seq.write(WolfCryptASN1Util.ASN1_CONTEXT_SPECIFIC_2); + seq.write(WolfCryptASN1Util.encodeDERLength(saltLenBytes.length)); + seq.write(saltLenBytes); + } + + /* Encode trailerField [3] if not default (1) */ + if (trailer != 1) { + byte[] trailerBytes = WolfCryptASN1Util.encodeDERInteger(trailer); + seq.write(WolfCryptASN1Util.ASN1_CONTEXT_SPECIFIC_3); + seq.write(WolfCryptASN1Util.encodeDERLength(trailerBytes.length)); + seq.write(trailerBytes); + } + + /* Wrap in SEQUENCE */ + seqBytes = seq.toByteArray(); + + return WolfCryptASN1Util.encodeDERSequence(seqBytes); + } + + /** + * Decode DER-encoded PSS parameters (RFC 4055). + * + * @param params DER-encoded PSS parameters + * + * @return Decoded PSSParameterSpec + * + * @throws IOException if decoding fails + */ + private PSSParameterSpec decodePssParameters(byte[] params) + throws IOException { + + int idx = 0, seqLen = 0; + int[] lenInfo; + + /* Defaults, per RFC 4055 */ + String digestAlg = "SHA-1"; + String mgfDigest = "SHA-1"; + int saltLen = 20; + int trailer = 1; + + if (params == null || params.length < 2) { + throw new IOException("Invalid PSS parameters: too short"); + } + + /* Check SEQUENCE tag */ + if (params[idx++] != WolfCryptASN1Util.ASN1_SEQUENCE) { + throw new IOException( + "Invalid PSS parameters: expected SEQUENCE"); + } + + /* Get SEQUENCE length */ + lenInfo = WolfCryptASN1Util.decodeDERLengthWithOffset(params, idx); + seqLen = lenInfo[0]; + idx = lenInfo[1]; + + if (idx + seqLen != params.length) { + throw new IOException( + "Invalid PSS parameters: incorrect length"); + } + + /* Parse optional fields */ + while (idx < params.length) { + byte tag = params[idx++]; + + /* Get field length */ + lenInfo = + WolfCryptASN1Util.decodeDERLengthWithOffset(params, idx); + int fieldLen = lenInfo[0]; + idx = lenInfo[1]; + + if (idx + fieldLen > params.length) { + throw new IOException( + "Invalid PSS parameters: field extends beyond data"); + } + + byte[] fieldData = new byte[fieldLen]; + System.arraycopy(params, idx, fieldData, 0, fieldLen); + idx += fieldLen; + + if (tag == WolfCryptASN1Util.ASN1_CONTEXT_SPECIFIC_0) { + /* hashAlgorithm [0] */ + digestAlg = decodeAlgorithmIdentifier(fieldData); + } + else if (tag == WolfCryptASN1Util.ASN1_CONTEXT_SPECIFIC_1) { + /* maskGenAlgorithm [1] */ + mgfDigest = decodeMGF1AlgorithmIdentifier(fieldData); + } + else if (tag == WolfCryptASN1Util.ASN1_CONTEXT_SPECIFIC_2) { + /* saltLength [2] */ + saltLen = WolfCryptASN1Util.decodeDERInteger(fieldData); + } + else if (tag == WolfCryptASN1Util.ASN1_CONTEXT_SPECIFIC_3) { + /* trailerField [3] */ + trailer = WolfCryptASN1Util.decodeDERInteger(fieldData); + } + else { + throw new IOException( + "Invalid PSS parameters: unknown tag 0x" + + Integer.toHexString(tag & 0xff)); + } + } + + /* Validate trailer field */ + if (trailer != 1) { + throw new IOException( + "Invalid PSS parameters: trailerField must be 1"); + } + + /* Create and return PSSParameterSpec */ + return new PSSParameterSpec(digestAlg, "MGF1", + new MGF1ParameterSpec(mgfDigest), saltLen, trailer + ); + } + + /** + * Encode AlgorithmIdentifier for a hash algorithm. + * + * @param digestAlg The digest algorithm name + * + * @return DER-encoded AlgorithmIdentifier + * + * @throws IOException if encoding fails + */ + private byte[] encodeAlgorithmIdentifier(String digestAlg) + throws IOException { + + byte[] oid = getHashOID(digestAlg); + + ByteArrayOutputStream seq = new ByteArrayOutputStream(); + + /* Write OID */ + seq.write(WolfCryptASN1Util.encodeDERObjectIdentifier(oid)); + + /* Write NULL parameters */ + seq.write(WolfCryptASN1Util.encodeDERNull()); + + /* Wrap in SEQUENCE */ + return WolfCryptASN1Util.encodeDERSequence(seq.toByteArray()); + } + + /** + * Decode AlgorithmIdentifier and extract the hash algorithm name. + * + * @param data DER-encoded AlgorithmIdentifier + * + * @return Hash algorithm name + * + * @throws IOException if decoding fails + */ + private String decodeAlgorithmIdentifier(byte[] data) throws IOException { + + int idx = 0, oidLen = 0; + int[] lenInfo; + byte[] oid; + + if (data == null || data.length < 2) { + throw new IOException("Invalid AlgorithmIdentifier: too short"); + } + + /* Check SEQUENCE tag */ + if (data[idx++] != WolfCryptASN1Util.ASN1_SEQUENCE) { + throw new IOException( + "Invalid AlgorithmIdentifier: expected SEQUENCE"); + } + + /* Skip SEQUENCE length */ + lenInfo = WolfCryptASN1Util.decodeDERLengthWithOffset(data, idx); + idx = lenInfo[1]; + + /* Check OBJECT IDENTIFIER tag */ + if (idx >= data.length || + data[idx++] != WolfCryptASN1Util.ASN1_OBJECT_IDENTIFIER) { + throw new IOException( + "Invalid AlgorithmIdentifier: expected OBJECT IDENTIFIER"); + } + + /* Get OID length */ + lenInfo = WolfCryptASN1Util.decodeDERLengthWithOffset(data, idx); + oidLen = lenInfo[0]; + idx = lenInfo[1]; + + if (idx + oidLen > data.length) { + throw new IOException( + "Invalid AlgorithmIdentifier: OID extends beyond data"); + } + + /* Extract OID bytes */ + oid = new byte[oidLen]; + System.arraycopy(data, idx, oid, 0, oidLen); + idx += oidLen; + + /* Verify NULL parameters follow OID (per RFC 4055). + * Hash AlgorithmIdentifiers have NULL parameters. */ + if (idx + 2 > data.length) { + throw new IOException( + "Invalid AlgorithmIdentifier: missing parameters"); + } + + if (data[idx] != WolfCryptASN1Util.ASN1_NULL || + data[idx + 1] != 0x00) { + throw new IOException( + "Invalid AlgorithmIdentifier: expected NULL parameters"); + } + + /* Map OID to hash algorithm name */ + return WolfCryptASN1Util.getHashAlgorithmName(oid); + } + + /** + * Encode MGF1 AlgorithmIdentifier with embedded hash AlgorithmIdentifier. + * + * @param hashAlgId DER-encoded hash AlgorithmIdentifier + * + * @return DER-encoded MGF1 AlgorithmIdentifier + * + * @throws IOException if encoding fails + */ + private byte[] encodeMGF1AlgorithmIdentifier(byte[] hashAlgId) + throws IOException { + + ByteArrayOutputStream seq = new ByteArrayOutputStream(); + + /* Write MGF1 OID */ + seq.write(WolfCryptASN1Util.encodeDERObjectIdentifier( + WolfCryptASN1Util.getMGF1OID())); + + /* Write hash AlgorithmIdentifier as parameters */ + seq.write(hashAlgId); + + /* Wrap in SEQUENCE */ + return WolfCryptASN1Util.encodeDERSequence(seq.toByteArray()); + } + + /** + * Decode MGF1 AlgorithmIdentifier and extract hash algorithm name. + * + * @param data DER-encoded MGF1 AlgorithmIdentifier + * + * @return Hash algorithm name for MGF + * + * @throws IOException if decoding fails + */ + private String decodeMGF1AlgorithmIdentifier(byte[] data) + throws IOException { + + int idx = 0, oidLen = 0; + int[] lenInfo; + byte[] oid, hashAlgId; + + if (data == null || data.length < 2) { + throw new IOException("Invalid MGF1 AlgorithmIdentifier"); + } + + /* Check SEQUENCE tag */ + if (data[idx++] != WolfCryptASN1Util.ASN1_SEQUENCE) { + throw new IOException( + "Invalid MGF1 AlgorithmIdentifier: expected SEQUENCE"); + } + + /* Skip SEQUENCE length */ + lenInfo = WolfCryptASN1Util.decodeDERLengthWithOffset(data, idx); + idx = lenInfo[1]; + + /* Check MGF1 OID */ + if (idx >= data.length || + data[idx++] != WolfCryptASN1Util.ASN1_OBJECT_IDENTIFIER) { + throw new IOException( + "Invalid MGF1 AlgorithmIdentifier: expected OID"); + } + + lenInfo = WolfCryptASN1Util.decodeDERLengthWithOffset(data, idx); + oidLen = lenInfo[0]; + idx = lenInfo[1]; + + if (idx + oidLen > data.length) { + throw new IOException("Invalid MGF1 AlgorithmIdentifier"); + } + + /* Verify it's the MGF1 OID */ + oid = new byte[oidLen]; + System.arraycopy(data, idx, oid, 0, oidLen); + idx += oidLen; + + if (!WolfCryptASN1Util.bytesEqual(oid, + WolfCryptASN1Util.getMGF1OID())) { + throw new IOException( + "Invalid MGF1 AlgorithmIdentifier: not MGF1 OID"); + } + + /* Decode embedded hash AlgorithmIdentifier */ + hashAlgId = new byte[data.length - idx]; + System.arraycopy(data, idx, hashAlgId, 0, hashAlgId.length); + + return decodeAlgorithmIdentifier(hashAlgId); + } + + /** + * Get OID bytes for a hash algorithm name. + * + * @param digestAlg The hash algorithm name + * + * @return OID bytes + * + * @throws IOException if algorithm not supported + */ + private byte[] getHashOID(String digestAlg) throws IOException { + try { + return WolfCryptASN1Util.getHashAlgorithmOID(digestAlg); + + } catch (IllegalArgumentException e) { + throw new IOException(e.getMessage(), e); + } + } } + diff --git a/src/test/java/com/wolfssl/provider/jce/test/WolfCryptAlgorithmParametersTest.java b/src/test/java/com/wolfssl/provider/jce/test/WolfCryptAlgorithmParametersTest.java index 13a89df6..49ba0752 100644 --- a/src/test/java/com/wolfssl/provider/jce/test/WolfCryptAlgorithmParametersTest.java +++ b/src/test/java/com/wolfssl/provider/jce/test/WolfCryptAlgorithmParametersTest.java @@ -41,6 +41,9 @@ import java.security.spec.AlgorithmParameterSpec; import java.util.Arrays; +import java.security.spec.PSSParameterSpec; +import java.security.spec.MGF1ParameterSpec; + import javax.crypto.spec.DHParameterSpec; import javax.crypto.spec.IvParameterSpec; @@ -595,5 +598,572 @@ public void testDHParametersRoundTrip() /* Verify encodings are identical */ assertTrue(Arrays.equals(encoded1, encoded2)); } + + @Test + public void testRSAPSSParametersGetInstance() + throws NoSuchProviderException, NoSuchAlgorithmException { + + AlgorithmParameters params; + params = AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + assertNotNull(params); + } + + @Test + public void testRSAPSSParametersInitWithDefaultSpec() + throws Exception { + + /* Create default PSS parameters (SHA-256, MGF1-SHA256, salt=32) */ + PSSParameterSpec spec = new PSSParameterSpec( + "SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + assertNotNull(params); + + params.init(spec); + + /* Retrieve parameters back */ + PSSParameterSpec retrievedSpec = + params.getParameterSpec(PSSParameterSpec.class); + assertNotNull(retrievedSpec); + assertEquals("SHA-256", retrievedSpec.getDigestAlgorithm()); + assertEquals("MGF1", retrievedSpec.getMGFAlgorithm()); + assertEquals(32, retrievedSpec.getSaltLength()); + assertEquals(1, retrievedSpec.getTrailerField()); + + /* Check MGF1 parameters */ + assertTrue( + retrievedSpec.getMGFParameters() instanceof MGF1ParameterSpec); + MGF1ParameterSpec mgf1Spec = + (MGF1ParameterSpec) retrievedSpec.getMGFParameters(); + assertEquals("SHA-256", mgf1Spec.getDigestAlgorithm()); + } + + @Test + public void testRSAPSSParametersInitWithSHA1Spec() + throws Exception { + + /* SHA-1 with MGF1-SHA1, salt=20 */ + PSSParameterSpec spec = new PSSParameterSpec( + "SHA-1", "MGF1", MGF1ParameterSpec.SHA1, 20, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params.init(spec); + + PSSParameterSpec retrievedSpec = + params.getParameterSpec(PSSParameterSpec.class); + assertEquals("SHA-1", retrievedSpec.getDigestAlgorithm()); + MGF1ParameterSpec mgf1Spec = + (MGF1ParameterSpec) retrievedSpec.getMGFParameters(); + assertEquals("SHA-1", mgf1Spec.getDigestAlgorithm()); + assertEquals(20, retrievedSpec.getSaltLength()); + } + + @Test + public void testRSAPSSParametersInitWithSHA384Spec() + throws Exception { + + /* SHA-384 with MGF1-SHA384, salt=48 */ + PSSParameterSpec spec = new PSSParameterSpec( + "SHA-384", "MGF1", MGF1ParameterSpec.SHA384, 48, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params.init(spec); + + PSSParameterSpec retrievedSpec = + params.getParameterSpec(PSSParameterSpec.class); + assertEquals("SHA-384", retrievedSpec.getDigestAlgorithm()); + MGF1ParameterSpec mgf1Spec = + (MGF1ParameterSpec) retrievedSpec.getMGFParameters(); + assertEquals("SHA-384", mgf1Spec.getDigestAlgorithm()); + assertEquals(48, retrievedSpec.getSaltLength()); + } + + @Test + public void testRSAPSSParametersInitWithSHA512Spec() + throws Exception { + + /* SHA-512 with MGF1-SHA512, salt=64 */ + PSSParameterSpec spec = new PSSParameterSpec( + "SHA-512", "MGF1", MGF1ParameterSpec.SHA512, 64, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params.init(spec); + + PSSParameterSpec retrievedSpec = + params.getParameterSpec(PSSParameterSpec.class); + assertEquals("SHA-512", retrievedSpec.getDigestAlgorithm()); + MGF1ParameterSpec mgf1Spec = + (MGF1ParameterSpec) retrievedSpec.getMGFParameters(); + assertEquals("SHA-512", mgf1Spec.getDigestAlgorithm()); + assertEquals(64, retrievedSpec.getSaltLength()); + } + + @Test + public void testRSAPSSParametersEncodingDERWithDefaults() + throws Exception { + + /* RFC 4055 defaults: SHA-1, MGF1-SHA1, salt=20 */ + PSSParameterSpec spec = new PSSParameterSpec( + "SHA-1", "MGF1", MGF1ParameterSpec.SHA1, 20, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params.init(spec); + + /* Get DER encoding */ + byte[] encoded = params.getEncoded(); + assertNotNull(encoded); + assertTrue(encoded.length > 0); + + /* With all defaults, should be minimal SEQUENCE */ + /* SEQUENCE { } = 0x30 0x00 */ + assertEquals(0x30, encoded[0] & 0xFF); + assertEquals(0x00, encoded[1] & 0xFF); + + /* Decode and verify */ + AlgorithmParameters params2 = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params2.init(encoded); + + PSSParameterSpec spec2 = + params2.getParameterSpec(PSSParameterSpec.class); + assertEquals("SHA-1", spec2.getDigestAlgorithm()); + assertEquals(20, spec2.getSaltLength()); + } + + @Test + public void testRSAPSSParametersEncodingDERWithNonDefaults() + throws Exception { + + /* Non-default values: SHA-256, MGF1-SHA256, salt=32 */ + PSSParameterSpec spec = new PSSParameterSpec( + "SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params.init(spec); + + /* Get DER encoding */ + byte[] encoded = params.getEncoded(); + assertNotNull(encoded); + assertTrue(encoded.length > 0); + + /* Should start with SEQUENCE tag */ + assertEquals(0x30, encoded[0] & 0xFF); + + /* Should be longer than minimal since we have non-defaults */ + assertTrue(encoded.length > 2); + + /* Decode and verify round trip */ + AlgorithmParameters params2 = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params2.init(encoded); + + PSSParameterSpec spec2 = + params2.getParameterSpec(PSSParameterSpec.class); + assertEquals("SHA-256", spec2.getDigestAlgorithm()); + assertEquals("MGF1", spec2.getMGFAlgorithm()); + assertEquals(32, spec2.getSaltLength()); + assertEquals(1, spec2.getTrailerField()); + + MGF1ParameterSpec mgf1Spec = + (MGF1ParameterSpec) spec2.getMGFParameters(); + assertEquals("SHA-256", mgf1Spec.getDigestAlgorithm()); + } + + @Test + public void testRSAPSSParametersEncodingAllHashAlgorithms() + throws Exception { + + /* Test all supported hash algorithms */ + String[] hashAlgs = { + "SHA-1", "SHA-224", "SHA-256", "SHA-384", "SHA-512", + "SHA-512/224", "SHA-512/256" + }; + + MGF1ParameterSpec[] mgf1Specs = { + MGF1ParameterSpec.SHA1, + MGF1ParameterSpec.SHA224, + MGF1ParameterSpec.SHA256, + MGF1ParameterSpec.SHA384, + MGF1ParameterSpec.SHA512, + new MGF1ParameterSpec("SHA-512/224"), + new MGF1ParameterSpec("SHA-512/256") + }; + + for (int i = 0; i < hashAlgs.length; i++) { + PSSParameterSpec spec = new PSSParameterSpec( + hashAlgs[i], "MGF1", mgf1Specs[i], 32, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params.init(spec); + + byte[] encoded = params.getEncoded(); + assertNotNull(encoded); + + /* Decode and verify */ + AlgorithmParameters params2 = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params2.init(encoded); + + PSSParameterSpec spec2 = + params2.getParameterSpec(PSSParameterSpec.class); + assertEquals(hashAlgs[i], spec2.getDigestAlgorithm()); + + MGF1ParameterSpec mgf1Retrieved = + (MGF1ParameterSpec) spec2.getMGFParameters(); + assertEquals(hashAlgs[i], mgf1Retrieved.getDigestAlgorithm()); + } + } + + @Test + public void testRSAPSSParametersEncodingVariousSaltLengths() + throws Exception { + + int[] saltLengths = {0, 1, 16, 20, 32, 48, 64, 128, 255}; + + for (int saltLen : saltLengths) { + PSSParameterSpec spec = new PSSParameterSpec( + "SHA-256", "MGF1", MGF1ParameterSpec.SHA256, saltLen, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params.init(spec); + + byte[] encoded = params.getEncoded(); + assertNotNull(encoded); + + /* Decode and verify salt length preserved */ + AlgorithmParameters params2 = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params2.init(encoded); + + PSSParameterSpec spec2 = + params2.getParameterSpec(PSSParameterSpec.class); + assertEquals(saltLen, spec2.getSaltLength()); + } + } + + @Test + public void testRSAPSSParametersEncodingMixedHashAlgorithms() + throws Exception { + + /* Test with different hash for digest and MGF */ + PSSParameterSpec spec = new PSSParameterSpec( + "SHA-256", "MGF1", MGF1ParameterSpec.SHA1, 32, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params.init(spec); + + byte[] encoded = params.getEncoded(); + + AlgorithmParameters params2 = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params2.init(encoded); + + PSSParameterSpec spec2 = + params2.getParameterSpec(PSSParameterSpec.class); + assertEquals("SHA-256", spec2.getDigestAlgorithm()); + + MGF1ParameterSpec mgf1Spec = + (MGF1ParameterSpec) spec2.getMGFParameters(); + assertEquals("SHA-1", mgf1Spec.getDigestAlgorithm()); + } + + @Test + public void testRSAPSSParametersRoundTrip() + throws Exception { + + PSSParameterSpec originalSpec = new PSSParameterSpec( + "SHA-384", "MGF1", MGF1ParameterSpec.SHA384, 48, 1 + ); + + /* Round trip 1: spec -> params -> encoded */ + AlgorithmParameters params1 = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params1.init(originalSpec); + byte[] encoded1 = params1.getEncoded(); + + /* Round trip 2: encoded -> params -> encoded */ + AlgorithmParameters params2 = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params2.init(encoded1); + byte[] encoded2 = params2.getEncoded(); + + /* Round trip 3: encoded -> params -> spec */ + AlgorithmParameters params3 = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params3.init(encoded2); + PSSParameterSpec finalSpec = + params3.getParameterSpec(PSSParameterSpec.class); + + /* Verify all parameters match original */ + assertEquals(originalSpec.getDigestAlgorithm(), + finalSpec.getDigestAlgorithm()); + assertEquals(originalSpec.getMGFAlgorithm(), + finalSpec.getMGFAlgorithm()); + assertEquals(originalSpec.getSaltLength(), + finalSpec.getSaltLength()); + assertEquals(originalSpec.getTrailerField(), + finalSpec.getTrailerField()); + + /* Verify encodings are identical */ + assertTrue(Arrays.equals(encoded1, encoded2)); + } + + @Test + public void testRSAPSSParametersInitWithInvalidParameterSpec() + throws Exception { + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + + /* Try to initialize with wrong type of ParameterSpec */ + try { + DHParameterSpec invalidSpec = + new DHParameterSpec( + BigInteger.valueOf(123), BigInteger.valueOf(2)); + params.init(invalidSpec); + + fail("AlgorithmParameters.init should throw " + + "InvalidParameterSpecException when given wrong spec type"); + + } catch (InvalidParameterSpecException e) { + /* expected */ + } + } + + @Test + public void testRSAPSSParametersInitWithNullBytes() + throws Exception { + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + + try { + params.init((byte[]) null); + fail("Should throw IOException for null parameters"); + } catch (IOException e) { + /* expected */ + } + } + + @Test + public void testRSAPSSParametersInitWithEmptyBytes() + throws Exception { + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + + try { + params.init(new byte[0]); + fail("Should throw IOException for empty parameters"); + } catch (IOException e) { + /* expected */ + } + } + + @Test + public void testRSAPSSParametersInitWithInvalidDERBytes() + throws Exception { + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + + /* Invalid DER: not a SEQUENCE */ + byte[] invalidDER = new byte[] {0x02, 0x01, 0x00}; + + try { + params.init(invalidDER); + fail("Should throw IOException for invalid DER encoding"); + } catch (IOException e) { + /* expected */ + } + } + + @Test + public void testRSAPSSParametersInitWithTruncatedDER() + throws Exception { + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + + /* SEQUENCE with length > actual data */ + byte[] truncatedDER = new byte[] {0x30, 0x10, 0x00}; + + try { + params.init(truncatedDER); + fail("Should throw IOException for truncated DER encoding"); + } catch (IOException e) { + /* expected */ + } + } + + @Test + public void testRSAPSSParametersGetEncodedBeforeInit() + throws Exception { + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + + try { + params.getEncoded(); + fail("Should throw IOException when getEncoded " + + "called before init"); + } catch (IOException e) { + /* expected */ + } + } + + @Test + public void testRSAPSSParametersGetParameterSpecBeforeInit() + throws Exception { + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + + try { + params.getParameterSpec(PSSParameterSpec.class); + fail("Should throw InvalidParameterSpecException " + + "when called before init"); + } catch (InvalidParameterSpecException e) { + /* expected */ + } + } + + @Test + public void testRSAPSSParametersGetParameterSpecWithWrongClass() + throws Exception { + + PSSParameterSpec spec = new PSSParameterSpec( + "SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params.init(spec); + + try { + params.getParameterSpec(DHParameterSpec.class); + fail("Should throw InvalidParameterSpecException for wrong class"); + } catch (InvalidParameterSpecException e) { + /* expected */ + } + } + + @Test + public void testRSAPSSParametersGetParameterSpecWithNull() + throws Exception { + + PSSParameterSpec spec = new PSSParameterSpec( + "SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params.init(spec); + + try { + params.getParameterSpec(null); + fail("Should throw InvalidParameterSpecException for null class"); + } catch (InvalidParameterSpecException e) { + /* expected */ + } + } + + @Test + public void testRSAPSSParametersToString() + throws Exception { + + PSSParameterSpec spec = new PSSParameterSpec( + "SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params.init(spec); + + String str = params.toString(); + assertNotNull(str); + assertTrue(str.contains("PSS Parameters")); + assertTrue(str.contains("SHA-256")); + } + + @Test + public void testRSAPSSParametersEncodingWithFormat() + throws Exception { + + PSSParameterSpec spec = new PSSParameterSpec( + "SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params.init(spec); + + /* Test ASN.1 format */ + byte[] asnEncoded = params.getEncoded("ASN.1"); + assertNotNull(asnEncoded); + + /* ASN.1 and default should be the same */ + byte[] defaultEncoded = params.getEncoded(); + assertTrue(Arrays.equals(asnEncoded, defaultEncoded)); + + /* Test case insensitivity */ + byte[] lowerEncoded = params.getEncoded("asn.1"); + assertTrue(Arrays.equals(asnEncoded, lowerEncoded)); + } + + @Test + public void testRSAPSSParametersEncodingUnsupportedFormat() + throws Exception { + + PSSParameterSpec spec = new PSSParameterSpec( + "SHA-256", "MGF1", MGF1ParameterSpec.SHA256, 32, 1 + ); + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + params.init(spec); + + try { + params.getEncoded("PEM"); + fail("Should throw IOException for unsupported format"); + } catch (IOException e) { + /* expected */ + } + } + + @Test + public void testRSAPSSParametersInitWithUnsupportedFormat() + throws Exception { + + AlgorithmParameters params = + AlgorithmParameters.getInstance("RSASSA-PSS", "wolfJCE"); + + byte[] data = new byte[] {0x30, 0x00}; + + try { + params.init(data, "PEM"); + fail("Should throw IOException for unsupported format"); + } catch (IOException e) { + /* expected */ + } + } }