Skip to content

Commit 58dd0ef

Browse files
committed
std.crypto.tls: rework for new std.Io API
1 parent f4dbfdc commit 58dd0ef

File tree

4 files changed

+498
-945
lines changed

4 files changed

+498
-945
lines changed

lib/std/Io/Reader.zig

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,31 +1315,6 @@ pub fn defaultRebase(r: *Reader, capacity: usize) RebaseError!void {
13151315
r.end = data.len;
13161316
}
13171317

1318-
/// Advances the stream and decreases the size of the storage buffer by `n`,
1319-
/// returning the range of bytes no longer accessible by `r`.
1320-
///
1321-
/// This action can be undone by `restitute`.
1322-
///
1323-
/// Asserts there are at least `n` buffered bytes already.
1324-
///
1325-
/// Asserts that `r.seek` is zero, i.e. the buffer is in a rebased state.
1326-
pub fn steal(r: *Reader, n: usize) []u8 {
1327-
assert(r.seek == 0);
1328-
assert(n <= r.end);
1329-
const stolen = r.buffer[0..n];
1330-
r.buffer = r.buffer[n..];
1331-
r.end -= n;
1332-
return stolen;
1333-
}
1334-
1335-
/// Expands the storage buffer, undoing the effects of `steal`
1336-
/// Assumes that `n` does not exceed the total number of stolen bytes.
1337-
pub fn restitute(r: *Reader, n: usize) void {
1338-
r.buffer = (r.buffer.ptr - n)[0 .. r.buffer.len + n];
1339-
r.end += n;
1340-
r.seek += n;
1341-
}
1342-
13431318
test fixed {
13441319
var r: Reader = .fixed("a\x02");
13451320
try testing.expect((try r.takeByte()) == 'a');

lib/std/crypto/tls.zig

Lines changed: 106 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ pub const hello_retry_request_sequence = [32]u8{
4949
};
5050

5151
pub const close_notify_alert = [_]u8{
52-
@intFromEnum(AlertLevel.warning),
53-
@intFromEnum(AlertDescription.close_notify),
52+
@intFromEnum(Alert.Level.warning),
53+
@intFromEnum(Alert.Description.close_notify),
5454
};
5555

5656
pub const ProtocolVersion = enum(u16) {
@@ -138,103 +138,108 @@ pub const ExtensionType = enum(u16) {
138138
_,
139139
};
140140

141-
pub const AlertLevel = enum(u8) {
142-
warning = 1,
143-
fatal = 2,
144-
_,
145-
};
141+
pub const Alert = struct {
142+
level: Level,
143+
description: Description,
146144

147-
pub const AlertDescription = enum(u8) {
148-
pub const Error = error{
149-
TlsAlertUnexpectedMessage,
150-
TlsAlertBadRecordMac,
151-
TlsAlertRecordOverflow,
152-
TlsAlertHandshakeFailure,
153-
TlsAlertBadCertificate,
154-
TlsAlertUnsupportedCertificate,
155-
TlsAlertCertificateRevoked,
156-
TlsAlertCertificateExpired,
157-
TlsAlertCertificateUnknown,
158-
TlsAlertIllegalParameter,
159-
TlsAlertUnknownCa,
160-
TlsAlertAccessDenied,
161-
TlsAlertDecodeError,
162-
TlsAlertDecryptError,
163-
TlsAlertProtocolVersion,
164-
TlsAlertInsufficientSecurity,
165-
TlsAlertInternalError,
166-
TlsAlertInappropriateFallback,
167-
TlsAlertMissingExtension,
168-
TlsAlertUnsupportedExtension,
169-
TlsAlertUnrecognizedName,
170-
TlsAlertBadCertificateStatusResponse,
171-
TlsAlertUnknownPskIdentity,
172-
TlsAlertCertificateRequired,
173-
TlsAlertNoApplicationProtocol,
174-
TlsAlertUnknown,
145+
pub const Level = enum(u8) {
146+
warning = 1,
147+
fatal = 2,
148+
_,
175149
};
176150

177-
close_notify = 0,
178-
unexpected_message = 10,
179-
bad_record_mac = 20,
180-
record_overflow = 22,
181-
handshake_failure = 40,
182-
bad_certificate = 42,
183-
unsupported_certificate = 43,
184-
certificate_revoked = 44,
185-
certificate_expired = 45,
186-
certificate_unknown = 46,
187-
illegal_parameter = 47,
188-
unknown_ca = 48,
189-
access_denied = 49,
190-
decode_error = 50,
191-
decrypt_error = 51,
192-
protocol_version = 70,
193-
insufficient_security = 71,
194-
internal_error = 80,
195-
inappropriate_fallback = 86,
196-
user_canceled = 90,
197-
missing_extension = 109,
198-
unsupported_extension = 110,
199-
unrecognized_name = 112,
200-
bad_certificate_status_response = 113,
201-
unknown_psk_identity = 115,
202-
certificate_required = 116,
203-
no_application_protocol = 120,
204-
_,
151+
pub const Description = enum(u8) {
152+
pub const Error = error{
153+
TlsAlertUnexpectedMessage,
154+
TlsAlertBadRecordMac,
155+
TlsAlertRecordOverflow,
156+
TlsAlertHandshakeFailure,
157+
TlsAlertBadCertificate,
158+
TlsAlertUnsupportedCertificate,
159+
TlsAlertCertificateRevoked,
160+
TlsAlertCertificateExpired,
161+
TlsAlertCertificateUnknown,
162+
TlsAlertIllegalParameter,
163+
TlsAlertUnknownCa,
164+
TlsAlertAccessDenied,
165+
TlsAlertDecodeError,
166+
TlsAlertDecryptError,
167+
TlsAlertProtocolVersion,
168+
TlsAlertInsufficientSecurity,
169+
TlsAlertInternalError,
170+
TlsAlertInappropriateFallback,
171+
TlsAlertMissingExtension,
172+
TlsAlertUnsupportedExtension,
173+
TlsAlertUnrecognizedName,
174+
TlsAlertBadCertificateStatusResponse,
175+
TlsAlertUnknownPskIdentity,
176+
TlsAlertCertificateRequired,
177+
TlsAlertNoApplicationProtocol,
178+
TlsAlertUnknown,
179+
};
205180

206-
pub fn toError(alert: AlertDescription) Error!void {
207-
switch (alert) {
208-
.close_notify => {}, // not an error
209-
.unexpected_message => return error.TlsAlertUnexpectedMessage,
210-
.bad_record_mac => return error.TlsAlertBadRecordMac,
211-
.record_overflow => return error.TlsAlertRecordOverflow,
212-
.handshake_failure => return error.TlsAlertHandshakeFailure,
213-
.bad_certificate => return error.TlsAlertBadCertificate,
214-
.unsupported_certificate => return error.TlsAlertUnsupportedCertificate,
215-
.certificate_revoked => return error.TlsAlertCertificateRevoked,
216-
.certificate_expired => return error.TlsAlertCertificateExpired,
217-
.certificate_unknown => return error.TlsAlertCertificateUnknown,
218-
.illegal_parameter => return error.TlsAlertIllegalParameter,
219-
.unknown_ca => return error.TlsAlertUnknownCa,
220-
.access_denied => return error.TlsAlertAccessDenied,
221-
.decode_error => return error.TlsAlertDecodeError,
222-
.decrypt_error => return error.TlsAlertDecryptError,
223-
.protocol_version => return error.TlsAlertProtocolVersion,
224-
.insufficient_security => return error.TlsAlertInsufficientSecurity,
225-
.internal_error => return error.TlsAlertInternalError,
226-
.inappropriate_fallback => return error.TlsAlertInappropriateFallback,
227-
.user_canceled => {}, // not an error
228-
.missing_extension => return error.TlsAlertMissingExtension,
229-
.unsupported_extension => return error.TlsAlertUnsupportedExtension,
230-
.unrecognized_name => return error.TlsAlertUnrecognizedName,
231-
.bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse,
232-
.unknown_psk_identity => return error.TlsAlertUnknownPskIdentity,
233-
.certificate_required => return error.TlsAlertCertificateRequired,
234-
.no_application_protocol => return error.TlsAlertNoApplicationProtocol,
235-
_ => return error.TlsAlertUnknown,
181+
close_notify = 0,
182+
unexpected_message = 10,
183+
bad_record_mac = 20,
184+
record_overflow = 22,
185+
handshake_failure = 40,
186+
bad_certificate = 42,
187+
unsupported_certificate = 43,
188+
certificate_revoked = 44,
189+
certificate_expired = 45,
190+
certificate_unknown = 46,
191+
illegal_parameter = 47,
192+
unknown_ca = 48,
193+
access_denied = 49,
194+
decode_error = 50,
195+
decrypt_error = 51,
196+
protocol_version = 70,
197+
insufficient_security = 71,
198+
internal_error = 80,
199+
inappropriate_fallback = 86,
200+
user_canceled = 90,
201+
missing_extension = 109,
202+
unsupported_extension = 110,
203+
unrecognized_name = 112,
204+
bad_certificate_status_response = 113,
205+
unknown_psk_identity = 115,
206+
certificate_required = 116,
207+
no_application_protocol = 120,
208+
_,
209+
210+
pub fn toError(description: Description) Error!void {
211+
switch (description) {
212+
.close_notify => {}, // not an error
213+
.unexpected_message => return error.TlsAlertUnexpectedMessage,
214+
.bad_record_mac => return error.TlsAlertBadRecordMac,
215+
.record_overflow => return error.TlsAlertRecordOverflow,
216+
.handshake_failure => return error.TlsAlertHandshakeFailure,
217+
.bad_certificate => return error.TlsAlertBadCertificate,
218+
.unsupported_certificate => return error.TlsAlertUnsupportedCertificate,
219+
.certificate_revoked => return error.TlsAlertCertificateRevoked,
220+
.certificate_expired => return error.TlsAlertCertificateExpired,
221+
.certificate_unknown => return error.TlsAlertCertificateUnknown,
222+
.illegal_parameter => return error.TlsAlertIllegalParameter,
223+
.unknown_ca => return error.TlsAlertUnknownCa,
224+
.access_denied => return error.TlsAlertAccessDenied,
225+
.decode_error => return error.TlsAlertDecodeError,
226+
.decrypt_error => return error.TlsAlertDecryptError,
227+
.protocol_version => return error.TlsAlertProtocolVersion,
228+
.insufficient_security => return error.TlsAlertInsufficientSecurity,
229+
.internal_error => return error.TlsAlertInternalError,
230+
.inappropriate_fallback => return error.TlsAlertInappropriateFallback,
231+
.user_canceled => {}, // not an error
232+
.missing_extension => return error.TlsAlertMissingExtension,
233+
.unsupported_extension => return error.TlsAlertUnsupportedExtension,
234+
.unrecognized_name => return error.TlsAlertUnrecognizedName,
235+
.bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse,
236+
.unknown_psk_identity => return error.TlsAlertUnknownPskIdentity,
237+
.certificate_required => return error.TlsAlertCertificateRequired,
238+
.no_application_protocol => return error.TlsAlertNoApplicationProtocol,
239+
_ => return error.TlsAlertUnknown,
240+
}
236241
}
237-
}
242+
};
238243
};
239244

240245
pub const SignatureScheme = enum(u16) {
@@ -650,22 +655,24 @@ pub const Decoder = struct {
650655
}
651656

652657
/// Use this function to increase `their_end`.
653-
pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void {
658+
pub fn readAtLeast(d: *Decoder, stream: *std.io.Reader, their_amt: usize) !void {
654659
assert(!d.disable_reads);
655660
const existing_amt = d.cap - d.idx;
656661
d.their_end = d.idx + their_amt;
657662
if (their_amt <= existing_amt) return;
658663
const request_amt = their_amt - existing_amt;
659664
const dest = d.buf[d.cap..];
660665
if (request_amt > dest.len) return error.TlsRecordOverflow;
661-
const actual_amt = try stream.readAtLeast(dest, request_amt);
662-
if (actual_amt < request_amt) return error.TlsConnectionTruncated;
663-
d.cap += actual_amt;
666+
stream.readSlice(dest[0..request_amt]) catch |err| switch (err) {
667+
error.EndOfStream => return error.TlsConnectionTruncated,
668+
error.ReadFailed => return error.ReadFailed,
669+
};
670+
d.cap += request_amt;
664671
}
665672

666673
/// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`.
667674
/// Use when `our_amt` is calculated by us, not by them.
668-
pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void {
675+
pub fn readAtLeastOurAmt(d: *Decoder, stream: *std.io.Reader, our_amt: usize) !void {
669676
assert(!d.disable_reads);
670677
try readAtLeast(d, stream, our_amt);
671678
d.our_end = d.idx + our_amt;

0 commit comments

Comments
 (0)