Skip to content

Commit aac26f3

Browse files
committed
TLS, HTTP, and package fetching fixes
* TLS: add missing assert for output buffer length requirement * TLS: add missing flushes * TLS: add flush implementation * TLS: finish drain implementation * HTTP: correct buffer sizes for TLS * HTTP: expose a getReadError method on Connection * HTTP: add missing flush on sendBodyComplete * Fetch: remove unwanted deinit * Fetch: improve error reporting
1 parent ac18b98 commit aac26f3

File tree

3 files changed

+84
-28
lines changed

3 files changed

+84
-28
lines changed

lib/std/crypto/tls/Client.zig

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ const mem = std.mem;
88
const crypto = std.crypto;
99
const assert = std.debug.assert;
1010
const Certificate = std.crypto.Certificate;
11-
const Reader = std.io.Reader;
12-
const Writer = std.io.Writer;
11+
const Reader = std.Io.Reader;
12+
const Writer = std.Io.Writer;
1313

1414
const max_ciphertext_len = tls.max_ciphertext_len;
1515
const hmacExpandLabel = tls.hmacExpandLabel;
@@ -27,6 +27,8 @@ reader: Reader,
2727

2828
/// The encrypted stream from the client to the server. Bytes are pushed here
2929
/// via `writer`.
30+
///
31+
/// The buffer is asserted to have capacity at least `min_buffer_len`.
3032
output: *Writer,
3133
/// The plaintext stream from the client to the server.
3234
writer: Writer,
@@ -122,7 +124,6 @@ pub const Options = struct {
122124
/// the amount of data expected, such as HTTP with the Content-Length header.
123125
allow_truncation_attacks: bool = false,
124126
write_buffer: []u8,
125-
/// Asserted to have capacity at least `min_buffer_len`.
126127
read_buffer: []u8,
127128
/// Populated when `error.TlsAlert` is returned from `init`.
128129
alert: ?*tls.Alert = null,
@@ -185,6 +186,7 @@ const InitError = error{
185186
/// `input` is asserted to have buffer capacity at least `min_buffer_len`.
186187
pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client {
187188
assert(input.buffer.len >= min_buffer_len);
189+
assert(output.buffer.len >= min_buffer_len);
188190
const host = switch (options.host) {
189191
.no_verification => "",
190192
.explicit => |host| host,
@@ -278,6 +280,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
278280
{
279281
var iovecs: [2][]const u8 = .{ cleartext_header, host };
280282
try output.writeVecAll(iovecs[0..if (host.len == 0) 1 else 2]);
283+
try output.flush();
281284
}
282285

283286
var tls_version: tls.ProtocolVersion = undefined;
@@ -763,6 +766,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
763766
&client_verify_msg,
764767
};
765768
try output.writeVecAll(&all_msgs_vec);
769+
try output.flush();
766770
},
767771
}
768772
write_seq += 1;
@@ -828,6 +832,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
828832
&finished_msg,
829833
};
830834
try output.writeVecAll(&all_msgs_vec);
835+
try output.flush();
831836

832837
const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
833838
const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length);
@@ -877,7 +882,7 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
877882
.buffer = options.write_buffer,
878883
.vtable = &.{
879884
.drain = drain,
880-
.sendFile = Writer.unimplementedSendFile,
885+
.flush = flush,
881886
},
882887
},
883888
.tls_version = tls_version,
@@ -911,31 +916,56 @@ pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client
911916

912917
fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize {
913918
const c: *Client = @alignCast(@fieldParentPtr("writer", w));
914-
if (true) @panic("update to use the buffer and flush");
915-
const sliced_data = if (splat == 0) data[0..data.len -| 1] else data;
916919
const output = c.output;
917920
const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
918-
var total_clear: usize = 0;
919921
var ciphertext_end: usize = 0;
920-
for (sliced_data) |buf| {
921-
const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
922-
total_clear += prepared.cleartext_len;
923-
ciphertext_end += prepared.ciphertext_end;
924-
if (total_clear < buf.len) break;
922+
var total_clear: usize = 0;
923+
done: {
924+
{
925+
const buf = w.buffered();
926+
const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
927+
total_clear += prepared.cleartext_len;
928+
ciphertext_end += prepared.ciphertext_end;
929+
if (prepared.cleartext_len < buf.len) break :done;
930+
}
931+
for (data[0 .. data.len - 1]) |buf| {
932+
if (buf.len < min_buffer_len) break :done;
933+
const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
934+
total_clear += prepared.cleartext_len;
935+
ciphertext_end += prepared.ciphertext_end;
936+
if (prepared.cleartext_len < buf.len) break :done;
937+
}
938+
const buf = data[data.len - 1];
939+
for (0..splat) |_| {
940+
if (buf.len < min_buffer_len) break :done;
941+
const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data);
942+
total_clear += prepared.cleartext_len;
943+
ciphertext_end += prepared.ciphertext_end;
944+
if (prepared.cleartext_len < buf.len) break :done;
945+
}
925946
}
926947
output.advance(ciphertext_end);
927-
return total_clear;
948+
return w.consume(total_clear);
949+
}
950+
951+
fn flush(w: *Writer) Writer.Error!void {
952+
const c: *Client = @alignCast(@fieldParentPtr("writer", w));
953+
const output = c.output;
954+
const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
955+
const prepared = prepareCiphertextRecord(c, ciphertext_buf, w.buffered(), .application_data);
956+
output.advance(prepared.ciphertext_end);
957+
w.end = 0;
928958
}
929959

930960
/// Sends a `close_notify` alert, which is necessary for the server to
931961
/// distinguish between a properly finished TLS session, or a truncation
932962
/// attack.
933963
pub fn end(c: *Client) Writer.Error!void {
964+
try flush(&c.writer);
934965
const output = c.output;
935966
const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len);
936967
const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert);
937-
output.advance(prepared.cleartext_len);
938-
return prepared.ciphertext_end;
968+
output.advance(prepared.ciphertext_end);
939969
}
940970

941971
fn prepareCiphertextRecord(
@@ -1045,7 +1075,7 @@ pub fn eof(c: Client) bool {
10451075
return c.received_close_notify;
10461076
}
10471077

1048-
fn stream(r: *Reader, w: *Writer, limit: std.io.Limit) Reader.StreamError!usize {
1078+
fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
10491079
const c: *Client = @alignCast(@fieldParentPtr("reader", r));
10501080
if (c.eof()) return error.EndOfStream;
10511081
const input = c.input;

lib/std/http/Client.zig

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ connection_pool: ConnectionPool = .{},
4242
///
4343
/// If the entire HTTP header cannot fit in this amount of bytes,
4444
/// `error.HttpHeadersOversize` will be returned from `Request.wait`.
45-
read_buffer_size: usize = 4096,
45+
read_buffer_size: usize = 4096 + if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len,
4646
/// Each `Connection` allocates this amount for the writer buffer.
4747
write_buffer_size: usize = 1024,
4848

@@ -304,15 +304,16 @@ pub const Connection = struct {
304304
const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len];
305305
const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size];
306306
const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size];
307-
const socket_write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
308-
assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len);
307+
const write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size];
308+
const read_buffer = write_buffer.ptr[write_buffer.len..][0..client.read_buffer_size];
309+
assert(base.ptr + alloc_len == read_buffer.ptr + read_buffer.len);
309310
@memcpy(host_buffer, remote_host);
310311
const tls: *Tls = @ptrCast(base);
311312
tls.* = .{
312313
.connection = .{
313314
.client = client,
314-
.stream_writer = stream.writer(socket_write_buffer),
315-
.stream_reader = stream.reader(&.{}),
315+
.stream_writer = stream.writer(tls_write_buffer),
316+
.stream_reader = stream.reader(tls_read_buffer),
316317
.pool_node = .{},
317318
.port = port,
318319
.host_len = @intCast(remote_host.len),
@@ -328,8 +329,8 @@ pub const Connection = struct {
328329
.host = .{ .explicit = remote_host },
329330
.ca = .{ .bundle = client.ca_bundle },
330331
.ssl_key_log = client.ssl_key_log,
331-
.read_buffer = tls_read_buffer,
332-
.write_buffer = tls_write_buffer,
332+
.read_buffer = read_buffer,
333+
.write_buffer = write_buffer,
333334
// This is appropriate for HTTPS because the HTTP headers contain
334335
// the content length which is used to detect truncation attacks.
335336
.allow_truncation_attacks = true,
@@ -347,7 +348,8 @@ pub const Connection = struct {
347348
}
348349

349350
fn allocLen(client: *Client, host_len: usize) usize {
350-
return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + client.write_buffer_size;
351+
return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size +
352+
client.write_buffer_size + client.read_buffer_size;
351353
}
352354

353355
fn host(tls: *Tls) []u8 {
@@ -356,6 +358,21 @@ pub const Connection = struct {
356358
}
357359
};
358360

361+
pub const ReadError = std.crypto.tls.Client.ReadError || std.net.Stream.ReadError;
362+
363+
pub fn getReadError(c: *const Connection) ?ReadError {
364+
return switch (c.protocol) {
365+
.tls => {
366+
if (disable_tls) unreachable;
367+
const tls: *const Tls = @alignCast(@fieldParentPtr("connection", c));
368+
return tls.client.read_err orelse c.stream_reader.getError();
369+
},
370+
.plain => {
371+
return c.stream_reader.getError();
372+
},
373+
};
374+
}
375+
359376
fn getStream(c: *Connection) net.Stream {
360377
return c.stream_reader.getStream();
361378
}
@@ -434,7 +451,6 @@ pub const Connection = struct {
434451
if (disable_tls) unreachable;
435452
const tls: *Tls = @alignCast(@fieldParentPtr("connection", c));
436453
try tls.client.end();
437-
try tls.client.writer.flush();
438454
}
439455
try c.stream_writer.interface.flush();
440456
}
@@ -874,6 +890,7 @@ pub const Request = struct {
874890
var bw = try sendBodyUnflushed(r, body);
875891
bw.writer.end = body.len;
876892
try bw.end();
893+
try r.connection.?.flush();
877894
}
878895

879896
/// Transfers the HTTP head over the connection, which is not flushed until
@@ -1063,6 +1080,9 @@ pub const Request = struct {
10631080
/// buffer capacity would be exceeded, `error.HttpRedirectLocationOversize`
10641081
/// is returned instead. This buffer may be empty if no redirects are to be
10651082
/// handled.
1083+
///
1084+
/// If this fails with `error.ReadFailed` then the `Connection.getReadError`
1085+
/// method of `r.connection` can be used to get more detailed information.
10661086
pub fn receiveHead(r: *Request, redirect_buffer: []u8) ReceiveHeadError!Response {
10671087
var aux_buf = redirect_buffer;
10681088
while (true) {

src/Package/Fetch.zig

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -998,15 +998,21 @@ fn initResource(f: *Fetch, uri: std.Uri, resource: *Resource, reader_buffer: []u
998998
.buffer = reader_buffer,
999999
} };
10001000
const request = &resource.http_request.request;
1001-
defer request.deinit();
1001+
errdefer request.deinit();
10021002

10031003
request.sendBodiless() catch |err|
10041004
return f.fail(f.location_tok, try eb.printString("HTTP request failed: {t}", .{err}));
10051005

10061006
var redirect_buffer: [1024]u8 = undefined;
10071007
const response = &resource.http_request.response;
1008-
response.* = request.receiveHead(&redirect_buffer) catch |err|
1009-
return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{err}));
1008+
response.* = request.receiveHead(&redirect_buffer) catch |err| switch (err) {
1009+
error.ReadFailed => {
1010+
return f.fail(f.location_tok, try eb.printString("HTTP response read failure: {t}", .{
1011+
request.connection.?.getReadError().?,
1012+
}));
1013+
},
1014+
else => |e| return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{e})),
1015+
};
10101016

10111017
if (response.head.status != .ok) return f.fail(f.location_tok, try eb.printString(
10121018
"bad HTTP response code: '{d} {s}'",

0 commit comments

Comments
 (0)