diff --git a/lib/compiler/resinator/cli.zig b/lib/compiler/resinator/cli.zig index 6d4a368b4e80..bfc67d87916d 100644 --- a/lib/compiler/resinator/cli.zig +++ b/lib/compiler/resinator/cli.zig @@ -1141,6 +1141,8 @@ pub fn parse(allocator: Allocator, args: []const []const u8, diagnostics: *Diagn } output_format = .res; } + } else { + output_format_source = .output_format_arg; } options.output_source = .{ .filename = try filepathWithExtension(allocator, options.input_source.filename, output_format.?.extension()) }; } else { @@ -1529,21 +1531,21 @@ fn testParseOutput(args: []const []const u8, expected_output: []const u8) !?Opti var diagnostics = Diagnostics.init(std.testing.allocator); defer diagnostics.deinit(); - var output = std.ArrayList(u8).init(std.testing.allocator); + var output: std.io.Writer.Allocating = .init(std.testing.allocator); defer output.deinit(); var options = parse(std.testing.allocator, args, &diagnostics) catch |err| switch (err) { error.ParseError => { - try diagnostics.renderToWriter(args, output.writer(), .no_color); - try std.testing.expectEqualStrings(expected_output, output.items); + try diagnostics.renderToWriter(args, &output.writer, .no_color); + try std.testing.expectEqualStrings(expected_output, output.getWritten()); return null; }, else => |e| return e, }; errdefer options.deinit(); - try diagnostics.renderToWriter(args, output.writer(), .no_color); - try std.testing.expectEqualStrings(expected_output, output.items); + try diagnostics.renderToWriter(args, &output.writer, .no_color); + try std.testing.expectEqualStrings(expected_output, output.getWritten()); return options; } diff --git a/lib/compiler/resinator/compile.zig b/lib/compiler/resinator/compile.zig index 3515421ff0fc..18d142eb9643 100644 --- a/lib/compiler/resinator/compile.zig +++ b/lib/compiler/resinator/compile.zig @@ -550,7 +550,7 @@ pub const Compiler = struct { // so get it here to simplify future usage. const filename_token = node.filename.getFirstToken(); - const file = self.searchForFile(filename_utf8) catch |err| switch (err) { + const file_handle = self.searchForFile(filename_utf8) catch |err| switch (err) { error.OutOfMemory => |e| return e, else => |e| { const filename_string_index = try self.diagnostics.putString(filename_utf8); @@ -564,13 +564,15 @@ pub const Compiler = struct { }); }, }; - defer file.close(); + defer file_handle.close(); + var file_buffer: [2048]u8 = undefined; + var file_reader = file_handle.reader(&file_buffer); if (maybe_predefined_type) |predefined_type| { switch (predefined_type) { .GROUP_ICON, .GROUP_CURSOR => { // Check for animated icon first - if (ani.isAnimatedIcon(file.deprecatedReader())) { + if (ani.isAnimatedIcon(file_reader.interface.adaptToOldInterface())) { // Animated icons are just put into the resource unmodified, // and the resource type changes to ANIICON/ANICURSOR @@ -582,18 +584,18 @@ pub const Compiler = struct { header.type_value.ordinal = @intFromEnum(new_predefined_type); header.memory_flags = MemoryFlags.defaults(new_predefined_type); header.applyMemoryFlags(node.common_resource_attributes, self.source); - header.data_size = @intCast(try file.getEndPos()); + header.data_size = @intCast(try file_reader.getSize()); try header.write(writer, self.errContext(node.id)); - try file.seekTo(0); - try writeResourceData(writer, file.deprecatedReader(), header.data_size); + try file_reader.seekTo(0); + try writeResourceData(writer, &file_reader.interface, header.data_size); return; } // isAnimatedIcon moved the file cursor so reset to the start - try file.seekTo(0); + try file_reader.seekTo(0); - const icon_dir = ico.read(self.allocator, file.deprecatedReader(), try file.getEndPos()) catch |err| switch (err) { + const icon_dir = ico.read(self.allocator, file_reader.interface.adaptToOldInterface(), try file_reader.getSize()) catch |err| switch (err) { error.OutOfMemory => |e| return e, else => |e| { return self.iconReadError( @@ -671,15 +673,15 @@ pub const Compiler = struct { try writer.writeInt(u16, entry.type_specific_data.cursor.hotspot_y, .little); } - try file.seekTo(entry.data_offset_from_start_of_file); - var header_bytes = file.deprecatedReader().readBytesNoEof(16) catch { + try file_reader.seekTo(entry.data_offset_from_start_of_file); + var header_bytes = (file_reader.interface.takeArray(16) catch { return self.iconReadError( error.UnexpectedEOF, filename_utf8, filename_token, predefined_type, ); - }; + }).*; const image_format = ico.ImageFormat.detect(&header_bytes); if (!image_format.validate(&header_bytes)) { @@ -802,8 +804,8 @@ pub const Compiler = struct { }, } - try file.seekTo(entry.data_offset_from_start_of_file); - try writeResourceDataNoPadding(writer, file.deprecatedReader(), entry.data_size_in_bytes); + try file_reader.seekTo(entry.data_offset_from_start_of_file); + try writeResourceDataNoPadding(writer, &file_reader.interface, entry.data_size_in_bytes); try writeDataPadding(writer, full_data_size); if (self.state.icon_id == std.math.maxInt(u16)) { @@ -857,9 +859,9 @@ pub const Compiler = struct { }, .BITMAP => { header.applyMemoryFlags(node.common_resource_attributes, self.source); - const file_size = try file.getEndPos(); + const file_size = try file_reader.getSize(); - const bitmap_info = bmp.read(file.deprecatedReader(), file_size) catch |err| { + const bitmap_info = bmp.read(file_reader.interface.adaptToOldInterface(), file_size) catch |err| { const filename_string_index = try self.diagnostics.putString(filename_utf8); return self.addErrorDetailsAndFail(.{ .err = .bmp_read_error, @@ -921,18 +923,17 @@ pub const Compiler = struct { header.data_size = bmp_bytes_to_write; try header.write(writer, self.errContext(node.id)); - try file.seekTo(bmp.file_header_len); - const file_reader = file.deprecatedReader(); - try writeResourceDataNoPadding(writer, file_reader, bitmap_info.dib_header_size); + try file_reader.seekTo(bmp.file_header_len); + try writeResourceDataNoPadding(writer, &file_reader.interface, bitmap_info.dib_header_size); if (bitmap_info.getBitmasksByteLen() > 0) { - try writeResourceDataNoPadding(writer, file_reader, bitmap_info.getBitmasksByteLen()); + try writeResourceDataNoPadding(writer, &file_reader.interface, bitmap_info.getBitmasksByteLen()); } if (bitmap_info.getExpectedPaletteByteLen() > 0) { - try writeResourceDataNoPadding(writer, file_reader, @intCast(bitmap_info.getActualPaletteByteLen())); + try writeResourceDataNoPadding(writer, &file_reader.interface, @intCast(bitmap_info.getActualPaletteByteLen())); } - try file.seekTo(bitmap_info.pixel_data_offset); + try file_reader.seekTo(bitmap_info.pixel_data_offset); const pixel_bytes: u32 = @intCast(file_size - bitmap_info.pixel_data_offset); - try writeResourceDataNoPadding(writer, file_reader, pixel_bytes); + try writeResourceDataNoPadding(writer, &file_reader.interface, pixel_bytes); try writeDataPadding(writer, bmp_bytes_to_write); return; }, @@ -956,7 +957,7 @@ pub const Compiler = struct { return; } header.applyMemoryFlags(node.common_resource_attributes, self.source); - const file_size = try file.getEndPos(); + const file_size = try file_reader.getSize(); if (file_size > std.math.maxInt(u32)) { return self.addErrorDetailsAndFail(.{ .err = .resource_data_size_exceeds_max, @@ -968,8 +969,9 @@ pub const Compiler = struct { header.data_size = @intCast(file_size); try header.write(writer, self.errContext(node.id)); - var header_slurping_reader = headerSlurpingReader(148, file.deprecatedReader()); - try writeResourceData(writer, header_slurping_reader.reader(), header.data_size); + var header_slurping_reader = headerSlurpingReader(148, file_reader.interface.adaptToOldInterface()); + var adapter = header_slurping_reader.reader().adaptToNewApi(&.{}); + try writeResourceData(writer, &adapter.new_interface, header.data_size); try self.state.font_dir.add(self.arena, FontDir.Font{ .id = header.name_value.ordinal, @@ -992,7 +994,7 @@ pub const Compiler = struct { } // Fallback to just writing out the entire contents of the file - const data_size = try file.getEndPos(); + const data_size = try file_reader.getSize(); if (data_size > std.math.maxInt(u32)) { return self.addErrorDetailsAndFail(.{ .err = .resource_data_size_exceeds_max, @@ -1002,7 +1004,7 @@ pub const Compiler = struct { // We now know that the data size will fit in a u32 header.data_size = @intCast(data_size); try header.write(writer, self.errContext(node.id)); - try writeResourceData(writer, file.deprecatedReader(), header.data_size); + try writeResourceData(writer, &file_reader.interface, header.data_size); } fn iconReadError( @@ -1250,8 +1252,8 @@ pub const Compiler = struct { const data_len: u32 = @intCast(data_buffer.items.len); try self.writeResourceHeader(writer, node.id, node.type, data_len, node.common_resource_attributes, self.state.language); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try writeResourceData(writer, data_fbs.reader(), data_len); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try writeResourceData(writer, &data_fbs, data_len); } pub fn writeResourceHeader(self: *Compiler, writer: anytype, id_token: Token, type_token: Token, data_size: u32, common_resource_attributes: []Token, language: res.Language) !void { @@ -1266,15 +1268,15 @@ pub const Compiler = struct { try header.write(writer, self.errContext(id_token)); } - pub fn writeResourceDataNoPadding(writer: anytype, data_reader: anytype, data_size: u32) !void { - var limited_reader = std.io.limitedReader(data_reader, data_size); - - const FifoBuffer = std.fifo.LinearFifo(u8, .{ .Static = 4096 }); - var fifo = FifoBuffer.init(); - try fifo.pump(limited_reader.reader(), writer); + pub fn writeResourceDataNoPadding(writer: anytype, data_reader: *std.Io.Reader, data_size: u32) !void { + var adapted = writer.adaptToNewApi(); + var buffer: [128]u8 = undefined; + adapted.new_interface.buffer = &buffer; + try data_reader.streamExact(&adapted.new_interface, data_size); + try adapted.new_interface.flush(); } - pub fn writeResourceData(writer: anytype, data_reader: anytype, data_size: u32) !void { + pub fn writeResourceData(writer: anytype, data_reader: *std.Io.Reader, data_size: u32) !void { try writeResourceDataNoPadding(writer, data_reader, data_size); try writeDataPadding(writer, data_size); } @@ -1339,8 +1341,8 @@ pub const Compiler = struct { try header.write(writer, self.errContext(node.id)); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try writeResourceData(writer, data_fbs.reader(), data_size); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try writeResourceData(writer, &data_fbs, data_size); } /// Expects `data_writer` to be a LimitedWriter limited to u32, meaning all writes to @@ -1732,8 +1734,8 @@ pub const Compiler = struct { try header.write(writer, self.errContext(node.id)); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try writeResourceData(writer, data_fbs.reader(), data_size); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try writeResourceData(writer, &data_fbs, data_size); } fn writeDialogHeaderAndStrings( @@ -2046,8 +2048,8 @@ pub const Compiler = struct { try header.write(writer, self.errContext(node.id)); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try writeResourceData(writer, data_fbs.reader(), data_size); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try writeResourceData(writer, &data_fbs, data_size); } /// Weight and italic carry over from previous FONT statements within a single resource, @@ -2121,8 +2123,8 @@ pub const Compiler = struct { try header.write(writer, self.errContext(node.id)); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try writeResourceData(writer, data_fbs.reader(), data_size); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try writeResourceData(writer, &data_fbs, data_size); } /// Expects `data_writer` to be a LimitedWriter limited to u32, meaning all writes to @@ -2386,8 +2388,8 @@ pub const Compiler = struct { try header.write(writer, self.errContext(node.id)); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try writeResourceData(writer, data_fbs.reader(), data_size); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try writeResourceData(writer, &data_fbs, data_size); } /// Expects writer to be a LimitedWriter limited to u16, meaning all writes to @@ -3321,8 +3323,8 @@ pub const StringTable = struct { // we fully control and know are numbers, so they have a fixed size. try header.writeAssertNoOverflow(writer); - var data_fbs = std.io.fixedBufferStream(data_buffer.items); - try Compiler.writeResourceData(writer, data_fbs.reader(), data_size); + var data_fbs: std.Io.Reader = .fixed(data_buffer.items); + try Compiler.writeResourceData(writer, &data_fbs, data_size); } }; diff --git a/lib/compiler/resinator/cvtres.zig b/lib/compiler/resinator/cvtres.zig index 21574f2704d0..27e14ae9a36c 100644 --- a/lib/compiler/resinator/cvtres.zig +++ b/lib/compiler/resinator/cvtres.zig @@ -65,7 +65,7 @@ pub const ParseResOptions = struct { }; /// The returned ParsedResources should be freed by calling its `deinit` function. -pub fn parseRes(allocator: Allocator, reader: anytype, options: ParseResOptions) !ParsedResources { +pub fn parseRes(allocator: Allocator, reader: *std.Io.Reader, options: ParseResOptions) !ParsedResources { var resources = ParsedResources.init(allocator); errdefer resources.deinit(); @@ -74,7 +74,7 @@ pub fn parseRes(allocator: Allocator, reader: anytype, options: ParseResOptions) return resources; } -pub fn parseResInto(resources: *ParsedResources, reader: anytype, options: ParseResOptions) !void { +pub fn parseResInto(resources: *ParsedResources, reader: *std.Io.Reader, options: ParseResOptions) !void { const allocator = resources.allocator; var bytes_remaining: u64 = options.max_size; { @@ -103,43 +103,38 @@ pub const ResourceAndSize = struct { total_size: u64, }; -pub fn parseResource(allocator: Allocator, reader: anytype, max_size: u64) !ResourceAndSize { - var header_counting_reader = std.io.countingReader(reader); - const header_reader = header_counting_reader.reader(); - const data_size = try header_reader.readInt(u32, .little); - const header_size = try header_reader.readInt(u32, .little); +pub fn parseResource(allocator: Allocator, reader: *std.Io.Reader, max_size: u64) !ResourceAndSize { + const data_size = try reader.takeInt(u32, .little); + const header_size = try reader.takeInt(u32, .little); const total_size: u64 = @as(u64, header_size) + data_size; if (total_size > max_size) return error.ImpossibleSize; - var header_bytes_available = header_size -| 8; - var type_reader = std.io.limitedReader(header_reader, header_bytes_available); - const type_value = try parseNameOrOrdinal(allocator, type_reader.reader()); + const remaining_header_bytes = try reader.take(header_size -| 8); + var remaining_header_reader: std.Io.Reader = .fixed(remaining_header_bytes); + const type_value = try parseNameOrOrdinal(allocator, &remaining_header_reader); errdefer type_value.deinit(allocator); - header_bytes_available -|= @intCast(type_value.byteLen()); - var name_reader = std.io.limitedReader(header_reader, header_bytes_available); - const name_value = try parseNameOrOrdinal(allocator, name_reader.reader()); + const name_value = try parseNameOrOrdinal(allocator, &remaining_header_reader); errdefer name_value.deinit(allocator); - const padding_after_name = numPaddingBytesNeeded(@intCast(header_counting_reader.bytes_read)); - try header_reader.skipBytes(padding_after_name, .{ .buf_size = 3 }); + const padding_after_name = numPaddingBytesNeeded(@intCast(remaining_header_reader.seek)); + try remaining_header_reader.discardAll(padding_after_name); - std.debug.assert(header_counting_reader.bytes_read % 4 == 0); - const data_version = try header_reader.readInt(u32, .little); - const memory_flags: MemoryFlags = @bitCast(try header_reader.readInt(u16, .little)); - const language: Language = @bitCast(try header_reader.readInt(u16, .little)); - const version = try header_reader.readInt(u32, .little); - const characteristics = try header_reader.readInt(u32, .little); + std.debug.assert(remaining_header_reader.seek % 4 == 0); + const data_version = try remaining_header_reader.takeInt(u32, .little); + const memory_flags: MemoryFlags = @bitCast(try remaining_header_reader.takeInt(u16, .little)); + const language: Language = @bitCast(try remaining_header_reader.takeInt(u16, .little)); + const version = try remaining_header_reader.takeInt(u32, .little); + const characteristics = try remaining_header_reader.takeInt(u32, .little); - const header_bytes_read = header_counting_reader.bytes_read; - if (header_size != header_bytes_read) return error.HeaderSizeMismatch; + if (remaining_header_reader.seek != remaining_header_reader.end) return error.HeaderSizeMismatch; const data = try allocator.alloc(u8, data_size); errdefer allocator.free(data); - try reader.readNoEof(data); + try reader.readSliceAll(data); const padding_after_data = numPaddingBytesNeeded(@intCast(data_size)); - try reader.skipBytes(padding_after_data, .{ .buf_size = 3 }); + try reader.discardAll(padding_after_data); return .{ .resource = .{ @@ -156,10 +151,10 @@ pub fn parseResource(allocator: Allocator, reader: anytype, max_size: u64) !Reso }; } -pub fn parseNameOrOrdinal(allocator: Allocator, reader: anytype) !NameOrOrdinal { - const first_code_unit = try reader.readInt(u16, .little); +pub fn parseNameOrOrdinal(allocator: Allocator, reader: *std.Io.Reader) !NameOrOrdinal { + const first_code_unit = try reader.takeInt(u16, .little); if (first_code_unit == 0xFFFF) { - const ordinal_value = try reader.readInt(u16, .little); + const ordinal_value = try reader.takeInt(u16, .little); return .{ .ordinal = ordinal_value }; } var name_buf = try std.ArrayListUnmanaged(u16).initCapacity(allocator, 16); @@ -167,7 +162,7 @@ pub fn parseNameOrOrdinal(allocator: Allocator, reader: anytype) !NameOrOrdinal var code_unit = first_code_unit; while (code_unit != 0) { try name_buf.append(allocator, std.mem.nativeToLittle(u16, code_unit)); - code_unit = try reader.readInt(u16, .little); + code_unit = try reader.takeInt(u16, .little); } return .{ .name = try name_buf.toOwnedSliceSentinel(allocator, 0) }; } diff --git a/lib/compiler/resinator/errors.zig b/lib/compiler/resinator/errors.zig index 14a001894e4c..4bc443c4e792 100644 --- a/lib/compiler/resinator/errors.zig +++ b/lib/compiler/resinator/errors.zig @@ -1078,11 +1078,9 @@ const CorrespondingLines = struct { at_eof: bool = false, span: SourceMappings.CorrespondingSpan, file: std.fs.File, - buffered_reader: BufferedReaderType, + buffered_reader: std.fs.File.Reader, code_page: SupportedCodePage, - const BufferedReaderType = std.io.BufferedReader(512, std.fs.File.DeprecatedReader); - pub fn init(cwd: std.fs.Dir, err_details: ErrorDetails, line_for_comparison: []const u8, corresponding_span: SourceMappings.CorrespondingSpan, corresponding_file: []const u8) !CorrespondingLines { // We don't do line comparison for this error, so don't print the note if the line // number is different @@ -1101,9 +1099,7 @@ const CorrespondingLines = struct { .buffered_reader = undefined, .code_page = err_details.code_page, }; - corresponding_lines.buffered_reader = BufferedReaderType{ - .unbuffered_reader = corresponding_lines.file.deprecatedReader(), - }; + corresponding_lines.buffered_reader = corresponding_lines.file.reader(&.{}); errdefer corresponding_lines.deinit(); var fbs = std.io.fixedBufferStream(&corresponding_lines.line_buf); @@ -1111,7 +1107,7 @@ const CorrespondingLines = struct { try corresponding_lines.writeLineFromStreamVerbatim( writer, - corresponding_lines.buffered_reader.reader(), + corresponding_lines.buffered_reader.interface.adaptToOldInterface(), corresponding_span.start_line, ); @@ -1154,7 +1150,7 @@ const CorrespondingLines = struct { try self.writeLineFromStreamVerbatim( writer, - self.buffered_reader.reader(), + self.buffered_reader.interface.adaptToOldInterface(), self.line_num, ); diff --git a/lib/compiler/resinator/ico.zig b/lib/compiler/resinator/ico.zig index e6de1d469e91..dca74fc8574d 100644 --- a/lib/compiler/resinator/ico.zig +++ b/lib/compiler/resinator/ico.zig @@ -14,8 +14,9 @@ pub fn read(allocator: std.mem.Allocator, reader: anytype, max_size: u64) ReadEr // Some Reader implementations have an empty ReadError error set which would // cause 'unreachable else' if we tried to use an else in the switch, so we // need to detect this case and not try to translate to ReadError + const anyerror_reader_errorset = @TypeOf(reader).Error == anyerror; const empty_reader_errorset = @typeInfo(@TypeOf(reader).Error).error_set == null or @typeInfo(@TypeOf(reader).Error).error_set.?.len == 0; - if (empty_reader_errorset) { + if (empty_reader_errorset and !anyerror_reader_errorset) { return readAnyError(allocator, reader, max_size) catch |err| switch (err) { error.EndOfStream => error.UnexpectedEOF, else => |e| return e, diff --git a/lib/compiler/resinator/main.zig b/lib/compiler/resinator/main.zig index 30e9c825bb89..3187f038b90b 100644 --- a/lib/compiler/resinator/main.zig +++ b/lib/compiler/resinator/main.zig @@ -325,8 +325,8 @@ pub fn main() !void { std.debug.assert(options.output_format == .coff); // TODO: Maybe use a buffered file reader instead of reading file into memory -> fbs - var fbs = std.io.fixedBufferStream(res_data.bytes); - break :resources cvtres.parseRes(allocator, fbs.reader(), .{ .max_size = res_data.bytes.len }) catch |err| { + var res_reader: std.Io.Reader = .fixed(res_data.bytes); + break :resources cvtres.parseRes(allocator, &res_reader, .{ .max_size = res_data.bytes.len }) catch |err| { // TODO: Better errors try error_handler.emitMessage(allocator, .err, "unable to parse res from '{s}': {s}", .{ res_stream.name, @errorName(err) }); std.process.exit(1); diff --git a/lib/docs/wasm/markdown.zig b/lib/docs/wasm/markdown.zig index 3293b680c98f..32e32b41042d 100644 --- a/lib/docs/wasm/markdown.zig +++ b/lib/docs/wasm/markdown.zig @@ -145,13 +145,12 @@ fn mainImpl() !void { var parser = try Parser.init(gpa); defer parser.deinit(); - var stdin_buf = std.io.bufferedReader(std.fs.File.stdin().deprecatedReader()); - var line_buf = std.ArrayList(u8).init(gpa); - defer line_buf.deinit(); - while (stdin_buf.reader().streamUntilDelimiter(line_buf.writer(), '\n', null)) { - if (line_buf.getLastOrNull() == '\r') _ = line_buf.pop(); - try parser.feedLine(line_buf.items); - line_buf.clearRetainingCapacity(); + var stdin_buffer: [1024]u8 = undefined; + var stdin_reader = std.fs.File.stdin().reader(&stdin_buffer); + + while (stdin_reader.takeDelimiterExclusive('\n')) |line| { + const trimmed = std.mem.trimRight(u8, line, '\r'); + try parser.feedLine(trimmed); } else |err| switch (err) { error.EndOfStream => {}, else => |e| return e, diff --git a/lib/std/Build/Fuzz.zig b/lib/std/Build/Fuzz.zig index a25b50175597..bc10f7907a2e 100644 --- a/lib/std/Build/Fuzz.zig +++ b/lib/std/Build/Fuzz.zig @@ -234,7 +234,7 @@ pub const Previous = struct { }; pub fn sendUpdate( fuzz: *Fuzz, - socket: *std.http.WebSocket, + socket: *std.http.Server.WebSocket, prev: *Previous, ) !void { fuzz.coverage_mutex.lock(); @@ -263,36 +263,36 @@ pub fn sendUpdate( .string_bytes_len = @intCast(coverage_map.coverage.string_bytes.items.len), .start_timestamp = coverage_map.start_timestamp, }; - const iovecs: [5]std.posix.iovec_const = .{ - makeIov(@ptrCast(&header)), - makeIov(@ptrCast(coverage_map.coverage.directories.keys())), - makeIov(@ptrCast(coverage_map.coverage.files.keys())), - makeIov(@ptrCast(coverage_map.source_locations)), - makeIov(coverage_map.coverage.string_bytes.items), + var iovecs: [5][]const u8 = .{ + @ptrCast(&header), + @ptrCast(coverage_map.coverage.directories.keys()), + @ptrCast(coverage_map.coverage.files.keys()), + @ptrCast(coverage_map.source_locations), + coverage_map.coverage.string_bytes.items, }; - try socket.writeMessagev(&iovecs, .binary); + try socket.writeMessageVec(&iovecs, .binary); } const header: abi.CoverageUpdateHeader = .{ .n_runs = n_runs, .unique_runs = unique_runs, }; - const iovecs: [2]std.posix.iovec_const = .{ - makeIov(@ptrCast(&header)), - makeIov(@ptrCast(seen_pcs)), + var iovecs: [2][]const u8 = .{ + @ptrCast(&header), + @ptrCast(seen_pcs), }; - try socket.writeMessagev(&iovecs, .binary); + try socket.writeMessageVec(&iovecs, .binary); prev.unique_runs = unique_runs; } if (prev.entry_points != coverage_map.entry_points.items.len) { const header: abi.EntryPointHeader = .init(@intCast(coverage_map.entry_points.items.len)); - const iovecs: [2]std.posix.iovec_const = .{ - makeIov(@ptrCast(&header)), - makeIov(@ptrCast(coverage_map.entry_points.items)), + var iovecs: [2][]const u8 = .{ + @ptrCast(&header), + @ptrCast(coverage_map.entry_points.items), }; - try socket.writeMessagev(&iovecs, .binary); + try socket.writeMessageVec(&iovecs, .binary); prev.entry_points = coverage_map.entry_points.items.len; } @@ -448,10 +448,3 @@ fn addEntryPoint(fuzz: *Fuzz, coverage_id: u64, addr: u64) error{ AlreadyReporte } try coverage_map.entry_points.append(fuzz.ws.gpa, @intCast(index)); } - -fn makeIov(s: []const u8) std.posix.iovec_const { - return .{ - .base = s.ptr, - .len = s.len, - }; -} diff --git a/lib/std/Build/WebServer.zig b/lib/std/Build/WebServer.zig index 9264d7473c61..868aabe67e22 100644 --- a/lib/std/Build/WebServer.zig +++ b/lib/std/Build/WebServer.zig @@ -251,48 +251,44 @@ pub fn now(s: *const WebServer) i64 { fn accept(ws: *WebServer, connection: std.net.Server.Connection) void { defer connection.stream.close(); - var read_buf: [0x4000]u8 = undefined; - var server: std.http.Server = .init(connection, &read_buf); + var send_buffer: [4096]u8 = undefined; + var recv_buffer: [4096]u8 = undefined; + var connection_reader = connection.stream.reader(&recv_buffer); + var connection_writer = connection.stream.writer(&send_buffer); + var server: http.Server = .init(connection_reader.interface(), &connection_writer.interface); while (true) { var request = server.receiveHead() catch |err| switch (err) { error.HttpConnectionClosing => return, - else => { - log.err("failed to receive http request: {s}", .{@errorName(err)}); - return; - }, + else => return log.err("failed to receive http request: {t}", .{err}), }; - var ws_send_buf: [0x4000]u8 = undefined; - var ws_recv_buf: [0x4000]u8 align(4) = undefined; - if (std.http.WebSocket.init(&request, &ws_send_buf, &ws_recv_buf) catch |err| { - log.err("failed to initialize websocket connection: {s}", .{@errorName(err)}); - return; - }) |ws_init| { - var web_socket = ws_init; - ws.serveWebSocket(&web_socket) catch |err| { - log.err("failed to serve websocket: {s}", .{@errorName(err)}); - return; - }; - comptime unreachable; - } else { - ws.serveRequest(&request) catch |err| switch (err) { - error.AlreadyReported => return, - else => { - log.err("failed to serve '{s}': {s}", .{ request.head.target, @errorName(err) }); + switch (request.upgradeRequested()) { + .websocket => |opt_key| { + const key = opt_key orelse return log.err("missing websocket key", .{}); + var web_socket = request.respondWebSocket(.{ .key = key }) catch { + return log.err("failed to respond web socket: {t}", .{connection_writer.err.?}); + }; + ws.serveWebSocket(&web_socket) catch |err| { + log.err("failed to serve websocket: {t}", .{err}); return; - }, - }; + }; + comptime unreachable; + }, + .other => |name| return log.err("unknown upgrade request: {s}", .{name}), + .none => { + ws.serveRequest(&request) catch |err| switch (err) { + error.AlreadyReported => return, + else => { + log.err("failed to serve '{s}': {t}", .{ request.head.target, err }); + return; + }, + }; + }, } } } -fn makeIov(s: []const u8) std.posix.iovec_const { - return .{ - .base = s.ptr, - .len = s.len, - }; -} -fn serveWebSocket(ws: *WebServer, sock: *std.http.WebSocket) !noreturn { +fn serveWebSocket(ws: *WebServer, sock: *http.Server.WebSocket) !noreturn { var prev_build_status = ws.build_status.load(.monotonic); const prev_step_status_bits = try ws.gpa.alloc(u8, ws.step_status_bits.len); @@ -312,11 +308,8 @@ fn serveWebSocket(ws: *WebServer, sock: *std.http.WebSocket) !noreturn { .timestamp = ws.now(), .steps_len = @intCast(ws.all_steps.len), }; - try sock.writeMessagev(&.{ - makeIov(@ptrCast(&hello_header)), - makeIov(ws.step_names_trailing), - makeIov(prev_step_status_bits), - }, .binary); + var bufs: [3][]const u8 = .{ @ptrCast(&hello_header), ws.step_names_trailing, prev_step_status_bits }; + try sock.writeMessageVec(&bufs, .binary); } var prev_fuzz: Fuzz.Previous = .init; @@ -380,7 +373,7 @@ fn serveWebSocket(ws: *WebServer, sock: *std.http.WebSocket) !noreturn { std.Thread.Futex.timedWait(&ws.update_id, start_update_id, std.time.ns_per_ms * default_update_interval_ms) catch {}; } } -fn recvWebSocketMessages(ws: *WebServer, sock: *std.http.WebSocket) void { +fn recvWebSocketMessages(ws: *WebServer, sock: *http.Server.WebSocket) void { while (true) { const msg = sock.readSmallMessage() catch return; if (msg.opcode != .binary) continue; @@ -402,7 +395,7 @@ fn recvWebSocketMessages(ws: *WebServer, sock: *std.http.WebSocket) void { } } -fn serveRequest(ws: *WebServer, req: *std.http.Server.Request) !void { +fn serveRequest(ws: *WebServer, req: *http.Server.Request) !void { // Strip an optional leading '/debug' component from the request. const target: []const u8, const debug: bool = target: { if (mem.eql(u8, req.head.target, "/debug")) break :target .{ "/", true }; @@ -431,7 +424,7 @@ fn serveRequest(ws: *WebServer, req: *std.http.Server.Request) !void { fn serveLibFile( ws: *WebServer, - request: *std.http.Server.Request, + request: *http.Server.Request, sub_path: []const u8, content_type: []const u8, ) !void { @@ -442,7 +435,7 @@ fn serveLibFile( } fn serveClientWasm( ws: *WebServer, - req: *std.http.Server.Request, + req: *http.Server.Request, optimize_mode: std.builtin.OptimizeMode, ) !void { var arena_state: std.heap.ArenaAllocator = .init(ws.gpa); @@ -456,12 +449,12 @@ fn serveClientWasm( pub fn serveFile( ws: *WebServer, - request: *std.http.Server.Request, + request: *http.Server.Request, path: Cache.Path, content_type: []const u8, ) !void { const gpa = ws.gpa; - // The desired API is actually sendfile, which will require enhancing std.http.Server. + // The desired API is actually sendfile, which will require enhancing http.Server. // We load the file with every request so that the user can make changes to the file // and refresh the HTML page without restarting this server. const file_contents = path.root_dir.handle.readFileAlloc(gpa, path.sub_path, 10 * 1024 * 1024) catch |err| { @@ -478,14 +471,13 @@ pub fn serveFile( } pub fn serveTarFile( ws: *WebServer, - request: *std.http.Server.Request, + request: *http.Server.Request, paths: []const Cache.Path, ) !void { const gpa = ws.gpa; - var send_buf: [0x4000]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buf, + var send_buffer: [0x4000]u8 = undefined; + var response = try request.respondStreaming(&send_buffer, .{ .respond_options = .{ .extra_headers = &.{ .{ .name = "Content-Type", .value = "application/x-tar" }, @@ -497,10 +489,7 @@ pub fn serveTarFile( var cached_cwd_path: ?[]const u8 = null; defer if (cached_cwd_path) |p| gpa.free(p); - var response_buf: [1024]u8 = undefined; - var adapter = response.writer().adaptToNewApi(); - adapter.new_interface.buffer = &response_buf; - var archiver: std.tar.Writer = .{ .underlying_writer = &adapter.new_interface }; + var archiver: std.tar.Writer = .{ .underlying_writer = &response.writer }; for (paths) |path| { var file = path.root_dir.handle.openFile(path.sub_path, .{}) catch |err| { @@ -526,7 +515,6 @@ pub fn serveTarFile( } // intentionally not calling `archiver.finishPedantically` - try adapter.new_interface.flush(); try response.end(); } @@ -804,7 +792,7 @@ pub fn wait(ws: *WebServer) RunnerRequest { } } -const cache_control_header: std.http.Header = .{ +const cache_control_header: http.Header = .{ .name = "Cache-Control", .value = "max-age=0, must-revalidate", }; @@ -819,5 +807,6 @@ const Build = std.Build; const Cache = Build.Cache; const Fuzz = Build.Fuzz; const abi = Build.abi; +const http = std.http; const WebServer = @This(); diff --git a/lib/std/Io.zig b/lib/std/Io.zig index ab309c87201e..746a457c5bb6 100644 --- a/lib/std/Io.zig +++ b/lib/std/Io.zig @@ -428,19 +428,9 @@ pub const BufferedWriter = @import("Io/buffered_writer.zig").BufferedWriter; /// Deprecated in favor of `Writer`. pub const bufferedWriter = @import("Io/buffered_writer.zig").bufferedWriter; /// Deprecated in favor of `Reader`. -pub const BufferedReader = @import("Io/buffered_reader.zig").BufferedReader; -/// Deprecated in favor of `Reader`. -pub const bufferedReader = @import("Io/buffered_reader.zig").bufferedReader; -/// Deprecated in favor of `Reader`. -pub const bufferedReaderSize = @import("Io/buffered_reader.zig").bufferedReaderSize; -/// Deprecated in favor of `Reader`. pub const FixedBufferStream = @import("Io/fixed_buffer_stream.zig").FixedBufferStream; /// Deprecated in favor of `Reader`. pub const fixedBufferStream = @import("Io/fixed_buffer_stream.zig").fixedBufferStream; -/// Deprecated in favor of `Reader.Limited`. -pub const LimitedReader = @import("Io/limited_reader.zig").LimitedReader; -/// Deprecated in favor of `Reader.Limited`. -pub const limitedReader = @import("Io/limited_reader.zig").limitedReader; /// Deprecated with no replacement; inefficient pattern pub const CountingWriter = @import("Io/counting_writer.zig").CountingWriter; /// Deprecated with no replacement; inefficient pattern @@ -926,7 +916,6 @@ pub fn PollFiles(comptime StreamEnum: type) type { test { _ = Reader; _ = Writer; - _ = BufferedReader; _ = BufferedWriter; _ = CountingWriter; _ = CountingReader; diff --git a/lib/std/Io/Reader.zig b/lib/std/Io/Reader.zig index 6485e2edffd8..b8fa6f2313d6 100644 --- a/lib/std/Io/Reader.zig +++ b/lib/std/Io/Reader.zig @@ -367,8 +367,11 @@ pub fn appendRemainingUnlimited( const buffer_contents = r.buffer[r.seek..r.end]; try list.ensureUnusedCapacity(gpa, buffer_contents.len + bump); list.appendSliceAssumeCapacity(buffer_contents); - r.seek = 0; - r.end = 0; + // If statement protects `ending`. + if (r.end != 0) { + r.seek = 0; + r.end = 0; + } // From here, we leave `buffer` empty, appending directly to `list`. var writer: Writer = .{ .buffer = undefined, @@ -1306,31 +1309,6 @@ pub fn defaultRebase(r: *Reader, capacity: usize) RebaseError!void { r.end = data.len; } -/// Advances the stream and decreases the size of the storage buffer by `n`, -/// returning the range of bytes no longer accessible by `r`. -/// -/// This action can be undone by `restitute`. -/// -/// Asserts there are at least `n` buffered bytes already. -/// -/// Asserts that `r.seek` is zero, i.e. the buffer is in a rebased state. -pub fn steal(r: *Reader, n: usize) []u8 { - assert(r.seek == 0); - assert(n <= r.end); - const stolen = r.buffer[0..n]; - r.buffer = r.buffer[n..]; - r.end -= n; - return stolen; -} - -/// Expands the storage buffer, undoing the effects of `steal` -/// Assumes that `n` does not exceed the total number of stolen bytes. -pub fn restitute(r: *Reader, n: usize) void { - r.buffer = (r.buffer.ptr - n)[0 .. r.buffer.len + n]; - r.end += n; - r.seek += n; -} - test fixed { var r: Reader = .fixed("a\x02"); try testing.expect((try r.takeByte()) == 'a'); diff --git a/lib/std/Io/Writer.zig b/lib/std/Io/Writer.zig index a84077f8f313..797a69914c3b 100644 --- a/lib/std/Io/Writer.zig +++ b/lib/std/Io/Writer.zig @@ -191,29 +191,87 @@ pub fn writeSplatHeader( data: []const []const u8, splat: usize, ) Error!usize { - const new_end = w.end + header.len; - if (new_end <= w.buffer.len) { - @memcpy(w.buffer[w.end..][0..header.len], header); - w.end = new_end; - return header.len + try writeSplat(w, data, splat); - } - var vecs: [8][]const u8 = undefined; // Arbitrarily chosen size. - var i: usize = 1; - vecs[0] = header; - for (data[0 .. data.len - 1]) |buf| { - if (buf.len == 0) continue; - vecs[i] = buf; - i += 1; - if (vecs.len - i == 0) break; + return writeSplatHeaderLimit(w, header, data, splat, .unlimited); +} + +/// Equivalent to `writeSplatHeader` but writes at most `limit` bytes. +pub fn writeSplatHeaderLimit( + w: *Writer, + header: []const u8, + data: []const []const u8, + splat: usize, + limit: Limit, +) Error!usize { + var remaining = @intFromEnum(limit); + { + const copy_len = @min(header.len, w.buffer.len - w.end, remaining); + if (header.len - copy_len != 0) return writeSplatHeaderLimitFinish(w, header, data, splat, remaining); + @memcpy(w.buffer[w.end..][0..copy_len], header[0..copy_len]); + w.end += copy_len; + remaining -= copy_len; + } + for (data[0 .. data.len - 1], 0..) |buf, i| { + const copy_len = @min(buf.len, w.buffer.len - w.end, remaining); + if (buf.len - copy_len != 0) return @intFromEnum(limit) - remaining + + try writeSplatHeaderLimitFinish(w, &.{}, data[i..], splat, remaining); + @memcpy(w.buffer[w.end..][0..copy_len], buf[0..copy_len]); + w.end += copy_len; + remaining -= copy_len; } const pattern = data[data.len - 1]; - const new_splat = s: { - if (pattern.len == 0 or vecs.len - i == 0) break :s 1; + const splat_n = pattern.len * splat; + if (splat_n > @min(w.buffer.len - w.end, remaining)) { + const buffered_n = @intFromEnum(limit) - remaining; + const written = try writeSplatHeaderLimitFinish(w, &.{}, data[data.len - 1 ..][0..1], splat, remaining); + return buffered_n + written; + } + + for (0..splat) |_| { + @memcpy(w.buffer[w.end..][0..pattern.len], pattern); + w.end += pattern.len; + } + + remaining -= splat_n; + return @intFromEnum(limit) - remaining; +} + +fn writeSplatHeaderLimitFinish( + w: *Writer, + header: []const u8, + data: []const []const u8, + splat: usize, + limit: usize, +) Error!usize { + var remaining = limit; + var vecs: [8][]const u8 = undefined; + var i: usize = 0; + v: { + if (header.len != 0) { + const copy_len = @min(header.len, remaining); + vecs[i] = header[0..copy_len]; + i += 1; + remaining -= copy_len; + if (remaining == 0) break :v; + } + for (data[0 .. data.len - 1]) |buf| if (buf.len != 0) { + const copy_len = @min(header.len, remaining); + vecs[i] = buf; + i += 1; + remaining -= copy_len; + if (remaining == 0) break :v; + if (vecs.len - i == 0) break :v; + }; + const pattern = data[data.len - 1]; + if (splat == 1) { + vecs[i] = pattern[0..@min(remaining, pattern.len)]; + i += 1; + break :v; + } vecs[i] = pattern; i += 1; - break :s splat; - }; - return w.vtable.drain(w, vecs[0..i], new_splat); + return w.vtable.drain(w, (&vecs)[0..i], @min(remaining / pattern.len, splat)); + } + return w.vtable.drain(w, (&vecs)[0..i], 1); } test "writeSplatHeader splatting avoids buffer aliasing temptation" { diff --git a/lib/std/Io/buffered_reader.zig b/lib/std/Io/buffered_reader.zig deleted file mode 100644 index 548dd92f7362..000000000000 --- a/lib/std/Io/buffered_reader.zig +++ /dev/null @@ -1,201 +0,0 @@ -const std = @import("../std.zig"); -const io = std.io; -const mem = std.mem; -const assert = std.debug.assert; -const testing = std.testing; - -pub fn BufferedReader(comptime buffer_size: usize, comptime ReaderType: type) type { - return struct { - unbuffered_reader: ReaderType, - buf: [buffer_size]u8 = undefined, - start: usize = 0, - end: usize = 0, - - pub const Error = ReaderType.Error; - pub const Reader = io.GenericReader(*Self, Error, read); - - const Self = @This(); - - pub fn read(self: *Self, dest: []u8) Error!usize { - // First try reading from the already buffered data onto the destination. - const current = self.buf[self.start..self.end]; - if (current.len != 0) { - const to_transfer = @min(current.len, dest.len); - @memcpy(dest[0..to_transfer], current[0..to_transfer]); - self.start += to_transfer; - return to_transfer; - } - - // If dest is large, read from the unbuffered reader directly into the destination. - if (dest.len >= buffer_size) { - return self.unbuffered_reader.read(dest); - } - - // If dest is small, read from the unbuffered reader into our own internal buffer, - // and then transfer to destination. - self.end = try self.unbuffered_reader.read(&self.buf); - const to_transfer = @min(self.end, dest.len); - @memcpy(dest[0..to_transfer], self.buf[0..to_transfer]); - self.start = to_transfer; - return to_transfer; - } - - pub fn reader(self: *Self) Reader { - return .{ .context = self }; - } - }; -} - -pub fn bufferedReader(reader: anytype) BufferedReader(4096, @TypeOf(reader)) { - return .{ .unbuffered_reader = reader }; -} - -pub fn bufferedReaderSize(comptime size: usize, reader: anytype) BufferedReader(size, @TypeOf(reader)) { - return .{ .unbuffered_reader = reader }; -} - -test "OneByte" { - const OneByteReadReader = struct { - str: []const u8, - curr: usize, - - const Error = error{NoError}; - const Self = @This(); - const Reader = io.GenericReader(*Self, Error, read); - - fn init(str: []const u8) Self { - return Self{ - .str = str, - .curr = 0, - }; - } - - fn read(self: *Self, dest: []u8) Error!usize { - if (self.str.len <= self.curr or dest.len == 0) - return 0; - - dest[0] = self.str[self.curr]; - self.curr += 1; - return 1; - } - - fn reader(self: *Self) Reader { - return .{ .context = self }; - } - }; - - const str = "This is a test"; - var one_byte_stream = OneByteReadReader.init(str); - var buf_reader = bufferedReader(one_byte_stream.reader()); - const stream = buf_reader.reader(); - - const res = try stream.readAllAlloc(testing.allocator, str.len + 1); - defer testing.allocator.free(res); - try testing.expectEqualSlices(u8, str, res); -} - -fn smallBufferedReader(underlying_stream: anytype) BufferedReader(8, @TypeOf(underlying_stream)) { - return .{ .unbuffered_reader = underlying_stream }; -} -test "Block" { - const BlockReader = struct { - block: []const u8, - reads_allowed: usize, - curr_read: usize, - - const Error = error{NoError}; - const Self = @This(); - const Reader = io.GenericReader(*Self, Error, read); - - fn init(block: []const u8, reads_allowed: usize) Self { - return Self{ - .block = block, - .reads_allowed = reads_allowed, - .curr_read = 0, - }; - } - - fn read(self: *Self, dest: []u8) Error!usize { - if (self.curr_read >= self.reads_allowed) return 0; - @memcpy(dest[0..self.block.len], self.block); - - self.curr_read += 1; - return self.block.len; - } - - fn reader(self: *Self) Reader { - return .{ .context = self }; - } - }; - - const block = "0123"; - - // len out == block - { - var test_buf_reader: BufferedReader(4, BlockReader) = .{ - .unbuffered_reader = BlockReader.init(block, 2), - }; - const reader = test_buf_reader.reader(); - var out_buf: [4]u8 = undefined; - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, block); - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, block); - try testing.expectEqual(try reader.readAll(&out_buf), 0); - } - - // len out < block - { - var test_buf_reader: BufferedReader(4, BlockReader) = .{ - .unbuffered_reader = BlockReader.init(block, 2), - }; - const reader = test_buf_reader.reader(); - var out_buf: [3]u8 = undefined; - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, "012"); - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, "301"); - const n = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, out_buf[0..n], "23"); - try testing.expectEqual(try reader.readAll(&out_buf), 0); - } - - // len out > block - { - var test_buf_reader: BufferedReader(4, BlockReader) = .{ - .unbuffered_reader = BlockReader.init(block, 2), - }; - const reader = test_buf_reader.reader(); - var out_buf: [5]u8 = undefined; - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, "01230"); - const n = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, out_buf[0..n], "123"); - try testing.expectEqual(try reader.readAll(&out_buf), 0); - } - - // len out == 0 - { - var test_buf_reader: BufferedReader(4, BlockReader) = .{ - .unbuffered_reader = BlockReader.init(block, 2), - }; - const reader = test_buf_reader.reader(); - var out_buf: [0]u8 = undefined; - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, ""); - } - - // len bufreader buf > block - { - var test_buf_reader: BufferedReader(5, BlockReader) = .{ - .unbuffered_reader = BlockReader.init(block, 2), - }; - const reader = test_buf_reader.reader(); - var out_buf: [4]u8 = undefined; - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, block); - _ = try reader.readAll(&out_buf); - try testing.expectEqualSlices(u8, &out_buf, block); - try testing.expectEqual(try reader.readAll(&out_buf), 0); - } -} diff --git a/lib/std/Io/limited_reader.zig b/lib/std/Io/limited_reader.zig deleted file mode 100644 index b6b555f76dec..000000000000 --- a/lib/std/Io/limited_reader.zig +++ /dev/null @@ -1,45 +0,0 @@ -const std = @import("../std.zig"); -const io = std.io; -const assert = std.debug.assert; -const testing = std.testing; - -pub fn LimitedReader(comptime ReaderType: type) type { - return struct { - inner_reader: ReaderType, - bytes_left: u64, - - pub const Error = ReaderType.Error; - pub const Reader = io.GenericReader(*Self, Error, read); - - const Self = @This(); - - pub fn read(self: *Self, dest: []u8) Error!usize { - const max_read = @min(self.bytes_left, dest.len); - const n = try self.inner_reader.read(dest[0..max_read]); - self.bytes_left -= n; - return n; - } - - pub fn reader(self: *Self) Reader { - return .{ .context = self }; - } - }; -} - -/// Returns an initialised `LimitedReader`. -/// `bytes_left` is a `u64` to be able to take 64 bit file offsets -pub fn limitedReader(inner_reader: anytype, bytes_left: u64) LimitedReader(@TypeOf(inner_reader)) { - return .{ .inner_reader = inner_reader, .bytes_left = bytes_left }; -} - -test "basic usage" { - const data = "hello world"; - var fbs = std.io.fixedBufferStream(data); - var early_stream = limitedReader(fbs.reader(), 3); - - var buf: [5]u8 = undefined; - try testing.expectEqual(@as(usize, 3), try early_stream.reader().read(&buf)); - try testing.expectEqualSlices(u8, data[0..3], buf[0..3]); - try testing.expectEqual(@as(usize, 0), try early_stream.reader().read(&buf)); - try testing.expectError(error.EndOfStream, early_stream.reader().skipBytes(10, .{})); -} diff --git a/lib/std/Io/test.zig b/lib/std/Io/test.zig index e0fcea7674d9..9733e6f044cb 100644 --- a/lib/std/Io/test.zig +++ b/lib/std/Io/test.zig @@ -45,9 +45,9 @@ test "write a file, read it, then delete it" { const expected_file_size: u64 = "begin".len + data.len + "end".len; try expectEqual(expected_file_size, file_size); - var buf_stream = io.bufferedReader(file.deprecatedReader()); - const st = buf_stream.reader(); - const contents = try st.readAllAlloc(std.testing.allocator, 2 * 1024); + var file_buffer: [1024]u8 = undefined; + var file_reader = file.reader(&file_buffer); + const contents = try file_reader.interface.allocRemaining(std.testing.allocator, .limited(2 * 1024)); defer std.testing.allocator.free(contents); try expect(mem.eql(u8, contents[0.."begin".len], "begin")); diff --git a/lib/std/Uri.zig b/lib/std/Uri.zig index 19af1512c2a8..7244c9595b26 100644 --- a/lib/std/Uri.zig +++ b/lib/std/Uri.zig @@ -4,6 +4,8 @@ const std = @import("std.zig"); const testing = std.testing; const Uri = @This(); +const Allocator = std.mem.Allocator; +const Writer = std.Io.Writer; scheme: []const u8, user: ?Component = null, @@ -14,6 +16,32 @@ path: Component = Component.empty, query: ?Component = null, fragment: ?Component = null, +pub const host_name_max = 255; + +/// Returned value may point into `buffer` or be the original string. +/// +/// Suggested buffer length: `host_name_max`. +/// +/// See also: +/// * `getHostAlloc` +pub fn getHost(uri: Uri, buffer: []u8) error{ UriMissingHost, UriHostTooLong }![]const u8 { + const component = uri.host orelse return error.UriMissingHost; + return component.toRaw(buffer) catch |err| switch (err) { + error.NoSpaceLeft => return error.UriHostTooLong, + }; +} + +/// Returned value may point into `buffer` or be the original string. +/// +/// See also: +/// * `getHost` +pub fn getHostAlloc(uri: Uri, arena: Allocator) error{ UriMissingHost, UriHostTooLong, OutOfMemory }![]const u8 { + const component = uri.host orelse return error.UriMissingHost; + const result = try component.toRawMaybeAlloc(arena); + if (result.len > host_name_max) return error.UriHostTooLong; + return result; +} + pub const Component = union(enum) { /// Invalid characters in this component must be percent encoded /// before being printed as part of a URI. @@ -30,11 +58,19 @@ pub const Component = union(enum) { }; } + /// Returned value may point into `buffer` or be the original string. + pub fn toRaw(component: Component, buffer: []u8) error{NoSpaceLeft}![]const u8 { + return switch (component) { + .raw => |raw| raw, + .percent_encoded => |percent_encoded| if (std.mem.indexOfScalar(u8, percent_encoded, '%')) |_| + try std.fmt.bufPrint(buffer, "{f}", .{std.fmt.alt(component, .formatRaw)}) + else + percent_encoded, + }; + } + /// Allocates the result with `arena` only if needed, so the result should not be freed. - pub fn toRawMaybeAlloc( - component: Component, - arena: std.mem.Allocator, - ) std.mem.Allocator.Error![]const u8 { + pub fn toRawMaybeAlloc(component: Component, arena: Allocator) Allocator.Error![]const u8 { return switch (component) { .raw => |raw| raw, .percent_encoded => |percent_encoded| if (std.mem.indexOfScalar(u8, percent_encoded, '%')) |_| @@ -44,7 +80,7 @@ pub const Component = union(enum) { }; } - pub fn formatRaw(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatRaw(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try w.writeAll(raw), .percent_encoded => |percent_encoded| { @@ -67,56 +103,56 @@ pub const Component = union(enum) { } } - pub fn formatEscaped(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatEscaped(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isUnreserved), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatUser(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatUser(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isUserChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatPassword(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatPassword(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isPasswordChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatHost(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatHost(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isHostChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatPath(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatPath(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isPathChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatQuery(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatQuery(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isQueryChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn formatFragment(component: Component, w: *std.io.Writer) std.io.Writer.Error!void { + pub fn formatFragment(component: Component, w: *Writer) Writer.Error!void { switch (component) { .raw => |raw| try percentEncode(w, raw, isFragmentChar), .percent_encoded => |percent_encoded| try w.writeAll(percent_encoded), } } - pub fn percentEncode(w: *std.io.Writer, raw: []const u8, comptime isValidChar: fn (u8) bool) std.io.Writer.Error!void { + pub fn percentEncode(w: *Writer, raw: []const u8, comptime isValidChar: fn (u8) bool) Writer.Error!void { var start: usize = 0; for (raw, 0..) |char, index| { if (isValidChar(char)) continue; @@ -165,17 +201,15 @@ pub const ParseError = error{ UnexpectedCharacter, InvalidFormat, InvalidPort }; /// The return value will contain strings pointing into the original `text`. /// Each component that is provided, will be non-`null`. pub fn parseAfterScheme(scheme: []const u8, text: []const u8) ParseError!Uri { - var reader = SliceReader{ .slice = text }; - var uri: Uri = .{ .scheme = scheme, .path = undefined }; + var i: usize = 0; - if (reader.peekPrefix("//")) a: { // authority part - std.debug.assert(reader.get().? == '/'); - std.debug.assert(reader.get().? == '/'); - - const authority = reader.readUntil(isAuthoritySeparator); + if (std.mem.startsWith(u8, text, "//")) a: { + i = std.mem.indexOfAnyPos(u8, text, 2, &authority_sep) orelse text.len; + const authority = text[2..i]; if (authority.len == 0) { - if (reader.peekPrefix("/")) break :a else return error.InvalidFormat; + if (!std.mem.startsWith(u8, text[2..], "/")) return error.InvalidFormat; + break :a; } var start_of_host: usize = 0; @@ -225,26 +259,28 @@ pub fn parseAfterScheme(scheme: []const u8, text: []const u8) ParseError!Uri { uri.host = .{ .percent_encoded = authority[start_of_host..end_of_host] }; } - uri.path = .{ .percent_encoded = reader.readUntil(isPathSeparator) }; + const path_start = i; + i = std.mem.indexOfAnyPos(u8, text, path_start, &path_sep) orelse text.len; + uri.path = .{ .percent_encoded = text[path_start..i] }; - if ((reader.peek() orelse 0) == '?') { // query part - std.debug.assert(reader.get().? == '?'); - uri.query = .{ .percent_encoded = reader.readUntil(isQuerySeparator) }; + if (std.mem.startsWith(u8, text[i..], "?")) { + const query_start = i + 1; + i = std.mem.indexOfScalarPos(u8, text, query_start, '#') orelse text.len; + uri.query = .{ .percent_encoded = text[query_start..i] }; } - if ((reader.peek() orelse 0) == '#') { // fragment part - std.debug.assert(reader.get().? == '#'); - uri.fragment = .{ .percent_encoded = reader.readUntilEof() }; + if (std.mem.startsWith(u8, text[i..], "#")) { + uri.fragment = .{ .percent_encoded = text[i + 1 ..] }; } return uri; } -pub fn format(uri: *const Uri, writer: *std.io.Writer) std.io.Writer.Error!void { +pub fn format(uri: *const Uri, writer: *Writer) Writer.Error!void { return writeToStream(uri, writer, .all); } -pub fn writeToStream(uri: *const Uri, writer: *std.io.Writer, flags: Format.Flags) std.io.Writer.Error!void { +pub fn writeToStream(uri: *const Uri, writer: *Writer, flags: Format.Flags) Writer.Error!void { if (flags.scheme) { try writer.print("{s}:", .{uri.scheme}); if (flags.authority and uri.host != null) { @@ -318,7 +354,7 @@ pub const Format = struct { }; }; - pub fn default(f: Format, writer: *std.io.Writer) std.io.Writer.Error!void { + pub fn default(f: Format, writer: *Writer) Writer.Error!void { return writeToStream(f.uri, writer, f.flags); } }; @@ -327,41 +363,34 @@ pub fn fmt(uri: *const Uri, flags: Format.Flags) std.fmt.Formatter(Format, Forma return .{ .data = .{ .uri = uri, .flags = flags } }; } -/// Parses the URI or returns an error. -/// The return value will contain strings pointing into the -/// original `text`. Each component that is provided, will be non-`null`. +/// The return value will contain strings pointing into the original `text`. +/// Each component that is provided will be non-`null`. pub fn parse(text: []const u8) ParseError!Uri { - var reader: SliceReader = .{ .slice = text }; - const scheme = reader.readWhile(isSchemeChar); - - // after the scheme, a ':' must appear - if (reader.get()) |c| { - if (c != ':') - return error.UnexpectedCharacter; - } else { - return error.InvalidFormat; - } - - return parseAfterScheme(scheme, reader.readUntilEof()); + const end = for (text, 0..) |byte, i| { + if (!isSchemeChar(byte)) break i; + } else text.len; + // After the scheme, a ':' must appear. + if (end >= text.len) return error.InvalidFormat; + if (text[end] != ':') return error.UnexpectedCharacter; + return parseAfterScheme(text[0..end], text[end + 1 ..]); } pub const ResolveInPlaceError = ParseError || error{NoSpaceLeft}; -/// Resolves a URI against a base URI, conforming to RFC 3986, Section 5. -/// Copies `new` to the beginning of `aux_buf.*`, allowing the slices to overlap, -/// then parses `new` as a URI, and then resolves the path in place. +/// Resolves a URI against a base URI, conforming to +/// [RFC 3986, Section 5](https://www.rfc-editor.org/rfc/rfc3986#section-5) +/// +/// Assumes new location is already copied to the beginning of `aux_buf.*`. +/// Parses that new location as a URI, and then resolves the path in place. +/// /// If a merge needs to take place, the newly constructed path will be stored -/// in `aux_buf.*` just after the copied `new`, and `aux_buf.*` will be modified -/// to only contain the remaining unused space. -pub fn resolve_inplace(base: Uri, new: []const u8, aux_buf: *[]u8) ResolveInPlaceError!Uri { - std.mem.copyForwards(u8, aux_buf.*, new); - // At this point, new is an invalid pointer. - const new_mut = aux_buf.*[0..new.len]; - aux_buf.* = aux_buf.*[new.len..]; - - const new_parsed = parse(new_mut) catch |err| - (parseAfterScheme("", new_mut) catch return err); - // As you can see above, `new_mut` is not a const pointer. +/// in `aux_buf.*` just after the copied location, and `aux_buf.*` will be +/// modified to only contain the remaining unused space. +pub fn resolveInPlace(base: Uri, new_len: usize, aux_buf: *[]u8) ResolveInPlaceError!Uri { + const new = aux_buf.*[0..new_len]; + const new_parsed = parse(new) catch |err| (parseAfterScheme("", new) catch return err); + aux_buf.* = aux_buf.*[new_len..]; + // As you can see above, `new` is not a const pointer. const new_path: []u8 = @constCast(new_parsed.path.percent_encoded); if (new_parsed.scheme.len > 0) return .{ @@ -461,7 +490,7 @@ test remove_dot_segments { /// 5.2.3. Merge Paths fn merge_paths(base: Component, new: []u8, aux_buf: *[]u8) error{NoSpaceLeft}!Component { - var aux: std.io.Writer = .fixed(aux_buf.*); + var aux: Writer = .fixed(aux_buf.*); if (!base.isEmpty()) { base.formatPath(&aux) catch return error.NoSpaceLeft; aux.end = std.mem.lastIndexOfScalar(u8, aux.buffered(), '/') orelse return remove_dot_segments(new); @@ -472,59 +501,6 @@ fn merge_paths(base: Component, new: []u8, aux_buf: *[]u8) error{NoSpaceLeft}!Co return merged_path; } -const SliceReader = struct { - const Self = @This(); - - slice: []const u8, - offset: usize = 0, - - fn get(self: *Self) ?u8 { - if (self.offset >= self.slice.len) - return null; - const c = self.slice[self.offset]; - self.offset += 1; - return c; - } - - fn peek(self: Self) ?u8 { - if (self.offset >= self.slice.len) - return null; - return self.slice[self.offset]; - } - - fn readWhile(self: *Self, comptime predicate: fn (u8) bool) []const u8 { - const start = self.offset; - var end = start; - while (end < self.slice.len and predicate(self.slice[end])) { - end += 1; - } - self.offset = end; - return self.slice[start..end]; - } - - fn readUntil(self: *Self, comptime predicate: fn (u8) bool) []const u8 { - const start = self.offset; - var end = start; - while (end < self.slice.len and !predicate(self.slice[end])) { - end += 1; - } - self.offset = end; - return self.slice[start..end]; - } - - fn readUntilEof(self: *Self) []const u8 { - const start = self.offset; - self.offset = self.slice.len; - return self.slice[start..]; - } - - fn peekPrefix(self: Self, prefix: []const u8) bool { - if (self.offset + prefix.len > self.slice.len) - return false; - return std.mem.eql(u8, self.slice[self.offset..][0..prefix.len], prefix); - } -}; - /// scheme = ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ) fn isSchemeChar(c: u8) bool { return switch (c) { @@ -533,19 +509,6 @@ fn isSchemeChar(c: u8) bool { }; } -/// reserved = gen-delims / sub-delims -fn isReserved(c: u8) bool { - return isGenLimit(c) or isSubLimit(c); -} - -/// gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" -fn isGenLimit(c: u8) bool { - return switch (c) { - ':', ',', '?', '#', '[', ']', '@' => true, - else => false, - }; -} - /// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" /// / "*" / "+" / "," / ";" / "=" fn isSubLimit(c: u8) bool { @@ -585,26 +548,8 @@ fn isQueryChar(c: u8) bool { const isFragmentChar = isQueryChar; -fn isAuthoritySeparator(c: u8) bool { - return switch (c) { - '/', '?', '#' => true, - else => false, - }; -} - -fn isPathSeparator(c: u8) bool { - return switch (c) { - '?', '#' => true, - else => false, - }; -} - -fn isQuerySeparator(c: u8) bool { - return switch (c) { - '#' => true, - else => false, - }; -} +const authority_sep: [3]u8 = .{ '/', '?', '#' }; +const path_sep: [2]u8 = .{ '?', '#' }; test "basic" { const parsed = try parse("https://ziglang.org/download"); diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index da6a431840d2..e647a7710e75 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -49,8 +49,8 @@ pub const hello_retry_request_sequence = [32]u8{ }; pub const close_notify_alert = [_]u8{ - @intFromEnum(AlertLevel.warning), - @intFromEnum(AlertDescription.close_notify), + @intFromEnum(Alert.Level.warning), + @intFromEnum(Alert.Description.close_notify), }; pub const ProtocolVersion = enum(u16) { @@ -138,103 +138,108 @@ pub const ExtensionType = enum(u16) { _, }; -pub const AlertLevel = enum(u8) { - warning = 1, - fatal = 2, - _, -}; +pub const Alert = struct { + level: Level, + description: Description, -pub const AlertDescription = enum(u8) { - pub const Error = error{ - TlsAlertUnexpectedMessage, - TlsAlertBadRecordMac, - TlsAlertRecordOverflow, - TlsAlertHandshakeFailure, - TlsAlertBadCertificate, - TlsAlertUnsupportedCertificate, - TlsAlertCertificateRevoked, - TlsAlertCertificateExpired, - TlsAlertCertificateUnknown, - TlsAlertIllegalParameter, - TlsAlertUnknownCa, - TlsAlertAccessDenied, - TlsAlertDecodeError, - TlsAlertDecryptError, - TlsAlertProtocolVersion, - TlsAlertInsufficientSecurity, - TlsAlertInternalError, - TlsAlertInappropriateFallback, - TlsAlertMissingExtension, - TlsAlertUnsupportedExtension, - TlsAlertUnrecognizedName, - TlsAlertBadCertificateStatusResponse, - TlsAlertUnknownPskIdentity, - TlsAlertCertificateRequired, - TlsAlertNoApplicationProtocol, - TlsAlertUnknown, + pub const Level = enum(u8) { + warning = 1, + fatal = 2, + _, }; - close_notify = 0, - unexpected_message = 10, - bad_record_mac = 20, - record_overflow = 22, - handshake_failure = 40, - bad_certificate = 42, - unsupported_certificate = 43, - certificate_revoked = 44, - certificate_expired = 45, - certificate_unknown = 46, - illegal_parameter = 47, - unknown_ca = 48, - access_denied = 49, - decode_error = 50, - decrypt_error = 51, - protocol_version = 70, - insufficient_security = 71, - internal_error = 80, - inappropriate_fallback = 86, - user_canceled = 90, - missing_extension = 109, - unsupported_extension = 110, - unrecognized_name = 112, - bad_certificate_status_response = 113, - unknown_psk_identity = 115, - certificate_required = 116, - no_application_protocol = 120, - _, + pub const Description = enum(u8) { + pub const Error = error{ + TlsAlertUnexpectedMessage, + TlsAlertBadRecordMac, + TlsAlertRecordOverflow, + TlsAlertHandshakeFailure, + TlsAlertBadCertificate, + TlsAlertUnsupportedCertificate, + TlsAlertCertificateRevoked, + TlsAlertCertificateExpired, + TlsAlertCertificateUnknown, + TlsAlertIllegalParameter, + TlsAlertUnknownCa, + TlsAlertAccessDenied, + TlsAlertDecodeError, + TlsAlertDecryptError, + TlsAlertProtocolVersion, + TlsAlertInsufficientSecurity, + TlsAlertInternalError, + TlsAlertInappropriateFallback, + TlsAlertMissingExtension, + TlsAlertUnsupportedExtension, + TlsAlertUnrecognizedName, + TlsAlertBadCertificateStatusResponse, + TlsAlertUnknownPskIdentity, + TlsAlertCertificateRequired, + TlsAlertNoApplicationProtocol, + TlsAlertUnknown, + }; - pub fn toError(alert: AlertDescription) Error!void { - switch (alert) { - .close_notify => {}, // not an error - .unexpected_message => return error.TlsAlertUnexpectedMessage, - .bad_record_mac => return error.TlsAlertBadRecordMac, - .record_overflow => return error.TlsAlertRecordOverflow, - .handshake_failure => return error.TlsAlertHandshakeFailure, - .bad_certificate => return error.TlsAlertBadCertificate, - .unsupported_certificate => return error.TlsAlertUnsupportedCertificate, - .certificate_revoked => return error.TlsAlertCertificateRevoked, - .certificate_expired => return error.TlsAlertCertificateExpired, - .certificate_unknown => return error.TlsAlertCertificateUnknown, - .illegal_parameter => return error.TlsAlertIllegalParameter, - .unknown_ca => return error.TlsAlertUnknownCa, - .access_denied => return error.TlsAlertAccessDenied, - .decode_error => return error.TlsAlertDecodeError, - .decrypt_error => return error.TlsAlertDecryptError, - .protocol_version => return error.TlsAlertProtocolVersion, - .insufficient_security => return error.TlsAlertInsufficientSecurity, - .internal_error => return error.TlsAlertInternalError, - .inappropriate_fallback => return error.TlsAlertInappropriateFallback, - .user_canceled => {}, // not an error - .missing_extension => return error.TlsAlertMissingExtension, - .unsupported_extension => return error.TlsAlertUnsupportedExtension, - .unrecognized_name => return error.TlsAlertUnrecognizedName, - .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse, - .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity, - .certificate_required => return error.TlsAlertCertificateRequired, - .no_application_protocol => return error.TlsAlertNoApplicationProtocol, - _ => return error.TlsAlertUnknown, + close_notify = 0, + unexpected_message = 10, + bad_record_mac = 20, + record_overflow = 22, + handshake_failure = 40, + bad_certificate = 42, + unsupported_certificate = 43, + certificate_revoked = 44, + certificate_expired = 45, + certificate_unknown = 46, + illegal_parameter = 47, + unknown_ca = 48, + access_denied = 49, + decode_error = 50, + decrypt_error = 51, + protocol_version = 70, + insufficient_security = 71, + internal_error = 80, + inappropriate_fallback = 86, + user_canceled = 90, + missing_extension = 109, + unsupported_extension = 110, + unrecognized_name = 112, + bad_certificate_status_response = 113, + unknown_psk_identity = 115, + certificate_required = 116, + no_application_protocol = 120, + _, + + pub fn toError(description: Description) Error!void { + switch (description) { + .close_notify => {}, // not an error + .unexpected_message => return error.TlsAlertUnexpectedMessage, + .bad_record_mac => return error.TlsAlertBadRecordMac, + .record_overflow => return error.TlsAlertRecordOverflow, + .handshake_failure => return error.TlsAlertHandshakeFailure, + .bad_certificate => return error.TlsAlertBadCertificate, + .unsupported_certificate => return error.TlsAlertUnsupportedCertificate, + .certificate_revoked => return error.TlsAlertCertificateRevoked, + .certificate_expired => return error.TlsAlertCertificateExpired, + .certificate_unknown => return error.TlsAlertCertificateUnknown, + .illegal_parameter => return error.TlsAlertIllegalParameter, + .unknown_ca => return error.TlsAlertUnknownCa, + .access_denied => return error.TlsAlertAccessDenied, + .decode_error => return error.TlsAlertDecodeError, + .decrypt_error => return error.TlsAlertDecryptError, + .protocol_version => return error.TlsAlertProtocolVersion, + .insufficient_security => return error.TlsAlertInsufficientSecurity, + .internal_error => return error.TlsAlertInternalError, + .inappropriate_fallback => return error.TlsAlertInappropriateFallback, + .user_canceled => {}, // not an error + .missing_extension => return error.TlsAlertMissingExtension, + .unsupported_extension => return error.TlsAlertUnsupportedExtension, + .unrecognized_name => return error.TlsAlertUnrecognizedName, + .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse, + .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity, + .certificate_required => return error.TlsAlertCertificateRequired, + .no_application_protocol => return error.TlsAlertNoApplicationProtocol, + _ => return error.TlsAlertUnknown, + } } - } + }; }; pub const SignatureScheme = enum(u16) { @@ -650,7 +655,7 @@ pub const Decoder = struct { } /// Use this function to increase `their_end`. - pub fn readAtLeast(d: *Decoder, stream: anytype, their_amt: usize) !void { + pub fn readAtLeast(d: *Decoder, stream: *std.io.Reader, their_amt: usize) !void { assert(!d.disable_reads); const existing_amt = d.cap - d.idx; d.their_end = d.idx + their_amt; @@ -658,14 +663,16 @@ pub const Decoder = struct { const request_amt = their_amt - existing_amt; const dest = d.buf[d.cap..]; if (request_amt > dest.len) return error.TlsRecordOverflow; - const actual_amt = try stream.readAtLeast(dest, request_amt); - if (actual_amt < request_amt) return error.TlsConnectionTruncated; - d.cap += actual_amt; + stream.readSlice(dest[0..request_amt]) catch |err| switch (err) { + error.EndOfStream => return error.TlsConnectionTruncated, + error.ReadFailed => return error.ReadFailed, + }; + d.cap += request_amt; } /// Same as `readAtLeast` but also increases `our_end` by exactly `our_amt`. /// Use when `our_amt` is calculated by us, not by them. - pub fn readAtLeastOurAmt(d: *Decoder, stream: anytype, our_amt: usize) !void { + pub fn readAtLeastOurAmt(d: *Decoder, stream: *std.io.Reader, our_amt: usize) !void { assert(!d.disable_reads); try readAtLeast(d, stream, our_amt); d.our_end = d.idx + our_amt; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 3fa7b73d0631..5e89c071c62b 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -1,11 +1,15 @@ +const builtin = @import("builtin"); +const native_endian = builtin.cpu.arch.endian(); + const std = @import("../../std.zig"); const tls = std.crypto.tls; const Client = @This(); -const net = std.net; const mem = std.mem; const crypto = std.crypto; const assert = std.debug.assert; const Certificate = std.crypto.Certificate; +const Reader = std.Io.Reader; +const Writer = std.Io.Writer; const max_ciphertext_len = tls.max_ciphertext_len; const hmacExpandLabel = tls.hmacExpandLabel; @@ -13,44 +17,60 @@ const hkdfExpandLabel = tls.hkdfExpandLabel; const int = tls.int; const array = tls.array; +/// The encrypted stream from the server to the client. Bytes are pulled from +/// here via `reader`. +/// +/// The buffer is asserted to have capacity at least `min_buffer_len`. +input: *Reader, +/// Decrypted stream from the server to the client. +reader: Reader, + +/// The encrypted stream from the client to the server. Bytes are pushed here +/// via `writer`. +/// +/// The buffer is asserted to have capacity at least `min_buffer_len`. +output: *Writer, +/// The plaintext stream from the client to the server. +writer: Writer, + +/// Populated when `error.TlsAlert` is returned. +alert: ?tls.Alert = null, +read_err: ?ReadError = null, tls_version: tls.ProtocolVersion, read_seq: u64, write_seq: u64, -/// The starting index of cleartext bytes inside `partially_read_buffer`. -partial_cleartext_idx: u15, -/// The ending index of cleartext bytes inside `partially_read_buffer` as well -/// as the starting index of ciphertext bytes. -partial_ciphertext_idx: u15, -/// The ending index of ciphertext bytes inside `partially_read_buffer`. -partial_ciphertext_end: u15, /// When this is true, the stream may still not be at the end because there -/// may be data in `partially_read_buffer`. +/// may be data in the input buffer. received_close_notify: bool, -/// By default, reaching the end-of-stream when reading from the server will -/// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify -/// message has been received. By setting this flag to `true`, instead, the -/// end-of-stream will be forwarded to the application layer above TLS. -/// This makes the application vulnerable to truncation attacks unless the -/// application layer itself verifies that the amount of data received equals -/// the amount of data expected, such as HTTP with the Content-Length header. allow_truncation_attacks: bool, application_cipher: tls.ApplicationCipher, -/// The size is enough to contain exactly one TLSCiphertext record. -/// This buffer is segmented into four parts: -/// 0. unused -/// 1. cleartext -/// 2. ciphertext -/// 3. unused -/// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and -/// `partial_ciphertext_end` describe the span of the segments. -partially_read_buffer: [tls.max_ciphertext_record_len]u8, -/// If non-null, ssl secrets are logged to a file. Creating such a log file allows other -/// programs with access to that file to decrypt all traffic over this connection. -ssl_key_log: ?struct { + +/// If non-null, ssl secrets are logged to a stream. Creating such a log file +/// allows other programs with access to that file to decrypt all traffic over +/// this connection. +ssl_key_log: ?*SslKeyLog, + +pub const ReadError = error{ + /// The alert description will be stored in `alert`. + TlsAlert, + TlsBadLength, + TlsBadRecordMac, + TlsConnectionTruncated, + TlsDecodeError, + TlsRecordOverflow, + TlsUnexpectedMessage, + TlsIllegalParameter, + TlsSequenceOverflow, + /// The buffer provided to the read function was not at least + /// `min_buffer_len`. + OutputBufferUndersize, +}; + +pub const SslKeyLog = struct { client_key_seq: u64, server_key_seq: u64, client_random: [32]u8, - file: std.fs.File, + writer: *Writer, fn clientCounter(key_log: *@This()) u64 { defer key_log.client_key_seq += 1; @@ -61,51 +81,12 @@ ssl_key_log: ?struct { defer key_log.server_key_seq += 1; return key_log.server_key_seq; } -}, - -/// This is an example of the type that is needed by the read and write -/// functions. It can have any fields but it must at least have these -/// functions. -/// -/// Note that `std.net.Stream` conforms to this interface. -/// -/// This declaration serves as documentation only. -pub const StreamInterface = struct { - /// Can be any error set. - pub const ReadError = error{}; - - /// Returns the number of bytes read. The number read may be less than the - /// buffer space provided. End-of-stream is indicated by a return value of 0. - /// - /// The `iovecs` parameter is mutable because so that function may to - /// mutate the fields in order to handle partial reads from the underlying - /// stream layer. - pub fn readv(this: @This(), iovecs: []std.posix.iovec) ReadError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// Can be any error set. - pub const WriteError = error{}; - - /// Returns the number of bytes read, which may be less than the buffer - /// space provided. A short read does not indicate end-of-stream. - pub fn writev(this: @This(), iovecs: []const std.posix.iovec_const) WriteError!usize { - _ = .{ this, iovecs }; - @panic("unimplemented"); - } - - /// Returns the number of bytes read, which may be less than the buffer - /// space provided, indicating end-of-stream. - /// The `iovecs` parameter is mutable in case this function needs to mutate - /// the fields in order to handle partial writes from the underlying layer. - pub fn writevAll(this: @This(), iovecs: []std.posix.iovec_const) WriteError!usize { - // This can be implemented in terms of writev, or specialized if desired. - _ = .{ this, iovecs }; - @panic("unimplemented"); - } }; +/// The `Reader` supplied to `init` requires a buffer capacity +/// at least this amount. +pub const min_buffer_len = tls.max_ciphertext_record_len; + pub const Options = struct { /// How to perform host verification of server certificates. host: union(enum) { @@ -127,64 +108,85 @@ pub const Options = struct { /// Verify that the server certificate is authorized by a given ca bundle. bundle: Certificate.Bundle, }, - /// If non-null, ssl secrets are logged to this file. Creating such a log file allows + /// If non-null, ssl secrets are logged to this stream. Creating such a log file allows /// other programs with access to that file to decrypt all traffic over this connection. - ssl_key_log_file: ?std.fs.File = null, + /// + /// Only the `writer` field is observed during the handshake (`init`). + /// After that, the other fields are populated. + ssl_key_log: ?*SslKeyLog = null, + /// By default, reaching the end-of-stream when reading from the server will + /// cause `error.TlsConnectionTruncated` to be returned, unless a close_notify + /// message has been received. By setting this flag to `true`, instead, the + /// end-of-stream will be forwarded to the application layer above TLS. + /// + /// This makes the application vulnerable to truncation attacks unless the + /// application layer itself verifies that the amount of data received equals + /// the amount of data expected, such as HTTP with the Content-Length header. + allow_truncation_attacks: bool = false, + write_buffer: []u8, + read_buffer: []u8, + /// Populated when `error.TlsAlert` is returned from `init`. + alert: ?*tls.Alert = null, }; -pub fn InitError(comptime Stream: type) type { - return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{ - InsufficientEntropy, - DiskQuota, - LockViolation, - NotOpenForWriting, - TlsUnexpectedMessage, - TlsIllegalParameter, - TlsDecryptFailure, - TlsRecordOverflow, - TlsBadRecordMac, - CertificateFieldHasInvalidLength, - CertificateHostMismatch, - CertificatePublicKeyInvalid, - CertificateExpired, - CertificateFieldHasWrongDataType, - CertificateIssuerMismatch, - CertificateNotYetValid, - CertificateSignatureAlgorithmMismatch, - CertificateSignatureAlgorithmUnsupported, - CertificateSignatureInvalid, - CertificateSignatureInvalidLength, - CertificateSignatureNamedCurveUnsupported, - CertificateSignatureUnsupportedBitCount, - TlsCertificateNotVerified, - TlsBadSignatureScheme, - TlsBadRsaSignatureBitCount, - InvalidEncoding, - IdentityElement, - SignatureVerificationFailed, - TlsDecryptError, - TlsConnectionTruncated, - TlsDecodeError, - UnsupportedCertificateVersion, - CertificateTimeInvalid, - CertificateHasUnrecognizedObjectId, - CertificateHasInvalidBitString, - MessageTooLong, - NegativeIntoUnsigned, - TargetTooSmall, - BufferTooSmall, - InvalidSignature, - NotSquare, - NonCanonical, - WeakPublicKey, - }; -} +const InitError = error{ + WriteFailed, + ReadFailed, + InsufficientEntropy, + DiskQuota, + LockViolation, + NotOpenForWriting, + /// The alert description will be stored in `alert`. + TlsAlert, + TlsUnexpectedMessage, + TlsIllegalParameter, + TlsDecryptFailure, + TlsRecordOverflow, + TlsBadRecordMac, + CertificateFieldHasInvalidLength, + CertificateHostMismatch, + CertificatePublicKeyInvalid, + CertificateExpired, + CertificateFieldHasWrongDataType, + CertificateIssuerMismatch, + CertificateNotYetValid, + CertificateSignatureAlgorithmMismatch, + CertificateSignatureAlgorithmUnsupported, + CertificateSignatureInvalid, + CertificateSignatureInvalidLength, + CertificateSignatureNamedCurveUnsupported, + CertificateSignatureUnsupportedBitCount, + TlsCertificateNotVerified, + TlsBadSignatureScheme, + TlsBadRsaSignatureBitCount, + InvalidEncoding, + IdentityElement, + SignatureVerificationFailed, + TlsDecryptError, + TlsConnectionTruncated, + TlsDecodeError, + UnsupportedCertificateVersion, + CertificateTimeInvalid, + CertificateHasUnrecognizedObjectId, + CertificateHasInvalidBitString, + MessageTooLong, + NegativeIntoUnsigned, + TargetTooSmall, + BufferTooSmall, + InvalidSignature, + NotSquare, + NonCanonical, + WeakPublicKey, +}; -/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `stream`, which -/// must conform to `StreamInterface`. +/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session. /// /// `host` is only borrowed during this function call. -pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client { +/// +/// `input` is asserted to have buffer capacity at least `min_buffer_len`. +pub fn init(input: *Reader, output: *Writer, options: Options) InitError!Client { + assert(input.buffer.len >= min_buffer_len); + assert(output.buffer.len >= min_buffer_len); const host = switch (options.host) { .no_verification => "", .explicit => |host| host, @@ -276,11 +278,9 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client }; { - var iovecs = [_]std.posix.iovec_const{ - .{ .base = cleartext_header.ptr, .len = cleartext_header.len }, - .{ .base = host.ptr, .len = host.len }, - }; - try stream.writevAll(iovecs[0..if (host.len == 0) 1 else 2]); + var iovecs: [2][]const u8 = .{ cleartext_header, host }; + try output.writeVecAll(iovecs[0..if (host.len == 0) 1 else 2]); + try output.flush(); } var tls_version: tls.ProtocolVersion = undefined; @@ -329,20 +329,28 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client var cleartext_fragment_start: usize = 0; var cleartext_fragment_end: usize = 0; var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined; - var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined; - var d: tls.Decoder = .{ .buf = &handshake_buffer }; fragment: while (true) { - try d.readAtLeastOurAmt(stream, tls.record_header_len); - const record_header = d.buf[d.idx..][0..tls.record_header_len]; - const record_ct = d.decode(tls.ContentType); - d.skip(2); // legacy_version - const record_len = d.decode(u16); - try d.readAtLeast(stream, record_len); - var record_decoder = try d.sub(record_len); + // Ensure the input buffer pointer is stable in this scope. + input.rebase(tls.max_ciphertext_record_len) catch |err| switch (err) { + error.EndOfStream => {}, // We have assurance the remainder of stream can be buffered. + }; + const record_header = input.peek(tls.record_header_len) catch |err| switch (err) { + error.EndOfStream => return error.TlsConnectionTruncated, + error.ReadFailed => return error.ReadFailed, + }; + const record_ct = input.takeEnumNonexhaustive(tls.ContentType, .big) catch unreachable; // already peeked + input.toss(2); // legacy_version + const record_len = input.takeInt(u16, .big) catch unreachable; // already peeked + if (record_len > tls.max_ciphertext_len) return error.TlsRecordOverflow; + const record_buffer = input.take(record_len) catch |err| switch (err) { + error.EndOfStream => return error.TlsConnectionTruncated, + error.ReadFailed => return error.ReadFailed, + }; + var record_decoder: tls.Decoder = .fromTheirSlice(record_buffer); var ctd, const ct = content: switch (cipher_state) { .cleartext => .{ record_decoder, record_ct }, .handshake => { - std.debug.assert(tls_version == .tls_1_3); + assert(tls_version == .tls_1_3); if (record_ct != .application_data) return error.TlsUnexpectedMessage; try record_decoder.ensure(record_len); const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; @@ -374,7 +382,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end])), ct }; }, .application => { - std.debug.assert(tls_version == .tls_1_2); + assert(tls_version == .tls_1_2); if (record_ct != .handshake) return error.TlsUnexpectedMessage; try record_decoder.ensure(record_len); const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; @@ -412,14 +420,11 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client switch (ct) { .alert => { ctd.ensure(2) catch continue :fragment; - const level = ctd.decode(tls.AlertLevel); - const desc = ctd.decode(tls.AlertDescription); - _ = level; - - // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; + if (options.alert) |a| a.* = .{ + .level = ctd.decode(tls.Alert.Level), + .description = ctd.decode(tls.Alert.Description), + }; + return error.TlsAlert; }, .change_cipher_spec => { ctd.ensure(1) catch continue :fragment; @@ -533,7 +538,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); - if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ .client_random = &client_hello_rand, }, .{ .SERVER_HANDSHAKE_TRAFFIC_SECRET = &server_secret, @@ -707,7 +712,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client &client_hello_rand, &server_hello_rand, }, 48); - if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ .client_random = &client_hello_rand, }, .{ .CLIENT_RANDOM = &master_secret, @@ -755,11 +760,13 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client nonce, pv.app_cipher.client_write_key, ); - const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg; - var all_msgs_vec = [_]std.posix.iovec_const{ - .{ .base = &all_msgs, .len = all_msgs.len }, + var all_msgs_vec: [3][]const u8 = .{ + &client_key_exchange_msg, + &client_change_cipher_spec_msg, + &client_verify_msg, }; - try stream.writevAll(&all_msgs_vec); + try output.writeVecAll(&all_msgs_vec); + try output.flush(); }, } write_seq += 1; @@ -820,15 +827,16 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client const nonce = pv.client_handshake_iv; P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key); - const all_msgs = client_change_cipher_spec_msg ++ finished_msg; - var all_msgs_vec = [_]std.posix.iovec_const{ - .{ .base = &all_msgs, .len = all_msgs.len }, + var all_msgs_vec: [2][]const u8 = .{ + &client_change_cipher_spec_msg, + &finished_msg, }; - try stream.writevAll(&all_msgs_vec); + try output.writeVecAll(&all_msgs_vec); + try output.flush(); const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + if (options.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ .counter = key_seq, .client_random = &client_hello_rand, }, .{ @@ -855,8 +863,28 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client else => unreachable, }, }; - const leftover = d.rest(); - var client: Client = .{ + if (options.ssl_key_log) |ssl_key_log| ssl_key_log.* = .{ + .client_key_seq = key_seq, + .server_key_seq = key_seq, + .client_random = client_hello_rand, + .writer = ssl_key_log.writer, + }; + return .{ + .input = input, + .reader = .{ + .buffer = options.read_buffer, + .vtable = &.{ .stream = stream }, + .seek = 0, + .end = 0, + }, + .output = output, + .writer = .{ + .buffer = options.write_buffer, + .vtable = &.{ + .drain = drain, + .flush = flush, + }, + }, .tls_version = tls_version, .read_seq = switch (tls_version) { .tls_1_3 => 0, @@ -868,22 +896,11 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client .tls_1_2 => write_seq, else => unreachable, }, - .partial_cleartext_idx = 0, - .partial_ciphertext_idx = 0, - .partial_ciphertext_end = @intCast(leftover.len), .received_close_notify = false, - .allow_truncation_attacks = false, + .allow_truncation_attacks = options.allow_truncation_attacks, .application_cipher = app_cipher, - .partially_read_buffer = undefined, - .ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{ - .client_key_seq = key_seq, - .server_key_seq = key_seq, - .client_random = client_hello_rand, - .file = key_log_file, - } else null, + .ssl_key_log = options.ssl_key_log, }; - @memcpy(client.partially_read_buffer[0..leftover.len], leftover); - return client; }, else => return error.TlsUnexpectedMessage, } @@ -897,94 +914,73 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client } } -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`. -pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize { - return writeEnd(c, stream, bytes, false); -} - -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -pub fn writeAll(c: *Client, stream: anytype, bytes: []const u8) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.write(stream, bytes[index..]); +fn drain(w: *Writer, data: []const []const u8, splat: usize) Writer.Error!usize { + const c: *Client = @alignCast(@fieldParentPtr("writer", w)); + const output = c.output; + const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); + var ciphertext_end: usize = 0; + var total_clear: usize = 0; + done: { + { + const buf = w.buffered(); + const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); + total_clear += prepared.cleartext_len; + ciphertext_end += prepared.ciphertext_end; + if (prepared.cleartext_len < buf.len) break :done; + } + for (data[0 .. data.len - 1]) |buf| { + if (buf.len < min_buffer_len) break :done; + const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); + total_clear += prepared.cleartext_len; + ciphertext_end += prepared.ciphertext_end; + if (prepared.cleartext_len < buf.len) break :done; + } + const buf = data[data.len - 1]; + for (0..splat) |_| { + if (buf.len < min_buffer_len) break :done; + const prepared = prepareCiphertextRecord(c, ciphertext_buf[ciphertext_end..], buf, .application_data); + total_clear += prepared.cleartext_len; + ciphertext_end += prepared.ciphertext_end; + if (prepared.cleartext_len < buf.len) break :done; + } } + output.advance(ciphertext_end); + return w.consume(total_clear); } -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !void { - var index: usize = 0; - while (index < bytes.len) { - index += try c.writeEnd(stream, bytes[index..], end); - } +fn flush(w: *Writer) Writer.Error!void { + const c: *Client = @alignCast(@fieldParentPtr("writer", w)); + const output = c.output; + const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); + const prepared = prepareCiphertextRecord(c, ciphertext_buf, w.buffered(), .application_data); + output.advance(prepared.ciphertext_end); + w.end = 0; } -/// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`. -/// If `end` is true, then this function additionally sends a `close_notify` alert, -/// which is necessary for the server to distinguish between a properly finished -/// TLS session, or a truncation attack. -pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize { - var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined; - var iovecs_buf: [6]std.posix.iovec_const = undefined; - var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data); - if (end) { - prepared.iovec_end += prepareCiphertextRecord( - c, - iovecs_buf[prepared.iovec_end..], - ciphertext_buf[prepared.ciphertext_end..], - &tls.close_notify_alert, - .alert, - ).iovec_end; - } - - const iovec_end = prepared.iovec_end; - const overhead_len = prepared.overhead_len; - - // Ideally we would call writev exactly once here, however, we must ensure - // that we don't return with a record partially written. - var i: usize = 0; - var total_amt: usize = 0; - while (true) { - var amt = try stream.writev(iovecs_buf[i..iovec_end]); - while (amt >= iovecs_buf[i].len) { - const encrypted_amt = iovecs_buf[i].len; - total_amt += encrypted_amt - overhead_len; - amt -= encrypted_amt; - i += 1; - // Rely on the property that iovecs delineate records, meaning that - // if amt equals zero here, we have fortunately found ourselves - // with a short read that aligns at the record boundary. - if (i >= iovec_end) return total_amt; - // We also cannot return on a vector boundary if the final close_notify is - // not sent; otherwise the caller would not know to retry the call. - if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt; - } - iovecs_buf[i].base += amt; - iovecs_buf[i].len -= amt; - } +/// Sends a `close_notify` alert, which is necessary for the server to +/// distinguish between a properly finished TLS session, or a truncation +/// attack. +pub fn end(c: *Client) Writer.Error!void { + try flush(&c.writer); + const output = c.output; + const ciphertext_buf = try output.writableSliceGreedy(min_buffer_len); + const prepared = prepareCiphertextRecord(c, ciphertext_buf, &tls.close_notify_alert, .alert); + output.advance(prepared.ciphertext_end); } fn prepareCiphertextRecord( c: *Client, - iovecs: []std.posix.iovec_const, ciphertext_buf: []u8, bytes: []const u8, inner_content_type: tls.ContentType, ) struct { - iovec_end: usize, ciphertext_end: usize, - /// How many bytes are taken up by overhead per record. - overhead_len: usize, + cleartext_len: usize, } { // Due to the trailing inner content type byte in the ciphertext, we need // an additional buffer for storing the cleartext into before encrypting. var cleartext_buf: [max_ciphertext_len]u8 = undefined; var ciphertext_end: usize = 0; - var iovec_end: usize = 0; var bytes_i: usize = 0; switch (c.application_cipher) { inline else => |*p| switch (c.tls_version) { @@ -992,18 +988,15 @@ fn prepareCiphertextRecord( const pv = &p.tls_1_3; const P = @TypeOf(p.*); const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; while (true) { const encrypted_content_len: u16 = @min( bytes.len - bytes_i, tls.max_ciphertext_inner_record_len, - ciphertext_buf.len -| - (close_notify_alert_reserved + overhead_len + ciphertext_end), + ciphertext_buf.len -| (overhead_len + ciphertext_end), ); if (encrypted_content_len == 0) return .{ - .iovec_end = iovec_end, .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, + .cleartext_len = bytes_i, }; @memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]); @@ -1012,7 +1005,6 @@ fn prepareCiphertextRecord( const ciphertext_len = encrypted_content_len + 1; const cleartext = cleartext_buf[0..ciphertext_len]; - const record_start = ciphertext_end; const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; ad.* = .{@intFromEnum(tls.ContentType.application_data)} ++ int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ @@ -1030,38 +1022,27 @@ fn prepareCiphertextRecord( }; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key); c.write_seq += 1; // TODO send key_update on overflow - - const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = .{ - .base = record.ptr, - .len = record.len, - }; - iovec_end += 1; } }, .tls_1_2 => { const pv = &p.tls_1_2; const P = @TypeOf(p.*); const overhead_len = tls.record_header_len + P.record_iv_length + P.mac_length; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; while (true) { const message_len: u16 = @min( bytes.len - bytes_i, tls.max_ciphertext_inner_record_len, - ciphertext_buf.len -| - (close_notify_alert_reserved + overhead_len + ciphertext_end), + ciphertext_buf.len -| (overhead_len + ciphertext_end), ); if (message_len == 0) return .{ - .iovec_end = iovec_end, .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, + .cleartext_len = bytes_i, }; @memcpy(cleartext_buf[0..message_len], bytes[bytes_i..][0..message_len]); bytes_i += message_len; const cleartext = cleartext_buf[0..message_len]; - const record_start = ciphertext_end; const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; ciphertext_end += tls.record_header_len; record_header.* = .{@intFromEnum(inner_content_type)} ++ @@ -1083,13 +1064,6 @@ fn prepareCiphertextRecord( ciphertext_end += P.mac_length; P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_write_key); c.write_seq += 1; // TODO send key_update on overflow - - const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = .{ - .base = record.ptr, - .len = record.len, - }; - iovec_end += 1; } }, else => unreachable, @@ -1098,421 +1072,194 @@ fn prepareCiphertextRecord( } pub fn eof(c: Client) bool { - return c.received_close_notify and - c.partial_cleartext_idx >= c.partial_ciphertext_idx and - c.partial_ciphertext_idx >= c.partial_ciphertext_end; -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the buffer has at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -pub fn readAtLeast(c: *Client, stream: anytype, buffer: []u8, len: usize) !usize { - var iovecs = [1]std.posix.iovec{.{ .base = buffer.ptr, .len = buffer.len }}; - return readvAtLeast(c, stream, &iovecs, len); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -pub fn read(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, 1); + return c.received_close_notify; } -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is smaller than -/// `buffer.len`, it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize { - return readAtLeast(c, stream, buffer, buffer.len); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read. If the number read is less than the space -/// provided it means the stream reached the end. Reaching the end of the -/// stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readv(c: *Client, stream: anytype, iovecs: []std.posix.iovec) !usize { - return readvAtLeast(c, stream, iovecs, 1); -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns the number of bytes read, calling the underlying read function the -/// minimal number of times until the iovecs have at least `len` bytes filled. -/// If the number read is less than `len` it means the stream reached the end. -/// Reaching the end of the stream is not an error condition. -/// The `iovecs` parameter is mutable because this function needs to mutate the fields in -/// order to handle partial reads from the underlying stream layer. -pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.posix.iovec, len: usize) !usize { - if (c.eof()) return 0; - - var off_i: usize = 0; - var vec_i: usize = 0; - while (true) { - var amt = try c.readvAdvanced(stream, iovecs[vec_i..]); - off_i += amt; - if (c.eof() or off_i >= len) return off_i; - while (amt >= iovecs[vec_i].len) { - amt -= iovecs[vec_i].len; - vec_i += 1; - } - iovecs[vec_i].base += amt; - iovecs[vec_i].len -= amt; - } -} - -/// Receives TLS-encrypted data from `stream`, which must conform to `StreamInterface`. -/// Returns number of bytes that have been read, populated inside `iovecs`. A -/// return value of zero bytes does not mean end of stream. Instead, check the `eof()` -/// for the end of stream. The `eof()` may be true after any call to -/// `read`, including when greater than zero bytes are returned, and this -/// function asserts that `eof()` is `false`. -/// See `readv` for a higher level function that has the same, familiar API as -/// other read functions, such as `std.fs.File.read`. -pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iovec) !usize { - var vp: VecPut = .{ .iovecs = iovecs }; - - // Give away the buffered cleartext we have, if any. - const partial_cleartext = c.partially_read_buffer[c.partial_cleartext_idx..c.partial_ciphertext_idx]; - if (partial_cleartext.len > 0) { - const amt: u15 = @intCast(vp.put(partial_cleartext)); - c.partial_cleartext_idx += amt; - - if (c.partial_cleartext_idx == c.partial_ciphertext_idx and - c.partial_ciphertext_end == c.partial_ciphertext_idx) - { - // The buffer is now empty. - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = 0; - } - - if (c.received_close_notify) { - c.partial_ciphertext_end = 0; - assert(vp.total == amt); - return amt; - } else if (amt > 0) { - // We don't need more data, so don't call read. - assert(vp.total == amt); - return amt; - } +fn stream(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize { + const c: *Client = @alignCast(@fieldParentPtr("reader", r)); + if (c.eof()) return error.EndOfStream; + const input = c.input; + // If at least one full encrypted record is not buffered, read once. + const record_header = input.peek(tls.record_header_len) catch |err| switch (err) { + error.EndOfStream => { + // This is either a truncation attack, a bug in the server, or an + // intentional omission of the close_notify message due to truncation + // detection handled above the TLS layer. + if (c.allow_truncation_attacks) { + c.received_close_notify = true; + return error.EndOfStream; + } else { + return failRead(c, error.TlsConnectionTruncated); + } + }, + error.ReadFailed => return error.ReadFailed, + }; + const ct: tls.ContentType = @enumFromInt(record_header[0]); + const legacy_version = mem.readInt(u16, record_header[1..][0..2], .big); + _ = legacy_version; + const record_len = mem.readInt(u16, record_header[3..][0..2], .big); + if (record_len > max_ciphertext_len) return failRead(c, error.TlsRecordOverflow); + const record_end = 5 + record_len; + if (record_end > input.buffered().len) { + input.fillMore() catch |err| switch (err) { + error.EndOfStream => return failRead(c, error.TlsConnectionTruncated), + error.ReadFailed => return error.ReadFailed, + }; + if (record_end > input.buffered().len) return 0; } - assert(!c.received_close_notify); - - // Ideally, this buffer would never be used. It is needed when `iovecs` are - // too small to fit the cleartext, which may be as large as `max_ciphertext_len`. var cleartext_stack_buffer: [max_ciphertext_len]u8 = undefined; - // Temporarily stores ciphertext before decrypting it and giving it to `iovecs`. - var in_stack_buffer: [max_ciphertext_len * 4]u8 = undefined; - // How many bytes left in the user's buffer. - const free_size = vp.freeSize(); - // The amount of the user's buffer that we need to repurpose for storing - // ciphertext. The end of the buffer will be used for such purposes. - const ciphertext_buf_len = (free_size / 2) -| in_stack_buffer.len; - // The amount of the user's buffer that will be used to give cleartext. The - // beginning of the buffer will be used for such purposes. - const cleartext_buf_len = free_size - ciphertext_buf_len; - - // Recoup `partially_read_buffer` space. This is necessary because it is assumed - // below that `frag0` is big enough to hold at least one record. - limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx); - c.partial_ciphertext_end -= c.partial_ciphertext_idx; - c.partial_ciphertext_idx = 0; - c.partial_cleartext_idx = 0; - const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..]; - - var ask_iovecs_buf: [2]std.posix.iovec = .{ - .{ - .base = first_iov.ptr, - .len = first_iov.len, - }, - .{ - .base = &in_stack_buffer, - .len = in_stack_buffer.len, + const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) { + inline else => |*p| switch (c.tls_version) { + .tls_1_3 => { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const ad = input.take(tls.record_header_len) catch unreachable; // already peeked + const ciphertext_len = record_len - P.AEAD.tag_length; + const ciphertext = input.take(ciphertext_len) catch unreachable; // already peeked + const auth_tag = (input.takeArray(P.AEAD.tag_length) catch unreachable).*; // already peeked + const nonce = nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ std.mem.toBytes(big(c.read_seq)); + break :nonce @as(V, pv.server_iv) ^ operand; + }; + const cleartext = cleartext_stack_buffer[0..ciphertext.len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch + return failRead(c, error.TlsBadRecordMac); + const msg = mem.trimRight(u8, cleartext, "\x00"); + break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) }; + }, + .tls_1_2 => { + const pv = &p.tls_1_2; + const P = @TypeOf(p.*); + const message_len: u16 = record_len - P.record_iv_length - P.mac_length; + const ad_header = input.take(tls.record_header_len) catch unreachable; // already peeked + const ad = std.mem.toBytes(big(c.read_seq)) ++ + ad_header[0 .. 1 + 2] ++ + std.mem.toBytes(big(message_len)); + const record_iv = (input.takeArray(P.record_iv_length) catch unreachable).*; // already peeked + const masked_read_seq = c.read_seq & + comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); + const nonce: [P.AEAD.nonce_length]u8 = nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq))); + break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand; + }; + const ciphertext = input.take(message_len) catch unreachable; // already peeked + const auth_tag = (input.takeArray(P.mac_length) catch unreachable).*; // already peeked + const cleartext = cleartext_stack_buffer[0..ciphertext.len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch + return failRead(c, error.TlsBadRecordMac); + break :cleartext .{ cleartext, ct }; + }, + else => unreachable, }, }; - - // Cleartext capacity of output buffer, in records. Minimum one full record. - const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1); - const wanted_read_len = buf_cap * (max_ciphertext_len + tls.record_header_len); - const ask_len = @max(wanted_read_len, cleartext_stack_buffer.len) - c.partial_ciphertext_end; - const ask_iovecs = limitVecs(&ask_iovecs_buf, ask_len); - const actual_read_len = try stream.readv(ask_iovecs); - if (actual_read_len == 0) { - // This is either a truncation attack, a bug in the server, or an - // intentional omission of the close_notify message due to truncation - // detection handled above the TLS layer. - if (c.allow_truncation_attacks) { - c.received_close_notify = true; - } else { - return error.TlsConnectionTruncated; - } - } - - // There might be more bytes inside `in_stack_buffer` that need to be processed, - // but at least frag0 will have one complete ciphertext record. - const frag0_end = @min(c.partially_read_buffer.len, c.partial_ciphertext_end + actual_read_len); - const frag0 = c.partially_read_buffer[c.partial_ciphertext_idx..frag0_end]; - var frag1 = in_stack_buffer[0..actual_read_len -| first_iov.len]; - // We need to decipher frag0 and frag1 but there may be a ciphertext record - // straddling the boundary. We can handle this with two memcpy() calls to - // assemble the straddling record in between handling the two sides. - var frag = frag0; - var in: usize = 0; - while (true) { - if (in == frag.len) { - // Perfect split. - if (frag.ptr == frag1.ptr) { - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - frag = frag1; - in = 0; - continue; - } - - if (in + tls.record_header_len > frag.len) { - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - const first = frag[in..]; - - if (frag1.len < tls.record_header_len) - return finishRead2(c, first, frag1, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const record_len_byte_0: u16 = straddleByte(frag, frag1, in + 3); - const record_len_byte_1: u16 = straddleByte(frag, frag1, in + 4); - const record_len = (record_len_byte_0 << 8) | record_len_byte_1; - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - const ct: tls.ContentType = @enumFromInt(frag[in]); - in += 1; - const legacy_version = mem.readInt(u16, frag[in..][0..2], .big); - in += 2; - _ = legacy_version; - const record_len = mem.readInt(u16, frag[in..][0..2], .big); - if (record_len > max_ciphertext_len) return error.TlsRecordOverflow; - in += 2; - const end = in + record_len; - if (end > frag.len) { - // We need the record header on the next iteration of the loop. - in -= tls.record_header_len; - - if (frag.ptr == frag1.ptr) - return finishRead(c, frag, in, vp.total); - - // A record straddles the two fragments. Copy into the now-empty first fragment. - const first = frag[in..]; - const full_record_len = record_len + tls.record_header_len; - const second_len = full_record_len - first.len; - if (frag1.len < second_len) - return finishRead2(c, first, frag1, vp.total); - - limitedOverlapCopy(frag, in); - @memcpy(frag[first.len..][0..second_len], frag1[0..second_len]); - frag = frag[0..full_record_len]; - frag1 = frag1[second_len..]; - in = 0; - continue; - } - const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) { - inline else => |*p| switch (c.tls_version) { - .tls_1_3 => { - const pv = &p.tls_1_3; - const P = @TypeOf(p.*); - const ad = frag[in - tls.record_header_len ..][0..tls.record_header_len]; - const ciphertext_len = record_len - P.AEAD.tag_length; - const ciphertext = frag[in..][0..ciphertext_len]; - in += ciphertext_len; - const auth_tag = frag[in..][0..P.AEAD.tag_length].*; - const nonce = nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ std.mem.toBytes(big(c.read_seq)); - break :nonce @as(V, pv.server_iv) ^ operand; - }; - const out_buf = vp.peek(); - const cleartext_buf = if (ciphertext.len <= out_buf.len) - out_buf - else - &cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch - return error.TlsBadRecordMac; - const msg = mem.trimEnd(u8, cleartext, "\x00"); - break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) }; + c.read_seq = std.math.add(u64, c.read_seq, 1) catch return failRead(c, error.TlsSequenceOverflow); + switch (inner_ct) { + .alert => { + if (cleartext.len != 2) return failRead(c, error.TlsDecodeError); + const alert: tls.Alert = .{ + .level = @enumFromInt(cleartext[0]), + .description = @enumFromInt(cleartext[1]), + }; + switch (alert.description) { + .close_notify => { + c.received_close_notify = true; + return 0; }, - .tls_1_2 => { - const pv = &p.tls_1_2; - const P = @TypeOf(p.*); - const message_len: u16 = record_len - P.record_iv_length - P.mac_length; - const ad = std.mem.toBytes(big(c.read_seq)) ++ - frag[in - tls.record_header_len ..][0 .. 1 + 2] ++ - std.mem.toBytes(big(message_len)); - const record_iv = frag[in..][0..P.record_iv_length].*; - in += P.record_iv_length; - const masked_read_seq = c.read_seq & - comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); - const nonce: [P.AEAD.nonce_length]u8 = nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq))); - break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand; - }; - const ciphertext = frag[in..][0..message_len]; - in += message_len; - const auth_tag = frag[in..][0..P.mac_length].*; - in += P.mac_length; - const out_buf = vp.peek(); - const cleartext_buf = if (message_len <= out_buf.len) - out_buf - else - &cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch - return error.TlsBadRecordMac; - break :cleartext .{ cleartext, ct }; + .user_canceled => { + // TODO: handle server-side closures + return failRead(c, error.TlsUnexpectedMessage); }, - else => unreachable, - }, - }; - c.read_seq = try std.math.add(u64, c.read_seq, 1); - switch (inner_ct) { - .alert => { - if (cleartext.len != 2) return error.TlsDecodeError; - const level: tls.AlertLevel = @enumFromInt(cleartext[0]); - const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); - if (desc == .close_notify) { - c.received_close_notify = true; - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - _ = level; - - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - var ct_i: usize = 0; - while (true) { - const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); - ct_i += 1; - const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); - ct_i += 3; - const next_handshake_i = ct_i + handshake_len; - if (next_handshake_i > cleartext.len) - return error.TlsBadLength; - const handshake = cleartext[ct_i..next_handshake_i]; - switch (handshake_type) { - .new_session_ticket => { - // This client implementation ignores new session tickets. - }, - .key_update => { - switch (c.application_cipher) { - inline else => |*p| { - const pv = &p.tls_1_3; - const P = @TypeOf(p.*); - const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length); - if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{ - .counter = key_log.serverCounter(), - .client_random = &key_log.client_random, - }, .{ - .SERVER_TRAFFIC_SECRET = &server_secret, - }); - pv.server_secret = server_secret; - pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.read_seq = 0; - - switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { - .update_requested => { - switch (c.application_cipher) { - inline else => |*p| { - const pv = &p.tls_1_3; - const P = @TypeOf(p.*); - const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length); - if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{ - .counter = key_log.clientCounter(), - .client_random = &key_log.client_random, - }, .{ - .CLIENT_TRAFFIC_SECRET = &client_secret, - }); - pv.client_secret = client_secret; - pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.write_seq = 0; - }, - .update_not_requested => {}, - _ => return error.TlsIllegalParameter, - } - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - ct_i = next_handshake_i; - if (ct_i >= cleartext.len) break; - } - }, - .application_data => { - // Determine whether the output buffer or a stack - // buffer was used for storing the cleartext. - if (cleartext.ptr == &cleartext_stack_buffer) { - // Stack buffer was used, so we must copy to the output buffer. - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // We have already run out of room in iovecs. Continue - // appending to `partially_read_buffer`. - @memcpy( - c.partially_read_buffer[c.partial_ciphertext_idx..][0..cleartext.len], - cleartext, - ); - c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + cleartext.len); - } else { - const amt = vp.put(cleartext); - if (amt < cleartext.len) { - const rest = cleartext[amt..]; - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = @intCast(rest.len); - @memcpy(c.partially_read_buffer[0..rest.len], rest); + else => { + c.alert = alert; + return failRead(c, error.TlsAlert); + }, + } + }, + .handshake => { + var ct_i: usize = 0; + while (true) { + const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); + ct_i += 1; + const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); + ct_i += 3; + const next_handshake_i = ct_i + handshake_len; + if (next_handshake_i > cleartext.len) return failRead(c, error.TlsBadLength); + const handshake = cleartext[ct_i..next_handshake_i]; + switch (handshake_type) { + .new_session_ticket => { + // This client implementation ignores new session tickets. + }, + .key_update => { + switch (c.application_cipher) { + inline else => |*p| { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length); + if (c.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ + .counter = key_log.serverCounter(), + .client_random = &key_log.client_random, + }, .{ + .SERVER_TRAFFIC_SECRET = &server_secret, + }); + pv.server_secret = server_secret; + pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + }, + } + c.read_seq = 0; + + switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { + .update_requested => { + switch (c.application_cipher) { + inline else => |*p| { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length); + if (c.ssl_key_log) |key_log| logSecrets(key_log.writer, .{ + .counter = key_log.clientCounter(), + .client_random = &key_log.client_random, + }, .{ + .CLIENT_TRAFFIC_SECRET = &client_secret, + }); + pv.client_secret = client_secret; + pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + }, + } + c.write_seq = 0; + }, + .update_not_requested => {}, + _ => return failRead(c, error.TlsIllegalParameter), } - } - } else { - // Output buffer was used directly which means no - // memory copying needs to occur, and we can move - // on to the next ciphertext record. - vp.next(cleartext.len); + }, + else => return failRead(c, error.TlsUnexpectedMessage), } - }, - else => return error.TlsUnexpectedMessage, - } - in = end; + ct_i = next_handshake_i; + if (ct_i >= cleartext.len) break; + } + return 0; + }, + .application_data => { + if (@intFromEnum(limit) < cleartext.len) return failRead(c, error.OutputBufferUndersize); + try w.writeAll(cleartext); + return cleartext.len; + }, + else => return failRead(c, error.TlsUnexpectedMessage), } } -fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) void { - const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false; - defer if (locked) key_log_file.unlock(); - key_log_file.seekFromEnd(0) catch {}; - inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| key_log_file.deprecatedWriter().print("{s}" ++ +fn failRead(c: *Client, err: ReadError) error{ReadFailed} { + c.read_err = err; + return error.ReadFailed; +} + +fn logSecrets(w: *Writer, context: anytype, secrets: anytype) void { + inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| w.print("{s}" ++ (if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {x} {x}\n", .{field.name} ++ (if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{ context.client_random, @@ -1520,62 +1267,6 @@ fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) voi }) catch {}; } -fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { - const saved_buf = frag[in..]; - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + saved_buf.len); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx..][0..saved_buf.len], saved_buf); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(saved_buf.len); - @memcpy(c.partially_read_buffer[0..saved_buf.len], saved_buf); - } - return out; -} - -/// Note that `first` usually overlaps with `c.partially_read_buffer`. -fn finishRead2(c: *Client, first: []const u8, frag1: []const u8, out: usize) usize { - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // There is cleartext at the beginning already which we need to preserve. - c.partial_ciphertext_end = @intCast(c.partial_ciphertext_idx + first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[c.partial_ciphertext_idx..][0..first.len], first); - @memcpy(c.partially_read_buffer[c.partial_ciphertext_idx + first.len ..][0..frag1.len], frag1); - } else { - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = 0; - c.partial_ciphertext_end = @intCast(first.len + frag1.len); - // TODO: eliminate this call to copyForwards - std.mem.copyForwards(u8, c.partially_read_buffer[0..first.len], first); - @memcpy(c.partially_read_buffer[first.len..][0..frag1.len], frag1); - } - return out; -} - -fn limitedOverlapCopy(frag: []u8, in: usize) void { - const first = frag[in..]; - if (first.len <= in) { - // A single, non-overlapping memcpy suffices. - @memcpy(frag[0..first.len], first); - } else { - // One memcpy call would overlap, so just do this instead. - std.mem.copyForwards(u8, frag, first); - } -} - -fn straddleByte(s1: []const u8, s2: []const u8, index: usize) u8 { - if (index < s1.len) { - return s1[index]; - } else { - return s2[index - s1.len]; - } -} - -const builtin = @import("builtin"); -const native_endian = builtin.cpu.arch.endian(); - fn big(x: anytype) @TypeOf(x) { return switch (native_endian) { .big => x, @@ -1836,81 +1527,6 @@ const CertificatePublicKey = struct { } }; -/// Abstraction for sending multiple byte buffers to a slice of iovecs. -const VecPut = struct { - iovecs: []const std.posix.iovec, - idx: usize = 0, - off: usize = 0, - total: usize = 0, - - /// Returns the amount actually put which is always equal to bytes.len - /// unless the vectors ran out of space. - fn put(vp: *VecPut, bytes: []const u8) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var bytes_i: usize = 0; - while (true) { - const v = vp.iovecs[vp.idx]; - const dest = v.base[vp.off..v.len]; - const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)]; - @memcpy(dest[0..src.len], src); - bytes_i += src.len; - vp.off += src.len; - if (vp.off >= v.len) { - vp.off = 0; - vp.idx += 1; - if (vp.idx >= vp.iovecs.len) { - vp.total += bytes_i; - return bytes_i; - } - } - if (bytes_i >= bytes.len) { - vp.total += bytes_i; - return bytes_i; - } - } - } - - /// Returns the next buffer that consecutive bytes can go into. - fn peek(vp: VecPut) []u8 { - if (vp.idx >= vp.iovecs.len) return &.{}; - const v = vp.iovecs[vp.idx]; - return v.base[vp.off..v.len]; - } - - // After writing to the result of peek(), one can call next() to - // advance the cursor. - fn next(vp: *VecPut, len: usize) void { - vp.total += len; - vp.off += len; - if (vp.off >= vp.iovecs[vp.idx].len) { - vp.off = 0; - vp.idx += 1; - } - } - - fn freeSize(vp: VecPut) usize { - if (vp.idx >= vp.iovecs.len) return 0; - var total: usize = 0; - total += vp.iovecs[vp.idx].len - vp.off; - if (vp.idx + 1 >= vp.iovecs.len) return total; - for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.len; - return total; - } -}; - -/// Limit iovecs to a specific byte size. -fn limitVecs(iovecs: []std.posix.iovec, len: usize) []std.posix.iovec { - var bytes_left: usize = len; - for (iovecs, 0..) |*iovec, vec_i| { - if (bytes_left <= iovec.len) { - iovec.len = bytes_left; - return iovecs[0 .. vec_i + 1]; - } - bytes_left -= iovec.len; - } - return iovecs; -} - /// The priority order here is chosen based on what crypto algorithms Zig has /// available in the standard library as well as what is faster. Following are /// a few data points on the relative performance of these algorithms. @@ -1954,7 +1570,3 @@ else .AES_256_GCM_SHA384, .ECDHE_RSA_WITH_AES_256_GCM_SHA384, }); - -test { - _ = StreamInterface; -} diff --git a/lib/std/fifo.zig b/lib/std/fifo.zig deleted file mode 100644 index e18b5edb0122..000000000000 --- a/lib/std/fifo.zig +++ /dev/null @@ -1,548 +0,0 @@ -// FIFO of fixed size items -// Usually used for e.g. byte buffers - -const std = @import("std"); -const math = std.math; -const mem = std.mem; -const Allocator = mem.Allocator; -const assert = std.debug.assert; -const testing = std.testing; - -pub const LinearFifoBufferType = union(enum) { - /// The buffer is internal to the fifo; it is of the specified size. - Static: usize, - - /// The buffer is passed as a slice to the initialiser. - Slice, - - /// The buffer is managed dynamically using a `mem.Allocator`. - Dynamic, -}; - -pub fn LinearFifo( - comptime T: type, - comptime buffer_type: LinearFifoBufferType, -) type { - const autoalign = false; - - const powers_of_two = switch (buffer_type) { - .Static => std.math.isPowerOfTwo(buffer_type.Static), - .Slice => false, // Any size slice could be passed in - .Dynamic => true, // This could be configurable in future - }; - - return struct { - allocator: if (buffer_type == .Dynamic) Allocator else void, - buf: if (buffer_type == .Static) [buffer_type.Static]T else []T, - head: usize, - count: usize, - - const Self = @This(); - pub const Reader = std.io.GenericReader(*Self, error{}, readFn); - pub const Writer = std.io.GenericWriter(*Self, error{OutOfMemory}, appendWrite); - - // Type of Self argument for slice operations. - // If buffer is inline (Static) then we need to ensure we haven't - // returned a slice into a copy on the stack - const SliceSelfArg = if (buffer_type == .Static) *Self else Self; - - pub const init = switch (buffer_type) { - .Static => initStatic, - .Slice => initSlice, - .Dynamic => initDynamic, - }; - - fn initStatic() Self { - comptime assert(buffer_type == .Static); - return .{ - .allocator = {}, - .buf = undefined, - .head = 0, - .count = 0, - }; - } - - fn initSlice(buf: []T) Self { - comptime assert(buffer_type == .Slice); - return .{ - .allocator = {}, - .buf = buf, - .head = 0, - .count = 0, - }; - } - - fn initDynamic(allocator: Allocator) Self { - comptime assert(buffer_type == .Dynamic); - return .{ - .allocator = allocator, - .buf = &.{}, - .head = 0, - .count = 0, - }; - } - - pub fn deinit(self: Self) void { - if (buffer_type == .Dynamic) self.allocator.free(self.buf); - } - - pub fn realign(self: *Self) void { - if (self.buf.len - self.head >= self.count) { - mem.copyForwards(T, self.buf[0..self.count], self.buf[self.head..][0..self.count]); - self.head = 0; - } else { - var tmp: [4096 / 2 / @sizeOf(T)]T = undefined; - - while (self.head != 0) { - const n = @min(self.head, tmp.len); - const m = self.buf.len - n; - @memcpy(tmp[0..n], self.buf[0..n]); - mem.copyForwards(T, self.buf[0..m], self.buf[n..][0..m]); - @memcpy(self.buf[m..][0..n], tmp[0..n]); - self.head -= n; - } - } - { // set unused area to undefined - const unused = mem.sliceAsBytes(self.buf[self.count..]); - @memset(unused, undefined); - } - } - - /// Reduce allocated capacity to `size`. - pub fn shrink(self: *Self, size: usize) void { - assert(size >= self.count); - if (buffer_type == .Dynamic) { - self.realign(); - self.buf = self.allocator.realloc(self.buf, size) catch |e| switch (e) { - error.OutOfMemory => return, // no problem, capacity is still correct then. - }; - } - } - - /// Ensure that the buffer can fit at least `size` items - pub fn ensureTotalCapacity(self: *Self, size: usize) !void { - if (self.buf.len >= size) return; - if (buffer_type == .Dynamic) { - self.realign(); - const new_size = if (powers_of_two) math.ceilPowerOfTwo(usize, size) catch return error.OutOfMemory else size; - self.buf = try self.allocator.realloc(self.buf, new_size); - } else { - return error.OutOfMemory; - } - } - - /// Makes sure at least `size` items are unused - pub fn ensureUnusedCapacity(self: *Self, size: usize) error{OutOfMemory}!void { - if (self.writableLength() >= size) return; - - return try self.ensureTotalCapacity(math.add(usize, self.count, size) catch return error.OutOfMemory); - } - - /// Returns number of items currently in fifo - pub fn readableLength(self: Self) usize { - return self.count; - } - - /// Returns a writable slice from the 'read' end of the fifo - fn readableSliceMut(self: SliceSelfArg, offset: usize) []T { - if (offset > self.count) return &[_]T{}; - - var start = self.head + offset; - if (start >= self.buf.len) { - start -= self.buf.len; - return self.buf[start .. start + (self.count - offset)]; - } else { - const end = @min(self.head + self.count, self.buf.len); - return self.buf[start..end]; - } - } - - /// Returns a readable slice from `offset` - pub fn readableSlice(self: SliceSelfArg, offset: usize) []const T { - return self.readableSliceMut(offset); - } - - pub fn readableSliceOfLen(self: *Self, len: usize) []const T { - assert(len <= self.count); - const buf = self.readableSlice(0); - if (buf.len >= len) { - return buf[0..len]; - } else { - self.realign(); - return self.readableSlice(0)[0..len]; - } - } - - /// Discard first `count` items in the fifo - pub fn discard(self: *Self, count: usize) void { - assert(count <= self.count); - { // set old range to undefined. Note: may be wrapped around - const slice = self.readableSliceMut(0); - if (slice.len >= count) { - const unused = mem.sliceAsBytes(slice[0..count]); - @memset(unused, undefined); - } else { - const unused = mem.sliceAsBytes(slice[0..]); - @memset(unused, undefined); - const unused2 = mem.sliceAsBytes(self.readableSliceMut(slice.len)[0 .. count - slice.len]); - @memset(unused2, undefined); - } - } - if (autoalign and self.count == count) { - self.head = 0; - self.count = 0; - } else { - var head = self.head + count; - if (powers_of_two) { - // Note it is safe to do a wrapping subtract as - // bitwise & with all 1s is a noop - head &= self.buf.len -% 1; - } else { - head %= self.buf.len; - } - self.head = head; - self.count -= count; - } - } - - /// Read the next item from the fifo - pub fn readItem(self: *Self) ?T { - if (self.count == 0) return null; - - const c = self.buf[self.head]; - self.discard(1); - return c; - } - - /// Read data from the fifo into `dst`, returns number of items copied. - pub fn read(self: *Self, dst: []T) usize { - var dst_left = dst; - - while (dst_left.len > 0) { - const slice = self.readableSlice(0); - if (slice.len == 0) break; - const n = @min(slice.len, dst_left.len); - @memcpy(dst_left[0..n], slice[0..n]); - self.discard(n); - dst_left = dst_left[n..]; - } - - return dst.len - dst_left.len; - } - - /// Same as `read` except it returns an error union - /// The purpose of this function existing is to match `std.io.GenericReader` API. - fn readFn(self: *Self, dest: []u8) error{}!usize { - return self.read(dest); - } - - pub fn reader(self: *Self) Reader { - return .{ .context = self }; - } - - /// Returns number of items available in fifo - pub fn writableLength(self: Self) usize { - return self.buf.len - self.count; - } - - /// Returns the first section of writable buffer. - /// Note that this may be of length 0 - pub fn writableSlice(self: SliceSelfArg, offset: usize) []T { - if (offset > self.buf.len) return &[_]T{}; - - const tail = self.head + offset + self.count; - if (tail < self.buf.len) { - return self.buf[tail..]; - } else { - return self.buf[tail - self.buf.len ..][0 .. self.writableLength() - offset]; - } - } - - /// Returns a writable buffer of at least `size` items, allocating memory as needed. - /// Use `fifo.update` once you've written data to it. - pub fn writableWithSize(self: *Self, size: usize) ![]T { - try self.ensureUnusedCapacity(size); - - // try to avoid realigning buffer - var slice = self.writableSlice(0); - if (slice.len < size) { - self.realign(); - slice = self.writableSlice(0); - } - return slice; - } - - /// Update the tail location of the buffer (usually follows use of writable/writableWithSize) - pub fn update(self: *Self, count: usize) void { - assert(self.count + count <= self.buf.len); - self.count += count; - } - - /// Appends the data in `src` to the fifo. - /// You must have ensured there is enough space. - pub fn writeAssumeCapacity(self: *Self, src: []const T) void { - assert(self.writableLength() >= src.len); - - var src_left = src; - while (src_left.len > 0) { - const writable_slice = self.writableSlice(0); - assert(writable_slice.len != 0); - const n = @min(writable_slice.len, src_left.len); - @memcpy(writable_slice[0..n], src_left[0..n]); - self.update(n); - src_left = src_left[n..]; - } - } - - /// Write a single item to the fifo - pub fn writeItem(self: *Self, item: T) !void { - try self.ensureUnusedCapacity(1); - return self.writeItemAssumeCapacity(item); - } - - pub fn writeItemAssumeCapacity(self: *Self, item: T) void { - var tail = self.head + self.count; - if (powers_of_two) { - tail &= self.buf.len - 1; - } else { - tail %= self.buf.len; - } - self.buf[tail] = item; - self.update(1); - } - - /// Appends the data in `src` to the fifo. - /// Allocates more memory as necessary - pub fn write(self: *Self, src: []const T) !void { - try self.ensureUnusedCapacity(src.len); - - return self.writeAssumeCapacity(src); - } - - /// Same as `write` except it returns the number of bytes written, which is always the same - /// as `bytes.len`. The purpose of this function existing is to match `std.io.GenericWriter` API. - fn appendWrite(self: *Self, bytes: []const u8) error{OutOfMemory}!usize { - try self.write(bytes); - return bytes.len; - } - - pub fn writer(self: *Self) Writer { - return .{ .context = self }; - } - - /// Make `count` items available before the current read location - fn rewind(self: *Self, count: usize) void { - assert(self.writableLength() >= count); - - var head = self.head + (self.buf.len - count); - if (powers_of_two) { - head &= self.buf.len - 1; - } else { - head %= self.buf.len; - } - self.head = head; - self.count += count; - } - - /// Place data back into the read stream - pub fn unget(self: *Self, src: []const T) !void { - try self.ensureUnusedCapacity(src.len); - - self.rewind(src.len); - - const slice = self.readableSliceMut(0); - if (src.len < slice.len) { - @memcpy(slice[0..src.len], src); - } else { - @memcpy(slice, src[0..slice.len]); - const slice2 = self.readableSliceMut(slice.len); - @memcpy(slice2[0 .. src.len - slice.len], src[slice.len..]); - } - } - - /// Returns the item at `offset`. - /// Asserts offset is within bounds. - pub fn peekItem(self: Self, offset: usize) T { - assert(offset < self.count); - - var index = self.head + offset; - if (powers_of_two) { - index &= self.buf.len - 1; - } else { - index %= self.buf.len; - } - return self.buf[index]; - } - - /// Pump data from a reader into a writer. - /// Stops when reader returns 0 bytes (EOF). - /// Buffer size must be set before calling; a buffer length of 0 is invalid. - pub fn pump(self: *Self, src_reader: anytype, dest_writer: anytype) !void { - assert(self.buf.len > 0); - while (true) { - if (self.writableLength() > 0) { - const n = try src_reader.read(self.writableSlice(0)); - if (n == 0) break; // EOF - self.update(n); - } - self.discard(try dest_writer.write(self.readableSlice(0))); - } - // flush remaining data - while (self.readableLength() > 0) { - self.discard(try dest_writer.write(self.readableSlice(0))); - } - } - - pub fn toOwnedSlice(self: *Self) Allocator.Error![]T { - if (self.head != 0) self.realign(); - assert(self.head == 0); - assert(self.count <= self.buf.len); - const allocator = self.allocator; - if (allocator.resize(self.buf, self.count)) { - const result = self.buf[0..self.count]; - self.* = Self.init(allocator); - return result; - } - const new_memory = try allocator.dupe(T, self.buf[0..self.count]); - allocator.free(self.buf); - self.* = Self.init(allocator); - return new_memory; - } - }; -} - -test "LinearFifo(u8, .Dynamic) discard(0) from empty buffer should not error on overflow" { - var fifo = LinearFifo(u8, .Dynamic).init(testing.allocator); - defer fifo.deinit(); - - // If overflow is not explicitly allowed this will crash in debug / safe mode - fifo.discard(0); -} - -test "LinearFifo(u8, .Dynamic)" { - var fifo = LinearFifo(u8, .Dynamic).init(testing.allocator); - defer fifo.deinit(); - - try fifo.write("HELLO"); - try testing.expectEqual(@as(usize, 5), fifo.readableLength()); - try testing.expectEqualSlices(u8, "HELLO", fifo.readableSlice(0)); - - { - var i: usize = 0; - while (i < 5) : (i += 1) { - try fifo.write(&[_]u8{fifo.peekItem(i)}); - } - try testing.expectEqual(@as(usize, 10), fifo.readableLength()); - try testing.expectEqualSlices(u8, "HELLOHELLO", fifo.readableSlice(0)); - } - - { - try testing.expectEqual(@as(u8, 'H'), fifo.readItem().?); - try testing.expectEqual(@as(u8, 'E'), fifo.readItem().?); - try testing.expectEqual(@as(u8, 'L'), fifo.readItem().?); - try testing.expectEqual(@as(u8, 'L'), fifo.readItem().?); - try testing.expectEqual(@as(u8, 'O'), fifo.readItem().?); - } - try testing.expectEqual(@as(usize, 5), fifo.readableLength()); - - { // Writes that wrap around - try testing.expectEqual(@as(usize, 11), fifo.writableLength()); - try testing.expectEqual(@as(usize, 6), fifo.writableSlice(0).len); - fifo.writeAssumeCapacity("6 FifoType.init(), - .Slice => FifoType.init(buf[0..]), - .Dynamic => FifoType.init(testing.allocator), - }; - defer fifo.deinit(); - - try fifo.write(&[_]T{ 0, 1, 1, 0, 1 }); - try testing.expectEqual(@as(usize, 5), fifo.readableLength()); - - { - try testing.expectEqual(@as(T, 0), fifo.readItem().?); - try testing.expectEqual(@as(T, 1), fifo.readItem().?); - try testing.expectEqual(@as(T, 1), fifo.readItem().?); - try testing.expectEqual(@as(T, 0), fifo.readItem().?); - try testing.expectEqual(@as(T, 1), fifo.readItem().?); - try testing.expectEqual(@as(usize, 0), fifo.readableLength()); - } - - { - try fifo.writeItem(1); - try fifo.writeItem(1); - try fifo.writeItem(1); - try testing.expectEqual(@as(usize, 3), fifo.readableLength()); - } - - { - var readBuf: [3]T = undefined; - const n = fifo.read(&readBuf); - try testing.expectEqual(@as(usize, 3), n); // NOTE: It should be the number of items. - } - } - } -} diff --git a/lib/std/fs/File.zig b/lib/std/fs/File.zig index 2791642ac714..7ad71ad274cc 100644 --- a/lib/std/fs/File.zig +++ b/lib/std/fs/File.zig @@ -1351,8 +1351,7 @@ pub const Reader = struct { } r.pos += n; if (n > data_size) { - io_reader.seek = 0; - io_reader.end = n - data_size; + io_reader.end += n - data_size; return data_size; } return n; @@ -1386,8 +1385,7 @@ pub const Reader = struct { } r.pos += n; if (n > data_size) { - io_reader.seek = 0; - io_reader.end = n - data_size; + io_reader.end += n - data_size; return data_size; } return n; diff --git a/lib/std/http.zig b/lib/std/http.zig index 5bf12a187622..6822af88c988 100644 --- a/lib/std/http.zig +++ b/lib/std/http.zig @@ -1,14 +1,14 @@ const builtin = @import("builtin"); const std = @import("std.zig"); const assert = std.debug.assert; +const Writer = std.Io.Writer; +const File = std.fs.File; pub const Client = @import("http/Client.zig"); pub const Server = @import("http/Server.zig"); -pub const protocol = @import("http/protocol.zig"); pub const HeadParser = @import("http/HeadParser.zig"); pub const ChunkParser = @import("http/ChunkParser.zig"); pub const HeaderIterator = @import("http/HeaderIterator.zig"); -pub const WebSocket = @import("http/WebSocket.zig"); pub const Version = enum { @"HTTP/1.0", @@ -20,51 +20,32 @@ pub const Version = enum { /// https://datatracker.ietf.org/doc/html/rfc7231#section-4 Initial definition /// /// https://datatracker.ietf.org/doc/html/rfc5789#section-2 PATCH -pub const Method = enum(u64) { - GET = parse("GET"), - HEAD = parse("HEAD"), - POST = parse("POST"), - PUT = parse("PUT"), - DELETE = parse("DELETE"), - CONNECT = parse("CONNECT"), - OPTIONS = parse("OPTIONS"), - TRACE = parse("TRACE"), - PATCH = parse("PATCH"), - - _, - - /// Converts `s` into a type that may be used as a `Method` field. - /// Asserts that `s` is 24 or fewer bytes. - pub fn parse(s: []const u8) u64 { - var x: u64 = 0; - const len = @min(s.len, @sizeOf(@TypeOf(x))); - @memcpy(std.mem.asBytes(&x)[0..len], s[0..len]); - return x; - } - - pub fn format(self: Method, w: *std.io.Writer) std.io.Writer.Error!void { - const bytes: []const u8 = @ptrCast(&@intFromEnum(self)); - const str = std.mem.sliceTo(bytes, 0); - try w.writeAll(str); - } +pub const Method = enum { + GET, + HEAD, + POST, + PUT, + DELETE, + CONNECT, + OPTIONS, + TRACE, + PATCH, /// Returns true if a request of this method is allowed to have a body /// Actual behavior from servers may vary and should still be checked - pub fn requestHasBody(self: Method) bool { - return switch (self) { + pub fn requestHasBody(m: Method) bool { + return switch (m) { .POST, .PUT, .PATCH => true, .GET, .HEAD, .DELETE, .CONNECT, .OPTIONS, .TRACE => false, - else => true, }; } /// Returns true if a response to this method is allowed to have a body /// Actual behavior from clients may vary and should still be checked - pub fn responseHasBody(self: Method) bool { - return switch (self) { + pub fn responseHasBody(m: Method) bool { + return switch (m) { .GET, .POST, .DELETE, .CONNECT, .OPTIONS, .PATCH => true, .HEAD, .PUT, .TRACE => false, - else => true, }; } @@ -73,11 +54,10 @@ pub const Method = enum(u64) { /// https://developer.mozilla.org/en-US/docs/Glossary/Safe/HTTP /// /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.1 - pub fn safe(self: Method) bool { - return switch (self) { + pub fn safe(m: Method) bool { + return switch (m) { .GET, .HEAD, .OPTIONS, .TRACE => true, .POST, .PUT, .DELETE, .CONNECT, .PATCH => false, - else => false, }; } @@ -88,11 +68,10 @@ pub const Method = enum(u64) { /// https://developer.mozilla.org/en-US/docs/Glossary/Idempotent /// /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.2 - pub fn idempotent(self: Method) bool { - return switch (self) { + pub fn idempotent(m: Method) bool { + return switch (m) { .GET, .HEAD, .PUT, .DELETE, .OPTIONS, .TRACE => true, .CONNECT, .POST, .PATCH => false, - else => false, }; } @@ -102,11 +81,10 @@ pub const Method = enum(u64) { /// https://developer.mozilla.org/en-US/docs/Glossary/cacheable /// /// https://datatracker.ietf.org/doc/html/rfc7231#section-4.2.3 - pub fn cacheable(self: Method) bool { - return switch (self) { + pub fn cacheable(m: Method) bool { + return switch (m) { .GET, .HEAD => true, .POST, .PUT, .DELETE, .CONNECT, .OPTIONS, .TRACE, .PATCH => false, - else => false, }; } }; @@ -296,13 +274,24 @@ pub const TransferEncoding = enum { }; pub const ContentEncoding = enum { - identity, - compress, - @"x-compress", - deflate, - gzip, - @"x-gzip", zstd, + gzip, + deflate, + compress, + identity, + + pub fn fromString(s: []const u8) ?ContentEncoding { + const map = std.StaticStringMap(ContentEncoding).initComptime(.{ + .{ "zstd", .zstd }, + .{ "gzip", .gzip }, + .{ "x-gzip", .gzip }, + .{ "deflate", .deflate }, + .{ "compress", .compress }, + .{ "x-compress", .compress }, + .{ "identity", .identity }, + }); + return map.get(s); + } }; pub const Connection = enum { @@ -315,15 +304,790 @@ pub const Header = struct { value: []const u8, }; +pub const Reader = struct { + in: *std.Io.Reader, + /// This is preallocated memory that might be used by `bodyReader`. That + /// function might return a pointer to this field, or a different + /// `*std.Io.Reader`. Advisable to not access this field directly. + interface: std.Io.Reader, + /// Keeps track of whether the stream is ready to accept a new request, + /// making invalid API usage cause assertion failures rather than HTTP + /// protocol violations. + state: State, + /// HTTP trailer bytes. These are at the end of a transfer-encoding: + /// chunked message. This data is available only after calling one of the + /// "end" functions and points to data inside the buffer of `in`, and is + /// therefore invalidated on the next call to `receiveHead`, or any other + /// read from `in`. + trailers: []const u8 = &.{}, + body_err: ?BodyError = null, + + pub const RemainingChunkLen = enum(u64) { + head = 0, + n = 1, + rn = 2, + _, + + pub fn init(integer: u64) RemainingChunkLen { + return @enumFromInt(integer); + } + + pub fn int(rcl: RemainingChunkLen) u64 { + return @intFromEnum(rcl); + } + }; + + pub const State = union(enum) { + /// The stream is available to be used for the first time, or reused. + ready, + received_head, + /// The stream goes until the connection is closed. + body_none, + body_remaining_content_length: u64, + body_remaining_chunk_len: RemainingChunkLen, + /// The stream would be eligible for another HTTP request, however the + /// client and server did not negotiate a persistent connection. + closing, + }; + + pub const BodyError = error{ + HttpChunkInvalid, + HttpChunkTruncated, + HttpHeadersOversize, + }; + + pub const HeadError = error{ + /// Too many bytes of HTTP headers. + /// + /// The HTTP specification suggests to respond with a 431 status code + /// before closing the connection. + HttpHeadersOversize, + /// Partial HTTP request was received but the connection was closed + /// before fully receiving the headers. + HttpRequestTruncated, + /// The client sent 0 bytes of headers before closing the stream. This + /// happens when a keep-alive connection is finally closed. + HttpConnectionClosing, + /// Transitive error occurred reading from `in`. + ReadFailed, + }; + + /// Buffers the entire head inside `in`. + /// + /// The resulting memory is invalidated by any subsequent consumption of + /// the input stream. + pub fn receiveHead(reader: *Reader) HeadError![]const u8 { + reader.trailers = &.{}; + const in = reader.in; + var hp: HeadParser = .{}; + var head_len: usize = 0; + while (true) { + if (in.buffer.len - head_len == 0) return error.HttpHeadersOversize; + const remaining = in.buffered()[head_len..]; + if (remaining.len == 0) { + in.fillMore() catch |err| switch (err) { + error.EndOfStream => switch (head_len) { + 0 => return error.HttpConnectionClosing, + else => return error.HttpRequestTruncated, + }, + error.ReadFailed => return error.ReadFailed, + }; + continue; + } + head_len += hp.feed(remaining); + if (hp.state == .finished) { + reader.state = .received_head; + const head_buffer = in.buffered()[0..head_len]; + in.toss(head_len); + return head_buffer; + } + } + } + + /// If compressed body has been negotiated this will return compressed bytes. + /// + /// Asserts only called once and after `receiveHead`. + /// + /// See also: + /// * `interfaceDecompressing` + pub fn bodyReader( + reader: *Reader, + buffer: []u8, + transfer_encoding: TransferEncoding, + content_length: ?u64, + ) *std.Io.Reader { + assert(reader.state == .received_head); + switch (transfer_encoding) { + .chunked => { + reader.state = .{ .body_remaining_chunk_len = .head }; + reader.interface = .{ + .buffer = buffer, + .seek = 0, + .end = 0, + .vtable = &.{ + .stream = chunkedStream, + .discard = chunkedDiscard, + }, + }; + return &reader.interface; + }, + .none => { + if (content_length) |len| { + reader.state = .{ .body_remaining_content_length = len }; + reader.interface = .{ + .buffer = buffer, + .seek = 0, + .end = 0, + .vtable = &.{ + .stream = contentLengthStream, + .discard = contentLengthDiscard, + }, + }; + return &reader.interface; + } else { + reader.state = .body_none; + return reader.in; + } + }, + } + } + + /// If compressed body has been negotiated this will return decompressed bytes. + /// + /// Asserts only called once and after `receiveHead`. + /// + /// See also: + /// * `interface` + pub fn bodyReaderDecompressing( + reader: *Reader, + transfer_encoding: TransferEncoding, + content_length: ?u64, + content_encoding: ContentEncoding, + decompressor: *Decompressor, + decompression_buffer: []u8, + ) *std.Io.Reader { + if (transfer_encoding == .none and content_length == null) { + assert(reader.state == .received_head); + reader.state = .body_none; + switch (content_encoding) { + .identity => { + return reader.in; + }, + .deflate => { + decompressor.* = .{ .flate = .init(reader.in, .zlib, decompression_buffer) }; + return &decompressor.flate.reader; + }, + .gzip => { + decompressor.* = .{ .flate = .init(reader.in, .gzip, decompression_buffer) }; + return &decompressor.flate.reader; + }, + .zstd => { + decompressor.* = .{ .zstd = .init(reader.in, decompression_buffer, .{ .verify_checksum = false }) }; + return &decompressor.zstd.reader; + }, + .compress => unreachable, + } + } + const transfer_reader = bodyReader(reader, &.{}, transfer_encoding, content_length); + return decompressor.init(transfer_reader, decompression_buffer, content_encoding); + } + + fn contentLengthStream( + io_r: *std.Io.Reader, + w: *Writer, + limit: std.Io.Limit, + ) std.Io.Reader.StreamError!usize { + const reader: *Reader = @alignCast(@fieldParentPtr("interface", io_r)); + const remaining_content_length = &reader.state.body_remaining_content_length; + const remaining = remaining_content_length.*; + if (remaining == 0) { + reader.state = .ready; + return error.EndOfStream; + } + const n = try reader.in.stream(w, limit.min(.limited64(remaining))); + remaining_content_length.* = remaining - n; + return n; + } + + fn contentLengthDiscard(io_r: *std.Io.Reader, limit: std.Io.Limit) std.Io.Reader.Error!usize { + const reader: *Reader = @alignCast(@fieldParentPtr("interface", io_r)); + const remaining_content_length = &reader.state.body_remaining_content_length; + const remaining = remaining_content_length.*; + if (remaining == 0) { + reader.state = .ready; + return error.EndOfStream; + } + const n = try reader.in.discard(limit.min(.limited64(remaining))); + remaining_content_length.* = remaining - n; + return n; + } + + fn chunkedStream(io_r: *std.Io.Reader, w: *Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize { + const reader: *Reader = @alignCast(@fieldParentPtr("interface", io_r)); + const chunk_len_ptr = switch (reader.state) { + .ready => return error.EndOfStream, + .body_remaining_chunk_len => |*x| x, + else => unreachable, + }; + return chunkedReadEndless(reader, w, limit, chunk_len_ptr) catch |err| switch (err) { + error.ReadFailed => return error.ReadFailed, + error.WriteFailed => return error.WriteFailed, + error.EndOfStream => { + reader.body_err = error.HttpChunkTruncated; + return error.ReadFailed; + }, + else => |e| { + reader.body_err = e; + return error.ReadFailed; + }, + }; + } + + fn chunkedReadEndless( + reader: *Reader, + w: *Writer, + limit: std.Io.Limit, + chunk_len_ptr: *RemainingChunkLen, + ) (BodyError || std.Io.Reader.StreamError)!usize { + const in = reader.in; + len: switch (chunk_len_ptr.*) { + .head => { + var cp: ChunkParser = .init; + while (true) { + const i = cp.feed(in.buffered()); + switch (cp.state) { + .invalid => return error.HttpChunkInvalid, + .data => { + in.toss(i); + break; + }, + else => { + in.toss(i); + try in.fillMore(); + continue; + }, + } + } + if (cp.chunk_len == 0) return parseTrailers(reader, 0); + const n = try in.stream(w, limit.min(.limited64(cp.chunk_len))); + chunk_len_ptr.* = .init(cp.chunk_len + 2 - n); + return n; + }, + .n => { + if ((try in.peekByte()) != '\n') return error.HttpChunkInvalid; + in.toss(1); + continue :len .head; + }, + .rn => { + const rn = try in.peekArray(2); + if (rn[0] != '\r' or rn[1] != '\n') return error.HttpChunkInvalid; + in.toss(2); + continue :len .head; + }, + else => |remaining_chunk_len| { + const n = try in.stream(w, limit.min(.limited64(@intFromEnum(remaining_chunk_len) - 2))); + chunk_len_ptr.* = .init(@intFromEnum(remaining_chunk_len) - n); + return n; + }, + } + } + + fn chunkedDiscard(io_r: *std.Io.Reader, limit: std.Io.Limit) std.Io.Reader.Error!usize { + const reader: *Reader = @alignCast(@fieldParentPtr("interface", io_r)); + const chunk_len_ptr = switch (reader.state) { + .ready => return error.EndOfStream, + .body_remaining_chunk_len => |*x| x, + else => unreachable, + }; + return chunkedDiscardEndless(reader, limit, chunk_len_ptr) catch |err| switch (err) { + error.ReadFailed => return error.ReadFailed, + error.EndOfStream => { + reader.body_err = error.HttpChunkTruncated; + return error.ReadFailed; + }, + else => |e| { + reader.body_err = e; + return error.ReadFailed; + }, + }; + } + + fn chunkedDiscardEndless( + reader: *Reader, + limit: std.Io.Limit, + chunk_len_ptr: *RemainingChunkLen, + ) (BodyError || std.Io.Reader.Error)!usize { + const in = reader.in; + len: switch (chunk_len_ptr.*) { + .head => { + var cp: ChunkParser = .init; + while (true) { + const i = cp.feed(in.buffered()); + switch (cp.state) { + .invalid => return error.HttpChunkInvalid, + .data => { + in.toss(i); + break; + }, + else => { + in.toss(i); + try in.fillMore(); + continue; + }, + } + } + if (cp.chunk_len == 0) return parseTrailers(reader, 0); + const n = try in.discard(limit.min(.limited64(cp.chunk_len))); + chunk_len_ptr.* = .init(cp.chunk_len + 2 - n); + return n; + }, + .n => { + if ((try in.peekByte()) != '\n') return error.HttpChunkInvalid; + in.toss(1); + continue :len .head; + }, + .rn => { + const rn = try in.peekArray(2); + if (rn[0] != '\r' or rn[1] != '\n') return error.HttpChunkInvalid; + in.toss(2); + continue :len .head; + }, + else => |remaining_chunk_len| { + const n = try in.discard(limit.min(.limited64(remaining_chunk_len.int() - 2))); + chunk_len_ptr.* = .init(remaining_chunk_len.int() - n); + return n; + }, + } + } + + /// Called when next bytes in the stream are trailers, or "\r\n" to indicate + /// end of chunked body. + fn parseTrailers(reader: *Reader, amt_read: usize) (BodyError || std.Io.Reader.Error)!usize { + const in = reader.in; + const rn = try in.peekArray(2); + if (rn[0] == '\r' and rn[1] == '\n') { + in.toss(2); + reader.state = .ready; + assert(reader.trailers.len == 0); + return amt_read; + } + var hp: HeadParser = .{ .state = .seen_rn }; + var trailers_len: usize = 2; + while (true) { + if (in.buffer.len - trailers_len == 0) return error.HttpHeadersOversize; + const remaining = in.buffered()[trailers_len..]; + if (remaining.len == 0) { + try in.fillMore(); + continue; + } + trailers_len += hp.feed(remaining); + if (hp.state == .finished) { + reader.state = .ready; + reader.trailers = in.buffered()[0..trailers_len]; + in.toss(trailers_len); + return amt_read; + } + } + } +}; + +pub const Decompressor = union(enum) { + flate: std.compress.flate.Decompress, + zstd: std.compress.zstd.Decompress, + none: *std.Io.Reader, + + pub fn init( + decompressor: *Decompressor, + transfer_reader: *std.Io.Reader, + buffer: []u8, + content_encoding: ContentEncoding, + ) *std.Io.Reader { + switch (content_encoding) { + .identity => { + decompressor.* = .{ .none = transfer_reader }; + return transfer_reader; + }, + .deflate => { + decompressor.* = .{ .flate = .init(transfer_reader, .zlib, buffer) }; + return &decompressor.flate.reader; + }, + .gzip => { + decompressor.* = .{ .flate = .init(transfer_reader, .gzip, buffer) }; + return &decompressor.flate.reader; + }, + .zstd => { + decompressor.* = .{ .zstd = .init(transfer_reader, buffer, .{ .verify_checksum = false }) }; + return &decompressor.zstd.reader; + }, + .compress => unreachable, + } + } +}; + +/// Request or response body. +pub const BodyWriter = struct { + /// Until the lifetime of `BodyWriter` ends, it is illegal to modify the + /// state of this other than via methods of `BodyWriter`. + http_protocol_output: *Writer, + state: State, + writer: Writer, + + pub const Error = Writer.Error; + + /// How many zeroes to reserve for hex-encoded chunk length. + const chunk_len_digits = 8; + const max_chunk_len: usize = std.math.pow(u64, 16, chunk_len_digits) - 1; + const chunk_header_template = ("0" ** chunk_len_digits) ++ "\r\n"; + + comptime { + assert(max_chunk_len == std.math.maxInt(u32)); + } + + pub const State = union(enum) { + /// End of connection signals the end of the stream. + none, + /// As a debugging utility, counts down to zero as bytes are written. + content_length: u64, + /// Each chunk is wrapped in a header and trailer. + chunked: Chunked, + /// Cleanly finished stream; connection can be reused. + end, + + pub const Chunked = union(enum) { + /// Index to the start of the hex-encoded chunk length in the chunk + /// header within the buffer of `BodyWriter.http_protocol_output`. + /// Buffered chunk data starts here plus length of `chunk_header_template`. + offset: usize, + /// We are in the middle of a chunk and this is how many bytes are + /// left until the next header. This includes +2 for "\r"\n", and + /// is zero for the beginning of the stream. + chunk_len: usize, + + pub const init: Chunked = .{ .chunk_len = 0 }; + }; + }; + + pub fn isEliding(w: *const BodyWriter) bool { + return w.writer.vtable.drain == elidingDrain; + } + + /// Sends all buffered data across `BodyWriter.http_protocol_output`. + pub fn flush(w: *BodyWriter) Error!void { + const out = w.http_protocol_output; + switch (w.state) { + .end, .none, .content_length => return out.flush(), + .chunked => |*chunked| switch (chunked.*) { + .offset => |offset| { + const chunk_len = out.end - offset - chunk_header_template.len; + if (chunk_len > 0) { + writeHex(out.buffer[offset..][0..chunk_len_digits], chunk_len); + chunked.* = .{ .chunk_len = 2 }; + } else { + out.end = offset; + chunked.* = .{ .chunk_len = 0 }; + } + try out.flush(); + }, + .chunk_len => return out.flush(), + }, + } + } + + /// When using content-length, asserts that the amount of data sent matches + /// the value sent in the header, then flushes. + /// + /// When using transfer-encoding: chunked, writes the end-of-stream message + /// with empty trailers, then flushes the stream to the system. Asserts any + /// started chunk has been completely finished. + /// + /// Respects the value of `isEliding` to omit all data after the headers. + /// + /// See also: + /// * `endUnflushed` + /// * `endChunked` + pub fn end(w: *BodyWriter) Error!void { + try endUnflushed(w); + try w.http_protocol_output.flush(); + } + + /// When using content-length, asserts that the amount of data sent matches + /// the value sent in the header. + /// + /// Otherwise, transfer-encoding: chunked is being used, and it writes the + /// end-of-stream message with empty trailers. + /// + /// Respects the value of `isEliding` to omit all data after the headers. + /// + /// See also: + /// * `end` + /// * `endChunked` + pub fn endUnflushed(w: *BodyWriter) Error!void { + switch (w.state) { + .end => unreachable, + .content_length => |len| { + assert(len == 0); // Trips when end() called before all bytes written. + w.state = .end; + }, + .none => {}, + .chunked => return endChunkedUnflushed(w, .{}), + } + } + + pub const EndChunkedOptions = struct { + trailers: []const Header = &.{}, + }; + + /// Writes the end-of-stream message and any optional trailers, flushing + /// the underlying stream. + /// + /// Asserts that the BodyWriter is using transfer-encoding: chunked. + /// + /// Respects the value of `isEliding` to omit all data after the headers. + /// + /// See also: + /// * `endChunkedUnflushed` + /// * `end` + pub fn endChunked(w: *BodyWriter, options: EndChunkedOptions) Error!void { + try endChunkedUnflushed(w, options); + try w.http_protocol_output.flush(); + } + + /// Writes the end-of-stream message and any optional trailers. + /// + /// Does not flush. + /// + /// Asserts that the BodyWriter is using transfer-encoding: chunked. + /// + /// Respects the value of `isEliding` to omit all data after the headers. + /// + /// See also: + /// * `endChunked` + /// * `endUnflushed` + /// * `end` + pub fn endChunkedUnflushed(w: *BodyWriter, options: EndChunkedOptions) Error!void { + const chunked = &w.state.chunked; + if (w.isEliding()) { + w.state = .end; + return; + } + const bw = w.http_protocol_output; + switch (chunked.*) { + .offset => |offset| { + const chunk_len = bw.end - offset - chunk_header_template.len; + writeHex(bw.buffer[offset..][0..chunk_len_digits], chunk_len); + try bw.writeAll("\r\n"); + }, + .chunk_len => |chunk_len| switch (chunk_len) { + 0 => {}, + 1 => try bw.writeByte('\n'), + 2 => try bw.writeAll("\r\n"), + else => unreachable, // An earlier write call indicated more data would follow. + }, + } + try bw.writeAll("0\r\n"); + for (options.trailers) |trailer| { + try bw.writeAll(trailer.name); + try bw.writeAll(": "); + try bw.writeAll(trailer.value); + try bw.writeAll("\r\n"); + } + try bw.writeAll("\r\n"); + w.state = .end; + } + + pub fn contentLengthDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const n = try out.writeSplatHeader(w.buffered(), data, splat); + bw.state.content_length -= n; + return w.consume(n); + } + + pub fn noneDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const n = try out.writeSplatHeader(w.buffered(), data, splat); + return w.consume(n); + } + + pub fn elidingDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); + const slice = data[0 .. data.len - 1]; + const pattern = data[slice.len]; + var written: usize = pattern.len * splat; + for (slice) |bytes| written += bytes.len; + switch (bw.state) { + .content_length => |*len| len.* -= written + w.end, + else => {}, + } + w.end = 0; + return written; + } + + pub fn elidingSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); + if (File.Handle == void) return error.Unimplemented; + if (builtin.zig_backend == .stage2_aarch64) return error.Unimplemented; + switch (bw.state) { + .content_length => |*len| len.* -= w.end, + else => {}, + } + w.end = 0; + if (limit == .nothing) return 0; + if (file_reader.getSize()) |size| { + const n = limit.minInt64(size - file_reader.pos); + if (n == 0) return error.EndOfStream; + file_reader.seekBy(@intCast(n)) catch return error.Unimplemented; + switch (bw.state) { + .content_length => |*len| len.* -= n, + else => {}, + } + return n; + } else |_| { + // Error is observable on `file_reader` instance, and it is better to + // treat the file as a pipe. + return error.Unimplemented; + } + } + + /// Returns `null` if size cannot be computed without making any syscalls. + pub fn noneSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const n = try out.sendFileHeader(w.buffered(), file_reader, limit); + return w.consume(n); + } + + pub fn contentLengthSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const n = try out.sendFileHeader(w.buffered(), file_reader, limit); + bw.state.content_length -= n; + return w.consume(n); + } + + pub fn chunkedSendFile(w: *Writer, file_reader: *File.Reader, limit: std.Io.Limit) Writer.FileError!usize { + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); + assert(!bw.isEliding()); + const data_len = Writer.countSendFileLowerBound(w.end, file_reader, limit) orelse { + // If the file size is unknown, we cannot lower to a `sendFile` since we would + // have to flush the chunk header before knowing the chunk length. + return error.Unimplemented; + }; + const out = bw.http_protocol_output; + const chunked = &bw.state.chunked; + state: switch (chunked.*) { + .offset => |off| { + // TODO: is it better perf to read small files into the buffer? + const buffered_len = out.end - off - chunk_header_template.len; + const chunk_len = data_len + buffered_len; + writeHex(out.buffer[off..][0..chunk_len_digits], chunk_len); + const n = try out.sendFileHeader(w.buffered(), file_reader, limit); + chunked.* = .{ .chunk_len = data_len + 2 - n }; + return w.consume(n); + }, + .chunk_len => |chunk_len| l: switch (chunk_len) { + 0 => { + const off = out.end; + const header_buf = try out.writableArray(chunk_header_template.len); + @memcpy(header_buf, chunk_header_template); + chunked.* = .{ .offset = off }; + continue :state .{ .offset = off }; + }, + 1 => { + try out.writeByte('\n'); + chunked.chunk_len = 0; + continue :l 0; + }, + 2 => { + try out.writeByte('\r'); + chunked.chunk_len = 1; + continue :l 1; + }, + else => { + const new_limit = limit.min(.limited(chunk_len - 2)); + const n = try out.sendFileHeader(w.buffered(), file_reader, new_limit); + chunked.chunk_len = chunk_len - n; + return w.consume(n); + }, + }, + } + } + + pub fn chunkedDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize { + const bw: *BodyWriter = @alignCast(@fieldParentPtr("writer", w)); + assert(!bw.isEliding()); + const out = bw.http_protocol_output; + const data_len = w.end + Writer.countSplat(data, splat); + const chunked = &bw.state.chunked; + state: switch (chunked.*) { + .offset => |offset| { + if (out.unusedCapacityLen() >= data_len) { + return w.consume(out.writeSplatHeader(w.buffered(), data, splat) catch unreachable); + } + const buffered_len = out.end - offset - chunk_header_template.len; + const chunk_len = data_len + buffered_len; + writeHex(out.buffer[offset..][0..chunk_len_digits], chunk_len); + const n = try out.writeSplatHeader(w.buffered(), data, splat); + chunked.* = .{ .chunk_len = data_len + 2 - n }; + return w.consume(n); + }, + .chunk_len => |chunk_len| l: switch (chunk_len) { + 0 => { + const offset = out.end; + const header_buf = try out.writableArray(chunk_header_template.len); + @memcpy(header_buf, chunk_header_template); + chunked.* = .{ .offset = offset }; + continue :state .{ .offset = offset }; + }, + 1 => { + try out.writeByte('\n'); + chunked.chunk_len = 0; + continue :l 0; + }, + 2 => { + try out.writeByte('\r'); + chunked.chunk_len = 1; + continue :l 1; + }, + else => { + const n = try out.writeSplatHeaderLimit(w.buffered(), data, splat, .limited(chunk_len - 2)); + chunked.chunk_len = chunk_len - n; + return w.consume(n); + }, + }, + } + } + + /// Writes an integer as base 16 to `buf`, right-aligned, assuming the + /// buffer has already been filled with zeroes. + fn writeHex(buf: []u8, x: usize) void { + assert(std.mem.allEqual(u8, buf, '0')); + const base = 16; + var index: usize = buf.len; + var a = x; + while (a > 0) { + const digit = a % base; + index -= 1; + buf[index] = std.fmt.digitToChar(@intCast(digit), .lower); + a /= base; + } + } +}; + test { + _ = Server; + _ = Status; + _ = Method; + _ = ChunkParser; + _ = HeadParser; + if (builtin.os.tag != .wasi) { _ = Client; - _ = Method; - _ = Server; - _ = Status; - _ = HeadParser; - _ = ChunkParser; - _ = WebSocket; _ = @import("http/test.zig"); } } diff --git a/lib/std/http/ChunkParser.zig b/lib/std/http/ChunkParser.zig index adcdc74bc7be..7c628ec32777 100644 --- a/lib/std/http/ChunkParser.zig +++ b/lib/std/http/ChunkParser.zig @@ -1,5 +1,8 @@ //! Parser for transfer-encoding: chunked. +const ChunkParser = @This(); +const std = @import("std"); + state: State, chunk_len: u64, @@ -97,9 +100,6 @@ pub fn feed(p: *ChunkParser, bytes: []const u8) usize { return bytes.len; } -const ChunkParser = @This(); -const std = @import("std"); - test feed { const testing = std.testing; diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 20f6018e4534..37022b4d0b22 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -13,9 +13,10 @@ const net = std.net; const Uri = std.Uri; const Allocator = mem.Allocator; const assert = std.debug.assert; +const Writer = std.io.Writer; +const Reader = std.io.Reader; const Client = @This(); -const proto = @import("protocol.zig"); pub const disable_tls = std.options.http_disable_tls; @@ -24,6 +25,12 @@ allocator: Allocator, ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{}, ca_bundle_mutex: std.Thread.Mutex = .{}, +/// Used both for the reader and writer buffers. +tls_buffer_size: if (disable_tls) u0 else usize = if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len, +/// If non-null, ssl secrets are logged to a stream. Creating such a stream +/// allows other processes with access to that stream to decrypt all +/// traffic over connections created with this `Client`. +ssl_key_log: ?*std.crypto.tls.Client.SslKeyLog = null, /// When this is `true`, the next time this client performs an HTTPS request, /// it will first rescan the system for root certificates. @@ -31,6 +38,13 @@ next_https_rescan_certs: bool = true, /// The pool of connections that can be reused (and currently in use). connection_pool: ConnectionPool = .{}, +/// Each `Connection` allocates this amount for the reader buffer. +/// +/// If the entire HTTP header cannot fit in this amount of bytes, +/// `error.HttpHeadersOversize` will be returned from `Request.wait`. +read_buffer_size: usize = 4096 + if (disable_tls) 0 else std.crypto.tls.Client.min_buffer_len, +/// Each `Connection` allocates this amount for the writer buffer. +write_buffer_size: usize = 1024, /// If populated, all http traffic travels through this third party. /// This field cannot be modified while the client has active connections. @@ -41,7 +55,7 @@ http_proxy: ?*Proxy = null, /// Pointer to externally-owned memory. https_proxy: ?*Proxy = null, -/// A set of linked lists of connections that can be reused. +/// A Least-Recently-Used cache of open connections to be reused. pub const ConnectionPool = struct { mutex: std.Thread.Mutex = .{}, /// Open connections that are currently in use. @@ -55,23 +69,25 @@ pub const ConnectionPool = struct { pub const Criteria = struct { host: []const u8, port: u16, - protocol: Connection.Protocol, + protocol: Protocol, }; - /// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe. + /// Finds and acquires a connection from the connection pool matching the criteria. /// If no connection is found, null is returned. + /// + /// Threadsafe. pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection { pool.mutex.lock(); defer pool.mutex.unlock(); var next = pool.free.last; while (next) |node| : (next = node.prev) { - const connection: *Connection = @fieldParentPtr("pool_node", node); + const connection: *Connection = @alignCast(@fieldParentPtr("pool_node", node)); if (connection.protocol != criteria.protocol) continue; if (connection.port != criteria.port) continue; // Domain names are case-insensitive (RFC 5890, Section 2.3.2.4) - if (!std.ascii.eqlIgnoreCase(connection.host, criteria.host)) continue; + if (!std.ascii.eqlIgnoreCase(connection.host(), criteria.host)) continue; pool.acquireUnsafe(connection); return connection; @@ -96,28 +112,23 @@ pub const ConnectionPool = struct { return pool.acquireUnsafe(connection); } - /// Tries to release a connection back to the connection pool. This function is threadsafe. + /// Tries to release a connection back to the connection pool. /// If the connection is marked as closing, it will be closed instead. /// - /// The allocator must be the owner of all nodes in this pool. - /// The allocator must be the owner of all resources associated with the connection. - pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void { + /// Threadsafe. + pub fn release(pool: *ConnectionPool, connection: *Connection) void { pool.mutex.lock(); defer pool.mutex.unlock(); pool.used.remove(&connection.pool_node); - if (connection.closing or pool.free_size == 0) { - connection.close(allocator); - return allocator.destroy(connection); - } + if (connection.closing or pool.free_size == 0) return connection.destroy(); if (pool.free_len >= pool.free_size) { - const popped: *Connection = @fieldParentPtr("pool_node", pool.free.popFirst().?); + const popped: *Connection = @alignCast(@fieldParentPtr("pool_node", pool.free.popFirst().?)); pool.free_len -= 1; - popped.close(allocator); - allocator.destroy(popped); + popped.destroy(); } if (connection.proxied) { @@ -138,9 +149,11 @@ pub const ConnectionPool = struct { pool.used.append(&connection.pool_node); } - /// Resizes the connection pool. This function is threadsafe. + /// Resizes the connection pool. /// /// If the new size is smaller than the current size, then idle connections will be closed until the pool is the new size. + /// + /// Threadsafe. pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void { pool.mutex.lock(); defer pool.mutex.unlock(); @@ -158,538 +171,612 @@ pub const ConnectionPool = struct { pool.free_size = new_size; } - /// Frees the connection pool and closes all connections within. This function is threadsafe. + /// Frees the connection pool and closes all connections within. /// /// All future operations on the connection pool will deadlock. - pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void { + /// + /// Threadsafe. + pub fn deinit(pool: *ConnectionPool) void { pool.mutex.lock(); var next = pool.free.first; while (next) |node| { - const connection: *Connection = @fieldParentPtr("pool_node", node); + const connection: *Connection = @alignCast(@fieldParentPtr("pool_node", node)); next = node.next; - connection.close(allocator); - allocator.destroy(connection); + connection.destroy(); } next = pool.used.first; while (next) |node| { - const connection: *Connection = @fieldParentPtr("pool_node", node); + const connection: *Connection = @alignCast(@fieldParentPtr("pool_node", node)); next = node.next; - connection.close(allocator); - allocator.destroy(node); + connection.destroy(); } pool.* = undefined; } }; -/// An interface to either a plain or TLS connection. -pub const Connection = struct { - stream: net.Stream, - /// undefined unless protocol is tls. - tls_client: if (!disable_tls) *std.crypto.tls.Client else void, - - /// Entry in `ConnectionPool.used` or `ConnectionPool.free`. - pool_node: std.DoublyLinkedList.Node, - - /// The protocol that this connection is using. - protocol: Protocol, - - /// The host that this connection is connected to. - host: []u8, +pub const Protocol = enum { + plain, + tls, - /// The port that this connection is connected to. - port: u16, - - /// Whether this connection is proxied and is not directly connected. - proxied: bool = false, - - /// Whether this connection is closing when we're done with it. - closing: bool = false, - - read_start: BufferSize = 0, - read_end: BufferSize = 0, - write_end: BufferSize = 0, - read_buf: [buffer_size]u8 = undefined, - write_buf: [buffer_size]u8 = undefined, - - pub const buffer_size = std.crypto.tls.max_ciphertext_record_len; - const BufferSize = std.math.IntFittingRange(0, buffer_size); - - pub const Protocol = enum { plain, tls }; - - pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - return conn.tls_client.readv(conn.stream, buffers) catch |err| { - // https://github.com/ziglang/zig/issues/2473 - if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert; - - switch (err) { - error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure, - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - } + fn port(protocol: Protocol) u16 { + return switch (protocol) { + .plain => 80, + .tls => 443, }; } - pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.readvDirectTls(buffers); - } - - return conn.stream.readv(buffers) catch |err| switch (err) { - error.ConnectionTimedOut => return error.ConnectionTimedOut, - error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer, - else => return error.UnexpectedReadFailure, - }; - } - - /// Refills the read buffer with data from the connection. - pub fn fill(conn: *Connection) ReadError!void { - if (conn.read_end != conn.read_start) return; - - var iovecs = [1]std.posix.iovec{ - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); - if (nread == 0) return error.EndOfStream; - conn.read_start = 0; - conn.read_end = @intCast(nread); + pub fn fromScheme(scheme: []const u8) ?Protocol { + const protocol_map = std.StaticStringMap(Protocol).initComptime(.{ + .{ "http", .plain }, + .{ "ws", .plain }, + .{ "https", .tls }, + .{ "wss", .tls }, + }); + return protocol_map.get(scheme); } - /// Returns the current slice of buffered data. - pub fn peek(conn: *Connection) []const u8 { - return conn.read_buf[conn.read_start..conn.read_end]; + pub fn fromUri(uri: Uri) ?Protocol { + return fromScheme(uri.scheme); } +}; - /// Discards the given number of bytes from the read buffer. - pub fn drop(conn: *Connection, num: BufferSize) void { - conn.read_start += num; - } +pub const Connection = struct { + client: *Client, + stream_writer: net.Stream.Writer, + stream_reader: net.Stream.Reader, + /// Entry in `ConnectionPool.used` or `ConnectionPool.free`. + pool_node: std.DoublyLinkedList.Node, + port: u16, + host_len: u8, + proxied: bool, + closing: bool, + protocol: Protocol, - /// Reads data from the connection into the given buffer. - pub fn read(conn: *Connection, buffer: []u8) ReadError!usize { - const available_read = conn.read_end - conn.read_start; - const available_buffer = buffer.len; + const Plain = struct { + connection: Connection, + + fn create( + client: *Client, + remote_host: []const u8, + port: u16, + stream: net.Stream, + ) error{OutOfMemory}!*Plain { + const gpa = client.allocator; + const alloc_len = allocLen(client, remote_host.len); + const base = try gpa.alignedAlloc(u8, .of(Plain), alloc_len); + errdefer gpa.free(base); + const host_buffer = base[@sizeOf(Plain)..][0..remote_host.len]; + const socket_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.read_buffer_size]; + const socket_write_buffer = socket_read_buffer.ptr[socket_read_buffer.len..][0..client.write_buffer_size]; + assert(base.ptr + alloc_len == socket_write_buffer.ptr + socket_write_buffer.len); + @memcpy(host_buffer, remote_host); + const plain: *Plain = @ptrCast(base); + plain.* = .{ + .connection = .{ + .client = client, + .stream_writer = stream.writer(socket_write_buffer), + .stream_reader = stream.reader(socket_read_buffer), + .pool_node = .{}, + .port = port, + .host_len = @intCast(remote_host.len), + .proxied = false, + .closing = false, + .protocol = .plain, + }, + }; + return plain; + } - if (available_read > available_buffer) { // partially read buffered data - @memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]); - conn.read_start += @intCast(available_buffer); + fn destroy(plain: *Plain) void { + const c = &plain.connection; + const gpa = c.client.allocator; + const base: [*]align(@alignOf(Plain)) u8 = @ptrCast(plain); + gpa.free(base[0..allocLen(c.client, c.host_len)]); + } - return available_buffer; - } else if (available_read > 0) { // fully read buffered data - @memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]); - conn.read_start += available_read; + fn allocLen(client: *Client, host_len: usize) usize { + return @sizeOf(Plain) + host_len + client.read_buffer_size + client.write_buffer_size; + } - return available_read; + fn host(plain: *Plain) []u8 { + const base: [*]u8 = @ptrCast(plain); + return base[@sizeOf(Plain)..][0..plain.connection.host_len]; } + }; - var iovecs = [2]std.posix.iovec{ - .{ .base = buffer.ptr, .len = buffer.len }, - .{ .base = &conn.read_buf, .len = conn.read_buf.len }, - }; - const nread = try conn.readvDirect(&iovecs); + const Tls = struct { + client: std.crypto.tls.Client, + connection: Connection, + + fn create( + client: *Client, + remote_host: []const u8, + port: u16, + stream: net.Stream, + ) error{ OutOfMemory, TlsInitializationFailed }!*Tls { + const gpa = client.allocator; + const alloc_len = allocLen(client, remote_host.len); + const base = try gpa.alignedAlloc(u8, .of(Tls), alloc_len); + errdefer gpa.free(base); + const host_buffer = base[@sizeOf(Tls)..][0..remote_host.len]; + const tls_read_buffer = host_buffer.ptr[host_buffer.len..][0..client.tls_buffer_size]; + const tls_write_buffer = tls_read_buffer.ptr[tls_read_buffer.len..][0..client.tls_buffer_size]; + const write_buffer = tls_write_buffer.ptr[tls_write_buffer.len..][0..client.write_buffer_size]; + const read_buffer = write_buffer.ptr[write_buffer.len..][0..client.read_buffer_size]; + assert(base.ptr + alloc_len == read_buffer.ptr + read_buffer.len); + @memcpy(host_buffer, remote_host); + const tls: *Tls = @ptrCast(base); + tls.* = .{ + .connection = .{ + .client = client, + .stream_writer = stream.writer(tls_write_buffer), + .stream_reader = stream.reader(tls_read_buffer), + .pool_node = .{}, + .port = port, + .host_len = @intCast(remote_host.len), + .proxied = false, + .closing = false, + .protocol = .tls, + }, + // TODO data race here on ca_bundle if the user sets next_https_rescan_certs to true + .client = std.crypto.tls.Client.init( + tls.connection.stream_reader.interface(), + &tls.connection.stream_writer.interface, + .{ + .host = .{ .explicit = remote_host }, + .ca = .{ .bundle = client.ca_bundle }, + .ssl_key_log = client.ssl_key_log, + .read_buffer = read_buffer, + .write_buffer = write_buffer, + // This is appropriate for HTTPS because the HTTP headers contain + // the content length which is used to detect truncation attacks. + .allow_truncation_attacks = true, + }, + ) catch return error.TlsInitializationFailed, + }; + return tls; + } - if (nread > buffer.len) { - conn.read_start = 0; - conn.read_end = @intCast(nread - buffer.len); - return buffer.len; + fn destroy(tls: *Tls) void { + const c = &tls.connection; + const gpa = c.client.allocator; + const base: [*]align(@alignOf(Tls)) u8 = @ptrCast(tls); + gpa.free(base[0..allocLen(c.client, c.host_len)]); } - return nread; - } + fn allocLen(client: *Client, host_len: usize) usize { + return @sizeOf(Tls) + host_len + client.tls_buffer_size + client.tls_buffer_size + + client.write_buffer_size + client.read_buffer_size; + } - pub const ReadError = error{ - TlsFailure, - TlsAlert, - ConnectionTimedOut, - ConnectionResetByPeer, - UnexpectedReadFailure, - EndOfStream, + fn host(tls: *Tls) []u8 { + const base: [*]u8 = @ptrCast(tls); + return base[@sizeOf(Tls)..][0..tls.connection.host_len]; + } }; - pub const Reader = std.io.GenericReader(*Connection, ReadError, read); - - pub fn reader(conn: *Connection) Reader { - return Reader{ .context = conn }; - } + pub const ReadError = std.crypto.tls.Client.ReadError || std.net.Stream.ReadError; - pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void { - return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, + pub fn getReadError(c: *const Connection) ?ReadError { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *const Tls = @alignCast(@fieldParentPtr("connection", c)); + return tls.client.read_err orelse c.stream_reader.getError(); + }, + .plain => { + return c.stream_reader.getError(); + }, }; } - pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void { - if (conn.protocol == .tls) { - if (disable_tls) unreachable; - - return conn.writeAllDirectTls(buffer); - } + fn getStream(c: *Connection) net.Stream { + return c.stream_reader.getStream(); + } - return conn.stream.writeAll(buffer) catch |err| switch (err) { - error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer, - else => return error.UnexpectedWriteFailure, + fn host(c: *Connection) []u8 { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); + return tls.host(); + }, + .plain => { + const plain: *Plain = @alignCast(@fieldParentPtr("connection", c)); + return plain.host(); + }, }; } - /// Writes the given buffer to the connection. - pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize { - if (conn.write_buf.len - conn.write_end < buffer.len) { - try conn.flush(); - - if (buffer.len > conn.write_buf.len) { - try conn.writeAllDirect(buffer); - return buffer.len; - } + /// If this is called without calling `flush` or `end`, data will be + /// dropped unsent. + pub fn destroy(c: *Connection) void { + c.getStream().close(); + switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); + tls.destroy(); + }, + .plain => { + const plain: *Plain = @alignCast(@fieldParentPtr("connection", c)); + plain.destroy(); + }, } - - @memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer); - conn.write_end += @intCast(buffer.len); - - return buffer.len; } - /// Returns a buffer to be filled with exactly len bytes to write to the connection. - pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 { - if (conn.write_buf.len - conn.write_end < len) try conn.flush(); - defer conn.write_end += len; - return conn.write_buf[conn.write_end..][0..len]; + /// HTTP protocol from client to server. + /// This either goes directly to `stream_writer`, or to a TLS client. + pub fn writer(c: *Connection) *Writer { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); + return &tls.client.writer; + }, + .plain => &c.stream_writer.interface, + }; } - /// Flushes the write buffer to the connection. - pub fn flush(conn: *Connection) WriteError!void { - if (conn.write_end == 0) return; - - try conn.writeAllDirect(conn.write_buf[0..conn.write_end]); - conn.write_end = 0; + /// HTTP protocol from server to client. + /// This either comes directly from `stream_reader`, or from a TLS client. + pub fn reader(c: *Connection) *Reader { + return switch (c.protocol) { + .tls => { + if (disable_tls) unreachable; + const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); + return &tls.client.reader; + }, + .plain => c.stream_reader.interface(), + }; } - pub const WriteError = error{ - ConnectionResetByPeer, - UnexpectedWriteFailure, - }; - - pub const Writer = std.io.GenericWriter(*Connection, WriteError, write); - - pub fn writer(conn: *Connection) Writer { - return Writer{ .context = conn }; + pub fn flush(c: *Connection) Writer.Error!void { + if (c.protocol == .tls) { + if (disable_tls) unreachable; + const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); + try tls.client.writer.flush(); + } + try c.stream_writer.interface.flush(); } - /// Closes the connection. - pub fn close(conn: *Connection, allocator: Allocator) void { - if (conn.protocol == .tls) { + /// If the connection is a TLS connection, sends the close_notify alert. + /// + /// Flushes all buffers. + pub fn end(c: *Connection) Writer.Error!void { + if (c.protocol == .tls) { if (disable_tls) unreachable; - - // try to cleanly close the TLS connection, for any server that cares. - _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; - if (conn.tls_client.ssl_key_log) |key_log| key_log.file.close(); - allocator.destroy(conn.tls_client); + const tls: *Tls = @alignCast(@fieldParentPtr("connection", c)); + try tls.client.end(); } - - conn.stream.close(); - allocator.free(conn.host); + try c.stream_writer.interface.flush(); } }; -/// The mode of transport for requests. -pub const RequestTransfer = union(enum) { - content_length: u64, - chunked: void, - none: void, -}; - -/// The decompressor for response messages. -pub const Compression = union(enum) { - //deflate: std.compress.flate.Decompress, - //gzip: std.compress.flate.Decompress, - // https://github.com/ziglang/zig/issues/18937 - //zstd: ZstdDecompressor, - none: void, -}; - -/// A HTTP response originating from a server. pub const Response = struct { - version: http.Version, - status: http.Status, - reason: []const u8, + request: *Request, + /// Pointers in this struct are invalidated when the response body stream + /// is initialized. + head: Head, + + pub const Head = struct { + bytes: []const u8, + version: http.Version, + status: http.Status, + reason: []const u8, + location: ?[]const u8 = null, + content_type: ?[]const u8 = null, + content_disposition: ?[]const u8 = null, + + keep_alive: bool, + + /// If present, the number of bytes in the response body. + content_length: ?u64 = null, + + transfer_encoding: http.TransferEncoding = .none, + content_encoding: http.ContentEncoding = .identity, + + pub const ParseError = error{ + HttpConnectionHeaderUnsupported, + HttpContentEncodingUnsupported, + HttpHeaderContinuationsUnsupported, + HttpHeadersInvalid, + HttpTransferEncodingUnsupported, + InvalidContentLength, + }; - /// Points into the user-provided `server_header_buffer`. - location: ?[]const u8 = null, - /// Points into the user-provided `server_header_buffer`. - content_type: ?[]const u8 = null, - /// Points into the user-provided `server_header_buffer`. - content_disposition: ?[]const u8 = null, + pub fn parse(bytes: []const u8) ParseError!Head { + var res: Head = .{ + .bytes = bytes, + .status = undefined, + .reason = undefined, + .version = undefined, + .keep_alive = false, + }; + var it = mem.splitSequence(u8, bytes, "\r\n"); - keep_alive: bool, + const first_line = it.first(); + if (first_line.len < 12) return error.HttpHeadersInvalid; - /// If present, the number of bytes in the response body. - content_length: ?u64 = null, + const version: http.Version = switch (int64(first_line[0..8])) { + int64("HTTP/1.0") => .@"HTTP/1.0", + int64("HTTP/1.1") => .@"HTTP/1.1", + else => return error.HttpHeadersInvalid, + }; + if (first_line[8] != ' ') return error.HttpHeadersInvalid; + const status: http.Status = @enumFromInt(parseInt3(first_line[9..12])); + const reason = mem.trimLeft(u8, first_line[12..], " "); + + res.version = version; + res.status = status; + res.reason = reason; + res.keep_alive = switch (version) { + .@"HTTP/1.0" => false, + .@"HTTP/1.1" => true, + }; - /// If present, the transfer encoding of the response body, otherwise none. - transfer_encoding: http.TransferEncoding = .none, + while (it.next()) |line| { + if (line.len == 0) return res; + switch (line[0]) { + ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, + else => {}, + } - /// If present, the compression of the response body, otherwise identity (no compression). - transfer_compression: http.ContentEncoding = .identity, + var line_it = mem.splitScalar(u8, line, ':'); + const header_name = line_it.next().?; + const header_value = mem.trim(u8, line_it.rest(), " \t"); + if (header_name.len == 0) return error.HttpHeadersInvalid; + + if (std.ascii.eqlIgnoreCase(header_name, "connection")) { + res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); + } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { + res.content_type = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "location")) { + res.location = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) { + res.content_disposition = header_value; + } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { + // Transfer-Encoding: second, first + // Transfer-Encoding: deflate, chunked + var iter = mem.splitBackwardsScalar(u8, header_value, ','); + + const first = iter.first(); + const trimmed_first = mem.trim(u8, first, " "); + + var next: ?[]const u8 = first; + if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { + if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding + res.transfer_encoding = transfer; + + next = iter.next(); + } - parser: proto.HeadersParser, - compression: Compression = .none, + if (next) |second| { + const trimmed_second = mem.trim(u8, second, " "); - /// Whether the response body should be skipped. Any data read from the - /// response body will be discarded. - skip: bool = false, + if (http.ContentEncoding.fromString(trimmed_second)) |transfer| { + if (res.content_encoding != .identity) return error.HttpHeadersInvalid; // double compression is not supported + res.content_encoding = transfer; + } else { + return error.HttpTransferEncodingUnsupported; + } + } - pub const ParseError = error{ - HttpHeadersInvalid, - HttpHeaderContinuationsUnsupported, - HttpTransferEncodingUnsupported, - HttpConnectionHeaderUnsupported, - InvalidContentLength, - CompressionUnsupported, - }; + if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { + const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; - pub fn parse(res: *Response, bytes: []const u8) ParseError!void { - var it = mem.splitSequence(u8, bytes, "\r\n"); + if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; - const first_line = it.next().?; - if (first_line.len < 12) { - return error.HttpHeadersInvalid; - } + res.content_length = content_length; + } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { + if (res.content_encoding != .identity) return error.HttpHeadersInvalid; - const version: http.Version = switch (int64(first_line[0..8])) { - int64("HTTP/1.0") => .@"HTTP/1.0", - int64("HTTP/1.1") => .@"HTTP/1.1", - else => return error.HttpHeadersInvalid, - }; - if (first_line[8] != ' ') return error.HttpHeadersInvalid; - const status: http.Status = @enumFromInt(parseInt3(first_line[9..12])); - const reason = mem.trimStart(u8, first_line[12..], " "); - - res.version = version; - res.status = status; - res.reason = reason; - res.keep_alive = switch (version) { - .@"HTTP/1.0" => false, - .@"HTTP/1.1" => true, - }; + const trimmed = mem.trim(u8, header_value, " "); - while (it.next()) |line| { - if (line.len == 0) return; - switch (line[0]) { - ' ', '\t' => return error.HttpHeaderContinuationsUnsupported, - else => {}, - } - - var line_it = mem.splitScalar(u8, line, ':'); - const header_name = line_it.next().?; - const header_value = mem.trim(u8, line_it.rest(), " \t"); - if (header_name.len == 0) return error.HttpHeadersInvalid; - - if (std.ascii.eqlIgnoreCase(header_name, "connection")) { - res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close"); - } else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) { - res.content_type = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "location")) { - res.location = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) { - res.content_disposition = header_value; - } else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) { - // Transfer-Encoding: second, first - // Transfer-Encoding: deflate, chunked - var iter = mem.splitBackwardsScalar(u8, header_value, ','); - - const first = iter.first(); - const trimmed_first = mem.trim(u8, first, " "); - - var next: ?[]const u8 = first; - if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| { - if (res.transfer_encoding != .none) return error.HttpHeadersInvalid; // we already have a transfer encoding - res.transfer_encoding = transfer; - - next = iter.next(); - } - - if (next) |second| { - const trimmed_second = mem.trim(u8, second, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { - if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported - res.transfer_compression = transfer; + if (http.ContentEncoding.fromString(trimmed)) |ce| { + res.content_encoding = ce; } else { - return error.HttpTransferEncodingUnsupported; + return error.HttpContentEncodingUnsupported; } } - - if (iter.next()) |_| return error.HttpTransferEncodingUnsupported; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) { - const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength; - - if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid; - - res.content_length = content_length; - } else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) { - if (res.transfer_compression != .identity) return error.HttpHeadersInvalid; - - const trimmed = mem.trim(u8, header_value, " "); - - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { - res.transfer_compression = ce; - } else { - return error.HttpTransferEncodingUnsupported; - } } + return error.HttpHeadersInvalid; // missing empty line } - return error.HttpHeadersInvalid; // missing empty line - } - test parse { - const response_bytes = "HTTP/1.1 200 OK\r\n" ++ - "LOcation:url\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-disposition:attachment; filename=example.txt \r\n" ++ - "content-Length:10\r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - var header_buffer: [1024]u8 = undefined; - var res = Response{ - .status = undefined, - .reason = undefined, - .version = undefined, - .keep_alive = false, - .parser = .init(&header_buffer), - }; + test parse { + const response_bytes = "HTTP/1.1 200 OK\r\n" ++ + "LOcation:url\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-disposition:attachment; filename=example.txt \r\n" ++ + "content-Length:10\r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + const head = try Head.parse(response_bytes); + + try testing.expectEqual(.@"HTTP/1.1", head.version); + try testing.expectEqualStrings("OK", head.reason); + try testing.expectEqual(.ok, head.status); + + try testing.expectEqualStrings("url", head.location.?); + try testing.expectEqualStrings("text/plain", head.content_type.?); + try testing.expectEqualStrings("attachment; filename=example.txt", head.content_disposition.?); + + try testing.expectEqual(true, head.keep_alive); + try testing.expectEqual(10, head.content_length.?); + try testing.expectEqual(.chunked, head.transfer_encoding); + try testing.expectEqual(.deflate, head.content_encoding); + } - @memcpy(header_buffer[0..response_bytes.len], response_bytes); - res.parser.header_bytes_len = response_bytes.len; + pub fn iterateHeaders(h: Head) http.HeaderIterator { + return .init(h.bytes); + } - try res.parse(response_bytes); + test iterateHeaders { + const response_bytes = "HTTP/1.1 200 OK\r\n" ++ + "LOcation:url\r\n" ++ + "content-tYpe: text/plain\r\n" ++ + "content-disposition:attachment; filename=example.txt \r\n" ++ + "content-Length:10\r\n" ++ + "TRansfer-encoding:\tdeflate, chunked \r\n" ++ + "connectioN:\t keep-alive \r\n\r\n"; + + const head = try Head.parse(response_bytes); + var it = head.iterateHeaders(); + { + const header = it.next().?; + try testing.expectEqualStrings("LOcation", header.name); + try testing.expectEqualStrings("url", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-tYpe", header.name); + try testing.expectEqualStrings("text/plain", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-disposition", header.name); + try testing.expectEqualStrings("attachment; filename=example.txt", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("content-Length", header.name); + try testing.expectEqualStrings("10", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("TRansfer-encoding", header.name); + try testing.expectEqualStrings("deflate, chunked", header.value); + try testing.expect(!it.is_trailer); + } + { + const header = it.next().?; + try testing.expectEqualStrings("connectioN", header.name); + try testing.expectEqualStrings("keep-alive", header.value); + try testing.expect(!it.is_trailer); + } + try testing.expectEqual(null, it.next()); + } - try testing.expectEqual(.@"HTTP/1.1", res.version); - try testing.expectEqualStrings("OK", res.reason); - try testing.expectEqual(.ok, res.status); + inline fn int64(array: *const [8]u8) u64 { + return @bitCast(array.*); + } - try testing.expectEqualStrings("url", res.location.?); - try testing.expectEqualStrings("text/plain", res.content_type.?); - try testing.expectEqualStrings("attachment; filename=example.txt", res.content_disposition.?); + fn parseInt3(text: *const [3]u8) u10 { + const nnn: @Vector(3, u8) = text.*; + const zero: @Vector(3, u8) = .{ '0', '0', '0' }; + const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; + return @reduce(.Add, (nnn -% zero) *% mmm); + } - try testing.expectEqual(true, res.keep_alive); - try testing.expectEqual(10, res.content_length.?); - try testing.expectEqual(.chunked, res.transfer_encoding); - try testing.expectEqual(.deflate, res.transfer_compression); - } + test parseInt3 { + const expectEqual = testing.expectEqual; + try expectEqual(@as(u10, 0), parseInt3("000")); + try expectEqual(@as(u10, 418), parseInt3("418")); + try expectEqual(@as(u10, 999), parseInt3("999")); + } - inline fn int64(array: *const [8]u8) u64 { - return @bitCast(array.*); - } + /// Help the programmer avoid bugs by calling this when the string + /// memory of `Head` becomes invalidated. + fn invalidateStrings(h: *Head) void { + h.bytes = undefined; + h.reason = undefined; + if (h.location) |*s| s.* = undefined; + if (h.content_type) |*s| s.* = undefined; + if (h.content_disposition) |*s| s.* = undefined; + } + }; - fn parseInt3(text: *const [3]u8) u10 { - const nnn: @Vector(3, u8) = text.*; - const zero: @Vector(3, u8) = .{ '0', '0', '0' }; - const mmm: @Vector(3, u10) = .{ 100, 10, 1 }; - return @reduce(.Add, (nnn -% zero) *% mmm); + /// If compressed body has been negotiated this will return compressed bytes. + /// + /// If the returned `Reader` returns `error.ReadFailed` the error is + /// available via `bodyErr`. + /// + /// Asserts that this function is only called once. + /// + /// See also: + /// * `readerDecompressing` + pub fn reader(response: *Response, buffer: []u8) *Reader { + response.head.invalidateStrings(); + const req = response.request; + if (!req.method.responseHasBody()) return .ending; + const head = &response.head; + return req.reader.bodyReader(buffer, head.transfer_encoding, head.content_length); } - test parseInt3 { - const expectEqual = testing.expectEqual; - try expectEqual(@as(u10, 0), parseInt3("000")); - try expectEqual(@as(u10, 418), parseInt3("418")); - try expectEqual(@as(u10, 999), parseInt3("999")); + /// If compressed body has been negotiated this will return decompressed bytes. + /// + /// If the returned `Reader` returns `error.ReadFailed` the error is + /// available via `bodyErr`. + /// + /// Asserts that this function is only called once. + /// + /// See also: + /// * `reader` + pub fn readerDecompressing( + response: *Response, + decompressor: *http.Decompressor, + decompression_buffer: []u8, + ) *Reader { + response.head.invalidateStrings(); + const head = &response.head; + return response.request.reader.bodyReaderDecompressing( + head.transfer_encoding, + head.content_length, + head.content_encoding, + decompressor, + decompression_buffer, + ); } - pub fn iterateHeaders(r: Response) http.HeaderIterator { - return .init(r.parser.get()); + /// After receiving `error.ReadFailed` from the `Reader` returned by + /// `reader` or `readerDecompressing`, this function accesses the + /// more specific error code. + pub fn bodyErr(response: *const Response) ?http.Reader.BodyError { + return response.request.reader.body_err; } - test iterateHeaders { - const response_bytes = "HTTP/1.1 200 OK\r\n" ++ - "LOcation:url\r\n" ++ - "content-tYpe: text/plain\r\n" ++ - "content-disposition:attachment; filename=example.txt \r\n" ++ - "content-Length:10\r\n" ++ - "TRansfer-encoding:\tdeflate, chunked \r\n" ++ - "connectioN:\t keep-alive \r\n\r\n"; - - var header_buffer: [1024]u8 = undefined; - var res = Response{ - .status = undefined, - .reason = undefined, - .version = undefined, - .keep_alive = false, - .parser = .init(&header_buffer), + pub fn iterateTrailers(response: *const Response) http.HeaderIterator { + const r = &response.request.reader; + assert(r.state == .ready); + return .{ + .bytes = r.trailers, + .index = 0, + .is_trailer = true, }; - - @memcpy(header_buffer[0..response_bytes.len], response_bytes); - res.parser.header_bytes_len = response_bytes.len; - - var it = res.iterateHeaders(); - { - const header = it.next().?; - try testing.expectEqualStrings("LOcation", header.name); - try testing.expectEqualStrings("url", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-tYpe", header.name); - try testing.expectEqualStrings("text/plain", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-disposition", header.name); - try testing.expectEqualStrings("attachment; filename=example.txt", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("content-Length", header.name); - try testing.expectEqualStrings("10", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("TRansfer-encoding", header.name); - try testing.expectEqualStrings("deflate, chunked", header.value); - try testing.expect(!it.is_trailer); - } - { - const header = it.next().?; - try testing.expectEqualStrings("connectioN", header.name); - try testing.expectEqualStrings("keep-alive", header.value); - try testing.expect(!it.is_trailer); - } - try testing.expectEqual(null, it.next()); } }; -/// A HTTP request that has been sent. -/// -/// Order of operations: open -> send[ -> write -> finish] -> wait -> read pub const Request = struct { + /// This field is provided so that clients can observe redirected URIs. + /// + /// Its backing memory is externally provided by API users when creating a + /// request, and then again provided externally via `redirect_buffer` to + /// `receiveHead`. uri: Uri, client: *Client, /// This is null when the connection is released. connection: ?*Connection, + reader: http.Reader, keep_alive: bool, method: http.Method, version: http.Version = .@"HTTP/1.1", - transfer_encoding: RequestTransfer, + transfer_encoding: TransferEncoding, redirect_behavior: RedirectBehavior, + accept_encoding: @TypeOf(default_accept_encoding) = default_accept_encoding, /// Whether the request should handle a 100-continue response before sending the request body. handle_continue: bool, - /// The response associated with this request. - /// - /// This field is undefined until `wait` is called. - response: Response, - /// Standard headers that have default, but overridable, behavior. headers: Headers, @@ -703,6 +790,20 @@ pub const Request = struct { /// Externally-owned; must outlive the Request. privileged_headers: []const http.Header, + pub const default_accept_encoding: [@typeInfo(http.ContentEncoding).@"enum".fields.len]bool = b: { + var result: [@typeInfo(http.ContentEncoding).@"enum".fields.len]bool = @splat(false); + result[@intFromEnum(http.ContentEncoding.gzip)] = true; + result[@intFromEnum(http.ContentEncoding.deflate)] = true; + result[@intFromEnum(http.ContentEncoding.identity)] = true; + break :b result; + }; + + pub const TransferEncoding = union(enum) { + content_length: u64, + chunked: void, + none: void, + }; + pub const Headers = struct { host: Value = .default, authorization: Value = .default, @@ -728,6 +829,11 @@ pub const Request = struct { unhandled = std.math.maxInt(u16), _, + pub fn init(n: u16) RedirectBehavior { + assert(n != std.math.maxInt(u16)); + return @enumFromInt(n); + } + pub fn subtractOne(rb: *RedirectBehavior) void { switch (rb.*) { .not_allowed => unreachable, @@ -742,98 +848,110 @@ pub const Request = struct { } }; - /// Frees all resources associated with the request. - pub fn deinit(req: *Request) void { - if (req.connection) |connection| { - if (!req.response.parser.done) { - // If the response wasn't fully read, then we need to close the connection. - connection.closing = true; - } - req.client.connection_pool.release(req.client.allocator, connection); + /// Returns the request's `Connection` back to the pool of the `Client`. + pub fn deinit(r: *Request) void { + if (r.connection) |connection| { + connection.closing = connection.closing or switch (r.reader.state) { + .ready => false, + .received_head => r.method.requestHasBody(), + else => true, + }; + r.client.connection_pool.release(connection); } - req.* = undefined; + r.* = undefined; } - // This function must deallocate all resources associated with the request, - // or keep those which will be used. - // This needs to be kept in sync with deinit and request. - fn redirect(req: *Request, uri: Uri) !void { - assert(req.response.parser.done); - - req.client.connection_pool.release(req.client.allocator, req.connection.?); - req.connection = null; - - var server_header: std.heap.FixedBufferAllocator = .init(req.response.parser.header_bytes_buffer); - defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); - - const new_host = valid_uri.host.?.raw; - const prev_host = req.uri.host.?.raw; - const keep_privileged_headers = - std.ascii.eqlIgnoreCase(valid_uri.scheme, req.uri.scheme) and - std.ascii.endsWithIgnoreCase(new_host, prev_host) and - (new_host.len == prev_host.len or new_host[new_host.len - prev_host.len - 1] == '.'); - if (!keep_privileged_headers) { - // When redirecting to a different domain, strip privileged headers. - req.privileged_headers = &.{}; - } - - if (switch (req.response.status) { - .see_other => true, - .moved_permanently, .found => req.method == .POST, - else => false, - }) { - // A redirect to a GET must change the method and remove the body. - req.method = .GET; - req.transfer_encoding = .none; - req.headers.content_type = .omit; - } - - if (req.transfer_encoding != .none) { - // The request body has already been sent. The request is - // still in a valid state, but the redirect must be handled - // manually. - return error.RedirectRequiresResend; - } + /// Sends and flushes a complete request as only HTTP head, no body. + pub fn sendBodiless(r: *Request) Writer.Error!void { + try sendBodilessUnflushed(r); + try r.connection.?.flush(); + } - req.uri = valid_uri; - req.connection = try req.client.connect(new_host, uriPort(valid_uri, protocol), protocol); - req.redirect_behavior.subtractOne(); - req.response.parser.reset(); - - req.response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = req.response.parser, - }; + /// Sends but does not flush a complete request as only HTTP head, no body. + pub fn sendBodilessUnflushed(r: *Request) Writer.Error!void { + assert(r.transfer_encoding == .none); + assert(!r.method.requestHasBody()); + try sendHead(r); } - pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding }; + /// Transfers the HTTP head over the connection and flushes. + /// + /// See also: + /// * `sendBodyUnflushed` + pub fn sendBody(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter { + const result = try sendBodyUnflushed(r, buffer); + try r.connection.?.flush(); + return result; + } - /// Send the HTTP request headers to the server. - pub fn send(req: *Request) SendError!void { - if (!req.method.requestHasBody() and req.transfer_encoding != .none) - return error.UnsupportedTransferEncoding; + /// Transfers the HTTP head and body over the connection and flushes. + pub fn sendBodyComplete(r: *Request, body: []u8) Writer.Error!void { + r.transfer_encoding = .{ .content_length = body.len }; + var bw = try sendBodyUnflushed(r, body); + bw.writer.end = body.len; + try bw.end(); + try r.connection.?.flush(); + } - const connection = req.connection.?; - var connection_writer_adapter = connection.writer().adaptToNewApi(); - const w = &connection_writer_adapter.new_interface; - sendAdapted(req, connection, w) catch |err| switch (err) { - error.WriteFailed => return connection_writer_adapter.err.?, - else => |e| return e, + /// Transfers the HTTP head over the connection, which is not flushed until + /// `BodyWriter.flush` or `BodyWriter.end` is called. + /// + /// See also: + /// * `sendBody` + pub fn sendBodyUnflushed(r: *Request, buffer: []u8) Writer.Error!http.BodyWriter { + assert(r.method.requestHasBody()); + try sendHead(r); + const http_protocol_output = r.connection.?.writer(); + return switch (r.transfer_encoding) { + .chunked => .{ + .http_protocol_output = http_protocol_output, + .state = .{ .chunked = .init }, + .writer = .{ + .buffer = buffer, + .vtable = &.{ + .drain = http.BodyWriter.chunkedDrain, + .sendFile = http.BodyWriter.chunkedSendFile, + }, + }, + }, + .content_length => |len| .{ + .http_protocol_output = http_protocol_output, + .state = .{ .content_length = len }, + .writer = .{ + .buffer = buffer, + .vtable = &.{ + .drain = http.BodyWriter.contentLengthDrain, + .sendFile = http.BodyWriter.contentLengthSendFile, + }, + }, + }, + .none => .{ + .http_protocol_output = http_protocol_output, + .state = .none, + .writer = .{ + .buffer = buffer, + .vtable = &.{ + .drain = http.BodyWriter.noneDrain, + .sendFile = http.BodyWriter.noneSendFile, + }, + }, + }, }; } - fn sendAdapted(req: *Request, connection: *Connection, w: *std.io.Writer) !void { - try req.method.format(w); + /// Sends HTTP headers without flushing. + fn sendHead(r: *Request) Writer.Error!void { + const uri = r.uri; + const connection = r.connection.?; + const w = connection.writer(); + + try w.writeAll(@tagName(r.method)); try w.writeByte(' '); - if (req.method == .CONNECT) { - try req.uri.writeToStream(w, .{ .authority = true }); + if (r.method == .CONNECT) { + try uri.writeToStream(w, .{ .authority = true }); } else { - try req.uri.writeToStream(w, .{ + try uri.writeToStream(w, .{ .scheme = connection.proxied, .authentication = connection.proxied, .authority = connection.proxied, @@ -842,58 +960,64 @@ pub const Request = struct { }); } try w.writeByte(' '); - try w.writeAll(@tagName(req.version)); + try w.writeAll(@tagName(r.version)); try w.writeAll("\r\n"); - if (try emitOverridableHeader("host: ", req.headers.host, w)) { + if (try emitOverridableHeader("host: ", r.headers.host, w)) { try w.writeAll("host: "); - try req.uri.writeToStream(w, .{ .authority = true }); + try uri.writeToStream(w, .{ .authority = true }); try w.writeAll("\r\n"); } - if (try emitOverridableHeader("authorization: ", req.headers.authorization, w)) { - if (req.uri.user != null or req.uri.password != null) { + if (try emitOverridableHeader("authorization: ", r.headers.authorization, w)) { + if (uri.user != null or uri.password != null) { try w.writeAll("authorization: "); - const authorization = try connection.allocWriteBuffer( - @intCast(basic_authorization.valueLengthFromUri(req.uri)), - ); - assert(basic_authorization.value(req.uri, authorization).len == authorization.len); + try basic_authorization.write(uri, w); try w.writeAll("\r\n"); } } - if (try emitOverridableHeader("user-agent: ", req.headers.user_agent, w)) { + if (try emitOverridableHeader("user-agent: ", r.headers.user_agent, w)) { try w.writeAll("user-agent: zig/"); try w.writeAll(builtin.zig_version_string); try w.writeAll(" (std.http)\r\n"); } - if (try emitOverridableHeader("connection: ", req.headers.connection, w)) { - if (req.keep_alive) { + if (try emitOverridableHeader("connection: ", r.headers.connection, w)) { + if (r.keep_alive) { try w.writeAll("connection: keep-alive\r\n"); } else { try w.writeAll("connection: close\r\n"); } } - if (try emitOverridableHeader("accept-encoding: ", req.headers.accept_encoding, w)) { - // https://github.com/ziglang/zig/issues/18937 - //try w.writeAll("accept-encoding: gzip, deflate, zstd\r\n"); - try w.writeAll("accept-encoding: gzip, deflate\r\n"); + if (try emitOverridableHeader("accept-encoding: ", r.headers.accept_encoding, w)) { + try w.writeAll("accept-encoding: "); + for (r.accept_encoding, 0..) |enabled, i| { + if (!enabled) continue; + const tag: http.ContentEncoding = @enumFromInt(i); + if (tag == .identity) continue; + const tag_name = @tagName(tag); + try w.ensureUnusedCapacity(tag_name.len + 2); + try w.writeAll(tag_name); + try w.writeAll(", "); + } + w.undo(2); + try w.writeAll("\r\n"); } - switch (req.transfer_encoding) { + switch (r.transfer_encoding) { .chunked => try w.writeAll("transfer-encoding: chunked\r\n"), .content_length => |len| try w.print("content-length: {d}\r\n", .{len}), .none => {}, } - if (try emitOverridableHeader("content-type: ", req.headers.content_type, w)) { + if (try emitOverridableHeader("content-type: ", r.headers.content_type, w)) { // The default is to omit content-type if not provided because // "application/octet-stream" is redundant. } - for (req.extra_headers) |header| { + for (r.extra_headers) |header| { assert(header.name.len != 0); try w.writeAll(header.name); @@ -904,8 +1028,8 @@ pub const Request = struct { if (connection.proxied) proxy: { const proxy = switch (connection.protocol) { - .plain => req.client.http_proxy, - .tls => req.client.https_proxy, + .plain => r.client.http_proxy, + .tls => r.client.https_proxy, } orelse break :proxy; const authorization = proxy.authorization orelse break :proxy; @@ -915,282 +1039,200 @@ pub const Request = struct { } try w.writeAll("\r\n"); - - try connection.flush(); - } - - /// Returns true if the default behavior is required, otherwise handles - /// writing (or not writing) the header. - fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, w: anytype) !bool { - switch (v) { - .default => return true, - .omit => return false, - .override => |x| { - try w.writeAll(prefix); - try w.writeAll(x); - try w.writeAll("\r\n"); - return false; - }, - } } - const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError; - - const TransferReader = std.io.GenericReader(*Request, TransferReadError, transferRead); - - fn transferReader(req: *Request) TransferReader { - return .{ .context = req }; - } - - fn transferRead(req: *Request, buf: []u8) TransferReadError!usize { - if (req.response.parser.done) return 0; - - var index: usize = 0; - while (index == 0) { - const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip); - if (amt == 0 and req.response.parser.done) break; - index += amt; - } - - return index; - } + pub const ReceiveHeadError = http.Reader.HeadError || ConnectError || error{ + /// Server sent headers that did not conform to the HTTP protocol. + /// + /// To find out more detailed diagnostics, `http.Reader.head_buffer` can be + /// passed directly to `Request.Head.parse`. + HttpHeadersInvalid, + TooManyHttpRedirects, + /// This can be avoided by calling `receiveHead` before sending the + /// request body. + RedirectRequiresResend, + HttpRedirectLocationMissing, + HttpRedirectLocationOversize, + HttpRedirectLocationInvalid, + HttpContentEncodingUnsupported, + HttpChunkInvalid, + HttpChunkTruncated, + HttpHeadersOversize, + UnsupportedUriScheme, - pub const WaitError = RequestError || SendError || TransferReadError || - proto.HeadersParser.CheckCompleteHeadError || Response.ParseError || - error{ - TooManyHttpRedirects, - RedirectRequiresResend, - HttpRedirectLocationMissing, - HttpRedirectLocationInvalid, - CompressionInitializationFailed, - CompressionUnsupported, - }; + /// Sending the request failed. Error code can be found on the + /// `Connection` object. + WriteFailed, + }; - /// Waits for a response from the server and parses any headers that are sent. - /// This function will block until the final response is received. - /// /// If handling redirects and the request has no payload, then this - /// function will automatically follow redirects. If a request payload is - /// present, then this function will error with - /// error.RedirectRequiresResend. + /// function will automatically follow redirects. + /// + /// If a request payload is present, then this function will error with + /// `error.RedirectRequiresResend`. + /// + /// This function takes an auxiliary buffer to store the arbitrarily large + /// URI which may need to be merged with the previous URI, and that data + /// needs to survive across different connections, which is where the input + /// buffer lives. /// - /// Must be called after `send` and, if any data was written to the request - /// body, then also after `finish`. - pub fn wait(req: *Request) WaitError!void { + /// `redirect_buffer` must outlive accesses to `Request.uri`. If this + /// buffer capacity would be exceeded, `error.HttpRedirectLocationOversize` + /// is returned instead. This buffer may be empty if no redirects are to be + /// handled. + /// + /// If this fails with `error.ReadFailed` then the `Connection.getReadError` + /// method of `r.connection` can be used to get more detailed information. + pub fn receiveHead(r: *Request, redirect_buffer: []u8) ReceiveHeadError!Response { + var aux_buf = redirect_buffer; while (true) { - // This while loop is for handling redirects, which means the request's - // connection may be different than the previous iteration. However, it - // is still guaranteed to be non-null with each iteration of this loop. - const connection = req.connection.?; - - while (true) { // read headers - try connection.fill(); - - const nchecked = try req.response.parser.checkCompleteHead(connection.peek()); - connection.drop(@intCast(nchecked)); + const head_buffer = try r.reader.receiveHead(); + const response: Response = .{ + .request = r, + .head = Response.Head.parse(head_buffer) catch return error.HttpHeadersInvalid, + }; + const head = &response.head; - if (req.response.parser.state.isContent()) break; + if (head.status == .@"continue") { + if (r.handle_continue) continue; + return response; // we're not handling the 100-continue } - try req.response.parse(req.response.parser.get()); - - if (req.response.status == .@"continue") { - // We're done parsing the continue response; reset to prepare - // for the real response. - req.response.parser.done = true; - req.response.parser.reset(); - - if (req.handle_continue) - continue; - - return; // we're not handling the 100-continue - } + // This while loop is for handling redirects, which means the request's + // connection may be different than the previous iteration. However, it + // is still guaranteed to be non-null with each iteration of this loop. + const connection = r.connection.?; - // we're switching protocols, so this connection is no longer doing http - if (req.method == .CONNECT and req.response.status.class() == .success) { + if (r.method == .CONNECT and head.status.class() == .success) { + // This connection is no longer doing HTTP. connection.closing = false; - req.response.parser.done = true; - return; // the connection is not HTTP past this point + return response; } - connection.closing = !req.response.keep_alive or !req.keep_alive; + connection.closing = !head.keep_alive or !r.keep_alive; // Any response to a HEAD request and any response with a 1xx // (Informational), 204 (No Content), or 304 (Not Modified) status // code is always terminated by the first empty line after the // header fields, regardless of the header fields present in the // message. - if (req.method == .HEAD or req.response.status.class() == .informational or - req.response.status == .no_content or req.response.status == .not_modified) + if (r.method == .HEAD or head.status.class() == .informational or + head.status == .no_content or head.status == .not_modified) { - req.response.parser.done = true; - return; // The response is empty; no further setup or redirection is necessary. - } - - switch (req.response.transfer_encoding) { - .none => { - if (req.response.content_length) |cl| { - req.response.parser.next_chunk_length = cl; - - if (cl == 0) req.response.parser.done = true; - } else { - // read until the connection is closed - req.response.parser.next_chunk_length = std.math.maxInt(u64); - } - }, - .chunked => { - req.response.parser.next_chunk_length = 0; - req.response.parser.state = .chunk_head_size; - }, + return response; } - if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) { - // skip the body of the redirect response, this will at least - // leave the connection in a known good state. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // we're skipping, no buffer is necessary - - if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects; - - const location = req.response.location orelse - return error.HttpRedirectLocationMissing; - - // This mutates the beginning of header_bytes_buffer and uses that - // for the backing memory of the returned Uri. - try req.redirect(req.uri.resolve_inplace( - location, - &req.response.parser.header_bytes_buffer, - ) catch |err| switch (err) { - error.UnexpectedCharacter, - error.InvalidFormat, - error.InvalidPort, - => return error.HttpRedirectLocationInvalid, - error.NoSpaceLeft => return error.HttpHeadersOversize, - }); - try req.send(); - } else { - req.response.skip = false; - if (!req.response.parser.done) { - switch (req.response.transfer_compression) { - .identity => req.response.compression = .none, - .compress, .@"x-compress" => return error.CompressionUnsupported, - // I'm about to upstream my http.Client rewrite - .deflate => return error.CompressionUnsupported, - // I'm about to upstream my http.Client rewrite - .gzip, .@"x-gzip" => return error.CompressionUnsupported, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => req.response.compression = .{ - // .zstd = std.compress.zstd.decompressStream(req.client.allocator, req.transferReader()), - //}, - .zstd => return error.CompressionUnsupported, - } + if (head.status.class() == .redirect and r.redirect_behavior != .unhandled) { + if (r.redirect_behavior == .not_allowed) { + // Connection can still be reused by skipping the body. + const reader = r.reader.bodyReader(&.{}, head.transfer_encoding, head.content_length); + _ = reader.discardRemaining() catch |err| switch (err) { + error.ReadFailed => connection.closing = true, + }; + return error.TooManyHttpRedirects; } - - break; + try r.redirect(head, &aux_buf); + try r.sendBodiless(); + continue; } - } - } - - pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError || - error{ DecompressionFailure, InvalidTrailers }; - pub const Reader = std.io.GenericReader(*Request, ReadError, read); + if (!r.accept_encoding[@intFromEnum(head.content_encoding)]) + return error.HttpContentEncodingUnsupported; - pub fn reader(req: *Request) Reader { - return .{ .context = req }; - } - - /// Reads data from the response body. Must be called after `wait`. - pub fn read(req: *Request, buffer: []u8) ReadError!usize { - const out_index = switch (req.response.compression) { - // I'm about to upstream my http client rewrite - //.deflate => |*deflate| deflate.readSlice(buffer) catch return error.DecompressionFailure, - //.gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure, - // https://github.com/ziglang/zig/issues/18937 - //.zstd => |*zstd| zstd.read(buffer) catch return error.DecompressionFailure, - else => try req.transferRead(buffer), - }; - if (out_index > 0) return out_index; - - while (!req.response.parser.state.isContent()) { // read trailing headers - try req.connection.?.fill(); - - const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek()); - req.connection.?.drop(@intCast(nchecked)); + return response; } - - return 0; } - /// Reads data from the response body. Must be called after `wait`. - pub fn readAll(req: *Request, buffer: []u8) !usize { - var index: usize = 0; - while (index < buffer.len) { - const amt = try read(req, buffer[index..]); - if (amt == 0) break; - index += amt; + /// This function takes an auxiliary buffer to store the arbitrarily large + /// URI which may need to be merged with the previous URI, and that data + /// needs to survive across different connections, which is where the input + /// buffer lives. + /// + /// `aux_buf` must outlive accesses to `Request.uri`. + fn redirect(r: *Request, head: *const Response.Head, aux_buf: *[]u8) !void { + const new_location = head.location orelse return error.HttpRedirectLocationMissing; + if (new_location.len > aux_buf.*.len) return error.HttpRedirectLocationOversize; + const location = aux_buf.*[0..new_location.len]; + @memcpy(location, new_location); + { + // Skip the body of the redirect response to leave the connection in + // the correct state. This causes `new_location` to be invalidated. + const reader = r.reader.bodyReader(&.{}, head.transfer_encoding, head.content_length); + _ = reader.discardRemaining() catch |err| switch (err) { + error.ReadFailed => return r.reader.body_err.?, + }; } - return index; - } - - pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong }; - - pub const Writer = std.io.GenericWriter(*Request, WriteError, write); + const new_uri = r.uri.resolveInPlace(location.len, aux_buf) catch |err| switch (err) { + error.UnexpectedCharacter => return error.HttpRedirectLocationInvalid, + error.InvalidFormat => return error.HttpRedirectLocationInvalid, + error.InvalidPort => return error.HttpRedirectLocationInvalid, + error.NoSpaceLeft => return error.HttpRedirectLocationOversize, + }; - pub fn writer(req: *Request) Writer { - return .{ .context = req }; - } + const protocol = Protocol.fromUri(new_uri) orelse return error.UnsupportedUriScheme; + const old_connection = r.connection.?; + const old_host = old_connection.host(); + var new_host_name_buffer: [Uri.host_name_max]u8 = undefined; + const new_host = try new_uri.getHost(&new_host_name_buffer); + const keep_privileged_headers = + std.ascii.eqlIgnoreCase(r.uri.scheme, new_uri.scheme) and + sameParentDomain(old_host, new_host); - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn write(req: *Request, bytes: []const u8) WriteError!usize { - switch (req.transfer_encoding) { - .chunked => { - if (bytes.len > 0) { - try req.connection.?.writer().print("{x}\r\n", .{bytes.len}); - try req.connection.?.writer().writeAll(bytes); - try req.connection.?.writer().writeAll("\r\n"); - } + r.client.connection_pool.release(old_connection); + r.connection = null; - return bytes.len; - }, - .content_length => |*len| { - if (len.* < bytes.len) return error.MessageTooLong; + if (!keep_privileged_headers) { + // When redirecting to a different domain, strip privileged headers. + r.privileged_headers = &.{}; + } - const amt = try req.connection.?.write(bytes); - len.* -= amt; - return amt; - }, - .none => return error.NotWriteable, + if (switch (head.status) { + .see_other => true, + .moved_permanently, .found => r.method == .POST, + else => false, + }) { + // A redirect to a GET must change the method and remove the body. + r.method = .GET; + r.transfer_encoding = .none; + r.headers.content_type = .omit; } - } - /// Write `bytes` to the server. The `transfer_encoding` field determines how data will be sent. - /// Must be called after `send` and before `finish`. - pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(req, bytes[index..]); + if (r.transfer_encoding != .none) { + // The request body has already been sent. The request is + // still in a valid state, but the redirect must be handled + // manually. + return error.RedirectRequiresResend; } - } - pub const FinishError = WriteError || error{MessageNotCompleted}; + const new_connection = try r.client.connect(new_host, uriPort(new_uri, protocol), protocol); + r.uri = new_uri; + r.connection = new_connection; + r.reader = .{ + .in = new_connection.reader(), + .state = .ready, + // Populated when `http.Reader.bodyReader` is called. + .interface = undefined, + }; + r.redirect_behavior.subtractOne(); + } - /// Finish the body of a request. This notifies the server that you have no more data to send. - /// Must be called after `send`. - pub fn finish(req: *Request) FinishError!void { - switch (req.transfer_encoding) { - .chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"), - .content_length => |len| if (len != 0) return error.MessageNotCompleted, - .none => {}, + /// Returns true if the default behavior is required, otherwise handles + /// writing (or not writing) the header. + fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, bw: *Writer) Writer.Error!bool { + switch (v) { + .default => return true, + .omit => return false, + .override => |x| { + var vecs: [3][]const u8 = .{ prefix, x, "\r\n" }; + try bw.writeVecAll(&vecs); + return false; + }, } - - try req.connection.?.flush(); } }; pub const Proxy = struct { - protocol: Connection.Protocol, + protocol: Protocol, host: []const u8, authorization: ?[]const u8, port: u16, @@ -1204,10 +1246,8 @@ pub const Proxy = struct { pub fn deinit(client: *Client) void { assert(client.connection_pool.used.first == null); // There are still active requests. - client.connection_pool.deinit(client.allocator); - - if (!disable_tls) - client.ca_bundle.deinit(client.allocator); + client.connection_pool.deinit(); + if (!disable_tls) client.ca_bundle.deinit(client.allocator); client.* = undefined; } @@ -1249,24 +1289,21 @@ fn createProxyFromEnvVar(arena: Allocator, env_var_names: []const []const u8) !? } else return null; const uri = Uri.parse(content) catch try Uri.parseAfterScheme("http", content); - const protocol, const valid_uri = validateUri(uri, arena) catch |err| switch (err) { - error.UnsupportedUriScheme => return null, - error.UriMissingHost => return error.HttpProxyMissingHost, - error.OutOfMemory => |e| return e, - }; + const protocol = Protocol.fromUri(uri) orelse return null; + const raw_host = try uri.getHostAlloc(arena); - const authorization: ?[]const u8 = if (valid_uri.user != null or valid_uri.password != null) a: { - const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(valid_uri)); - assert(basic_authorization.value(valid_uri, authorization).len == authorization.len); + const authorization: ?[]const u8 = if (uri.user != null or uri.password != null) a: { + const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(uri)); + assert(basic_authorization.value(uri, authorization).len == authorization.len); break :a authorization; } else null; const proxy = try arena.create(Proxy); proxy.* = .{ .protocol = protocol, - .host = valid_uri.host.?.raw, + .host = raw_host, .authorization = authorization, - .port = uriPort(valid_uri, protocol), + .port = uriPort(uri, protocol), .supports_connect = true, }; return proxy; @@ -1277,10 +1314,8 @@ pub const basic_authorization = struct { pub const max_password_len = 255; pub const max_value_len = valueLength(max_user_len, max_password_len); - const prefix = "Basic "; - pub fn valueLength(user_len: usize, password_len: usize) usize { - return prefix.len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); + return "Basic ".len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len); } pub fn valueLengthFromUri(uri: Uri) usize { @@ -1300,37 +1335,70 @@ pub const basic_authorization = struct { } pub fn value(uri: Uri, out: []u8) []u8 { - const user: Uri.Component = uri.user orelse .empty; - const password: Uri.Component = uri.password orelse .empty; - - var buf: [max_user_len + ":".len + max_password_len]u8 = undefined; - var w: std.io.Writer = .fixed(&buf); - user.formatUser(&w) catch unreachable; // fixed - password.formatPassword(&w) catch unreachable; // fixed + var bw: Writer = .fixed(out); + write(uri, &bw) catch unreachable; + return bw.buffered(); + } - @memcpy(out[0..prefix.len], prefix); - const base64 = std.base64.standard.Encoder.encode(out[prefix.len..], w.buffered()); - return out[0 .. prefix.len + base64.len]; + pub fn write(uri: Uri, out: *Writer) Writer.Error!void { + var buf: [max_user_len + 1 + max_password_len]u8 = undefined; + var w: Writer = .fixed(&buf); + const user: Uri.Component = uri.user orelse .empty; + const password: Uri.Component = uri.user orelse .empty; + user.formatUser(&w) catch unreachable; + w.writeByte(':') catch unreachable; + password.formatPassword(&w) catch unreachable; + try out.print("Basic {b64}", .{w.buffered()}); } }; -pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed }; +pub const ConnectTcpError = Allocator.Error || error{ + ConnectionRefused, + NetworkUnreachable, + ConnectionTimedOut, + ConnectionResetByPeer, + TemporaryNameServerFailure, + NameServerFailure, + UnknownHostName, + HostLacksNetworkAddresses, + UnexpectedConnectFailure, + TlsInitializationFailed, +}; -/// Connect to `host:port` using the specified protocol. This will reuse a connection if one is already open. +/// Reuses a `Connection` if one matching `host` and `port` is already open. /// -/// This function is threadsafe. -pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection { - if (client.connection_pool.findConnection(.{ - .host = host, - .port = port, - .protocol = protocol, - })) |node| return node; +/// Threadsafe. +pub fn connectTcp( + client: *Client, + host: []const u8, + port: u16, + protocol: Protocol, +) ConnectTcpError!*Connection { + return connectTcpOptions(client, .{ .host = host, .port = port, .protocol = protocol }); +} + +pub const ConnectTcpOptions = struct { + host: []const u8, + port: u16, + protocol: Protocol, - if (disable_tls and protocol == .tls) - return error.TlsInitializationFailed; + proxied_host: ?[]const u8 = null, + proxied_port: ?u16 = null, +}; - const conn = try client.allocator.create(Connection); - errdefer client.allocator.destroy(conn); +pub fn connectTcpOptions(client: *Client, options: ConnectTcpOptions) ConnectTcpError!*Connection { + const host = options.host; + const port = options.port; + const protocol = options.protocol; + + const proxied_host = options.proxied_host orelse host; + const proxied_port = options.proxied_port orelse port; + + if (client.connection_pool.findConnection(.{ + .host = proxied_host, + .port = proxied_port, + .protocol = protocol, + })) |conn| return conn; const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) { error.ConnectionRefused => return error.ConnectionRefused, @@ -1345,53 +1413,19 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec }; errdefer stream.close(); - conn.* = .{ - .stream = stream, - .tls_client = undefined, - - .protocol = protocol, - .host = try client.allocator.dupe(u8, host), - .port = port, - - .pool_node = .{}, - }; - errdefer client.allocator.free(conn.host); - - if (protocol == .tls) { - if (disable_tls) unreachable; - - conn.tls_client = try client.allocator.create(std.crypto.tls.Client); - errdefer client.allocator.destroy(conn.tls_client); - - const ssl_key_log_file: ?std.fs.File = if (std.options.http_enable_ssl_key_log_file) ssl_key_log_file: { - const ssl_key_log_path = std.process.getEnvVarOwned(client.allocator, "SSLKEYLOGFILE") catch |err| switch (err) { - error.EnvironmentVariableNotFound, error.InvalidWtf8 => break :ssl_key_log_file null, - error.OutOfMemory => return error.OutOfMemory, - }; - defer client.allocator.free(ssl_key_log_path); - break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{ - .truncate = false, - .mode = switch (builtin.os.tag) { - .windows, .wasi => 0, - else => 0o600, - }, - }) catch null; - } else null; - errdefer if (ssl_key_log_file) |key_log_file| key_log_file.close(); - - conn.tls_client.* = std.crypto.tls.Client.init(stream, .{ - .host = .{ .explicit = host }, - .ca = .{ .bundle = client.ca_bundle }, - .ssl_key_log_file = ssl_key_log_file, - }) catch return error.TlsInitializationFailed; - // This is appropriate for HTTPS because the HTTP headers contain - // the content length which is used to detect truncation attacks. - conn.tls_client.allow_truncation_attacks = true; + switch (protocol) { + .tls => { + if (disable_tls) return error.TlsInitializationFailed; + const tc = try Connection.Tls.create(client, proxied_host, proxied_port, stream); + client.connection_pool.addUsed(&tc.connection); + return &tc.connection; + }, + .plain => { + const pc = try Connection.Plain.create(client, proxied_host, proxied_port, stream); + client.connection_pool.addUsed(&pc.connection); + return &pc.connection; + }, } - - client.connection_pool.addUsed(conn); - - return conn; } pub const ConnectUnixError = Allocator.Error || std.posix.SocketError || error{NameTooLong} || std.posix.ConnectError; @@ -1429,69 +1463,67 @@ pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connecti return &conn.data; } -/// Connect to `tunnel_host:tunnel_port` using the specified proxy with HTTP +/// Connect to `proxied_host:proxied_port` using the specified proxy with HTTP /// CONNECT. This will reuse a connection if one is already open. /// /// This function is threadsafe. -pub fn connectTunnel( +pub fn connectProxied( client: *Client, proxy: *Proxy, - tunnel_host: []const u8, - tunnel_port: u16, + proxied_host: []const u8, + proxied_port: u16, ) !*Connection { if (!proxy.supports_connect) return error.TunnelNotSupported; if (client.connection_pool.findConnection(.{ - .host = tunnel_host, - .port = tunnel_port, + .host = proxied_host, + .port = proxied_port, .protocol = proxy.protocol, - })) |node| - return node; + })) |node| return node; var maybe_valid = false; (tunnel: { - const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + const connection = try client.connectTcpOptions(.{ + .host = proxy.host, + .port = proxy.port, + .protocol = proxy.protocol, + .proxied_host = proxied_host, + .proxied_port = proxied_port, + }); errdefer { - conn.closing = true; - client.connection_pool.release(client.allocator, conn); + connection.closing = true; + client.connection_pool.release(connection); } - var buffer: [8096]u8 = undefined; - var req = client.open(.CONNECT, .{ + var req = client.request(.CONNECT, .{ .scheme = "http", - .host = .{ .raw = tunnel_host }, - .port = tunnel_port, + .host = .{ .raw = proxied_host }, + .port = proxied_port, }, .{ .redirect_behavior = .unhandled, - .connection = conn, - .server_header_buffer = &buffer, + .connection = connection, }) catch |err| { - std.log.debug("err {}", .{err}); break :tunnel err; }; defer req.deinit(); - req.send() catch |err| break :tunnel err; - req.wait() catch |err| break :tunnel err; + req.sendBodiless() catch |err| break :tunnel err; + const response = req.receiveHead(&.{}) catch |err| break :tunnel err; - if (req.response.status.class() == .server_error) { + if (response.head.status.class() == .server_error) { maybe_valid = true; break :tunnel error.ServerError; } - if (req.response.status != .ok) break :tunnel error.ConnectionRefused; + if (response.head.status != .ok) break :tunnel error.ConnectionRefused; - // this connection is now a tunnel, so we can't use it for anything else, it will only be released when the client is de-initialized. + // this connection is now a tunnel, so we can't use it for anything + // else, it will only be released when the client is de-initialized. req.connection = null; - client.allocator.free(conn.host); - conn.host = try client.allocator.dupe(u8, tunnel_host); - errdefer client.allocator.free(conn.host); + connection.closing = false; - conn.port = tunnel_port; - conn.closing = false; - - return conn; + return connection; }) catch { // something went wrong with the tunnel proxy.supports_connect = maybe_valid; @@ -1499,12 +1531,11 @@ pub fn connectTunnel( }; } -// Prevents a dependency loop in open() -const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUriScheme, ConnectionRefused }; -pub const ConnectError = ConnectErrorPartial || RequestError; +pub const ConnectError = ConnectTcpError || RequestError; /// Connect to `host:port` using the specified protocol. This will reuse a /// connection if one is already open. +/// /// If a proxy is configured for the client, then the proxy will be used to /// connect to the host. /// @@ -1513,7 +1544,7 @@ pub fn connect( client: *Client, host: []const u8, port: u16, - protocol: Connection.Protocol, + protocol: Protocol, ) ConnectError!*Connection { const proxy = switch (protocol) { .plain => client.http_proxy, @@ -1528,32 +1559,24 @@ pub fn connect( } if (proxy.supports_connect) tunnel: { - return connectTunnel(client, proxy, host, port) catch |err| switch (err) { + return connectProxied(client, proxy, host, port) catch |err| switch (err) { error.TunnelNotSupported => break :tunnel, else => |e| return e, }; } // fall back to using the proxy as a normal http proxy - const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); - errdefer { - conn.closing = true; - client.connection_pool.release(conn); - } - - conn.proxied = true; - return conn; + const connection = try client.connectTcp(proxy.host, proxy.port, proxy.protocol); + connection.proxied = true; + return connection; } -pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError || - std.fmt.ParseIntError || Connection.WriteError || - error{ - UnsupportedUriScheme, - UriMissingHost, - - CertificateBundleLoadFailure, - UnsupportedTransferEncoding, - }; +pub const RequestError = ConnectTcpError || error{ + UnsupportedUriScheme, + UriMissingHost, + UriHostTooLong, + CertificateBundleLoadFailure, +}; pub const RequestOptions = struct { version: http.Version = .@"HTTP/1.1", @@ -1578,11 +1601,6 @@ pub const RequestOptions = struct { /// payload or the server has acknowledged the payload). redirect_behavior: Request.RedirectBehavior = @enumFromInt(3), - /// Externally-owned memory used to store the server's entire HTTP header. - /// `error.HttpHeadersOversize` is returned from read() when a - /// client sends too many bytes of HTTP headers. - server_header_buffer: []u8, - /// Must be an already acquired connection. connection: ?*Connection = null, @@ -1598,38 +1616,17 @@ pub const RequestOptions = struct { privileged_headers: []const http.Header = &.{}, }; -fn validateUri(uri: Uri, arena: Allocator) !struct { Connection.Protocol, Uri } { - const protocol_map = std.StaticStringMap(Connection.Protocol).initComptime(.{ - .{ "http", .plain }, - .{ "ws", .plain }, - .{ "https", .tls }, - .{ "wss", .tls }, - }); - const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUriScheme; - var valid_uri = uri; - // The host is always going to be needed as a raw string for hostname resolution anyway. - valid_uri.host = .{ - .raw = try (uri.host orelse return error.UriMissingHost).toRawMaybeAlloc(arena), - }; - return .{ protocol, valid_uri }; -} - -fn uriPort(uri: Uri, protocol: Connection.Protocol) u16 { - return uri.port orelse switch (protocol) { - .plain => 80, - .tls => 443, - }; +fn uriPort(uri: Uri, protocol: Protocol) u16 { + return uri.port orelse protocol.port(); } /// Open a connection to the host specified by `uri` and prepare to send a HTTP request. /// -/// `uri` must remain alive during the entire request. -/// /// The caller is responsible for calling `deinit()` on the `Request`. /// This function is threadsafe. /// /// Asserts that "\r\n" does not occur in any header name or value. -pub fn open( +pub fn request( client: *Client, method: http.Method, uri: Uri, @@ -1649,59 +1646,58 @@ pub fn open( } } - var server_header: std.heap.FixedBufferAllocator = .init(options.server_header_buffer); - const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); + const protocol = Protocol.fromUri(uri) orelse return error.UnsupportedUriScheme; - if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { + if (protocol == .tls) { if (disable_tls) unreachable; - - client.ca_bundle_mutex.lock(); - defer client.ca_bundle_mutex.unlock(); - - if (client.next_https_rescan_certs) { - client.ca_bundle.rescan(client.allocator) catch - return error.CertificateBundleLoadFailure; - @atomicStore(bool, &client.next_https_rescan_certs, false, .release); + if (@atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { + client.ca_bundle_mutex.lock(); + defer client.ca_bundle_mutex.unlock(); + + if (client.next_https_rescan_certs) { + client.ca_bundle.rescan(client.allocator) catch + return error.CertificateBundleLoadFailure; + @atomicStore(bool, &client.next_https_rescan_certs, false, .release); + } } } - const conn = options.connection orelse - try client.connect(valid_uri.host.?.raw, uriPort(valid_uri, protocol), protocol); + const connection = options.connection orelse c: { + var host_name_buffer: [Uri.host_name_max]u8 = undefined; + const host_name = try uri.getHost(&host_name_buffer); + break :c try client.connect(host_name, uriPort(uri, protocol), protocol); + }; - var req: Request = .{ - .uri = valid_uri, + return .{ + .uri = uri, .client = client, - .connection = conn, + .connection = connection, + .reader = .{ + .in = connection.reader(), + .state = .ready, + // Populated when `http.Reader.bodyReader` is called. + .interface = undefined, + }, .keep_alive = options.keep_alive, .method = method, .version = options.version, .transfer_encoding = .none, .redirect_behavior = options.redirect_behavior, .handle_continue = options.handle_continue, - .response = .{ - .version = undefined, - .status = undefined, - .reason = undefined, - .keep_alive = undefined, - .parser = .init(server_header.buffer[server_header.end_index..]), - }, .headers = options.headers, .extra_headers = options.extra_headers, .privileged_headers = options.privileged_headers, }; - errdefer req.deinit(); - - return req; } pub const FetchOptions = struct { - server_header_buffer: ?[]u8 = null, + /// `null` means it will be heap-allocated. + redirect_buffer: ?[]u8 = null, + /// `null` means it will be heap-allocated. + decompress_buffer: ?[]u8 = null, redirect_behavior: ?Request.RedirectBehavior = null, - - /// If the server sends a body, it will be appended to this ArrayList. - /// `max_append_size` provides an upper limit for how much they can grow. - response_storage: ResponseStorage = .ignore, - max_append_size: ?usize = null, + /// If the server sends a body, it will be stored here. + response_storage: ?ResponseStorage = null, location: Location, method: ?http.Method = null, @@ -1725,11 +1721,11 @@ pub const FetchOptions = struct { uri: Uri, }; - pub const ResponseStorage = union(enum) { - ignore, - /// Only the existing capacity will be used. - static: *std.ArrayListUnmanaged(u8), - dynamic: *std.ArrayList(u8), + pub const ResponseStorage = struct { + list: *std.ArrayListUnmanaged(u8), + /// If null then only the existing capacity will be used. + allocator: ?Allocator = null, + append_limit: std.io.Limit = .unlimited, }; }; @@ -1737,23 +1733,29 @@ pub const FetchResult = struct { status: http.Status, }; +pub const FetchError = Uri.ParseError || RequestError || Request.ReceiveHeadError || error{ + StreamTooLong, + /// TODO provide optional diagnostics when this occurs or break into more error codes + WriteFailed, + UnsupportedCompressionMethod, +}; + /// Perform a one-shot HTTP request with the provided options. /// /// This function is threadsafe. -pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { +pub fn fetch(client: *Client, options: FetchOptions) FetchError!FetchResult { const uri = switch (options.location) { .url => |u| try Uri.parse(u), .uri => |u| u, }; - var server_header_buffer: [16 * 1024]u8 = undefined; - const method: http.Method = options.method orelse if (options.payload != null) .POST else .GET; - var req = try open(client, method, uri, .{ - .server_header_buffer = options.server_header_buffer orelse &server_header_buffer, - .redirect_behavior = options.redirect_behavior orelse - if (options.payload == null) @enumFromInt(3) else .unhandled, + const redirect_behavior: Request.RedirectBehavior = options.redirect_behavior orelse + if (options.payload == null) @enumFromInt(3) else .unhandled; + + var req = try request(client, method, uri, .{ + .redirect_behavior = redirect_behavior, .headers = options.headers, .extra_headers = options.extra_headers, .privileged_headers = options.privileged_headers, @@ -1761,44 +1763,70 @@ pub fn fetch(client: *Client, options: FetchOptions) !FetchResult { }); defer req.deinit(); - if (options.payload) |payload| req.transfer_encoding = .{ .content_length = payload.len }; + if (options.payload) |payload| { + req.transfer_encoding = .{ .content_length = payload.len }; + var body = try req.sendBody(&.{}); + try body.writer.writeAll(payload); + try body.end(); + } else { + try req.sendBodiless(); + } - try req.send(); + const redirect_buffer: []u8 = if (redirect_behavior == .unhandled) &.{} else options.redirect_buffer orelse + try client.allocator.alloc(u8, 8 * 1024); + defer if (options.redirect_buffer == null) client.allocator.free(redirect_buffer); - if (options.payload) |payload| try req.writeAll(payload); + var response = try req.receiveHead(redirect_buffer); - try req.finish(); - try req.wait(); + const storage = options.response_storage orelse { + const reader = response.reader(&.{}); + _ = reader.discardRemaining() catch |err| switch (err) { + error.ReadFailed => return response.bodyErr().?, + }; + return .{ .status = response.head.status }; + }; - switch (options.response_storage) { - .ignore => { - // Take advantage of request internals to discard the response body - // and make the connection available for another request. - req.response.skip = true; - assert(try req.transferRead(&.{}) == 0); // No buffer is necessary when skipping. - }, - .dynamic => |list| { - const max_append_size = options.max_append_size orelse 2 * 1024 * 1024; - try req.reader().readAllArrayList(list, max_append_size); - }, - .static => |list| { - const buf = b: { - const buf = list.unusedCapacitySlice(); - if (options.max_append_size) |len| { - if (len < buf.len) break :b buf[0..len]; - } - break :b buf; - }; - list.items.len += try req.reader().readAll(buf); - }, + const decompress_buffer: []u8 = switch (response.head.content_encoding) { + .identity => &.{}, + .zstd => options.decompress_buffer orelse try client.allocator.alloc(u8, std.compress.zstd.default_window_len), + .deflate, .gzip => options.decompress_buffer orelse try client.allocator.alloc(u8, std.compress.flate.max_window_len), + .compress => return error.UnsupportedCompressionMethod, + }; + defer if (options.decompress_buffer == null) client.allocator.free(decompress_buffer); + + var decompressor: http.Decompressor = undefined; + const reader = response.readerDecompressing(&decompressor, decompress_buffer); + const list = storage.list; + + if (storage.allocator) |allocator| { + reader.appendRemaining(allocator, null, list, storage.append_limit) catch |err| switch (err) { + error.ReadFailed => return response.bodyErr().?, + else => |e| return e, + }; + } else { + const buf = storage.append_limit.slice(list.unusedCapacitySlice()); + list.items.len += reader.readSliceShort(buf) catch |err| switch (err) { + error.ReadFailed => return response.bodyErr().?, + }; } - return .{ - .status = req.response.status, - }; + return .{ .status = response.head.status }; +} + +pub fn sameParentDomain(parent_host: []const u8, child_host: []const u8) bool { + if (!std.ascii.endsWithIgnoreCase(child_host, parent_host)) return false; + if (child_host.len == parent_host.len) return true; + if (parent_host.len > child_host.len) return false; + return child_host[child_host.len - parent_host.len - 1] == '.'; +} + +test sameParentDomain { + try testing.expect(!sameParentDomain("foo.com", "bar.com")); + try testing.expect(sameParentDomain("foo.com", "foo.com")); + try testing.expect(sameParentDomain("foo.com", "bar.foo.com")); + try testing.expect(!sameParentDomain("bar.foo.com", "foo.com")); } test { _ = Response; - _ = &initDefaultProxies; } diff --git a/lib/std/http/Server.zig b/lib/std/http/Server.zig index 8997828f6252..9574a4dc6a90 100644 --- a/lib/std/http/Server.zig +++ b/lib/std/http/Server.zig @@ -1,139 +1,69 @@ -//! Blocking HTTP server implementation. -//! Handles a single connection's lifecycle. - -connection: net.Server.Connection, -/// Keeps track of whether the Server is ready to accept a new request on the -/// same connection, and makes invalid API usage cause assertion failures -/// rather than HTTP protocol violations. -state: State, -/// User-provided buffer that must outlive this Server. -/// Used to store the client's entire HTTP header. -read_buffer: []u8, -/// Amount of available data inside read_buffer. -read_buffer_len: usize, -/// Index into `read_buffer` of the first byte of the next HTTP request. -next_request_start: usize, - -pub const State = enum { - /// The connection is available to be used for the first time, or reused. - ready, - /// An error occurred in `receiveHead`. - receiving_head, - /// A Request object has been obtained and from there a Response can be - /// opened. - received_head, - /// The client is uploading something to this Server. - receiving_body, - /// The connection is eligible for another HTTP request, however the client - /// and server did not negotiate a persistent connection. - closing, -}; +//! Handles a single connection lifecycle. + +const std = @import("../std.zig"); +const http = std.http; +const mem = std.mem; +const Uri = std.Uri; +const assert = std.debug.assert; +const testing = std.testing; +const Writer = std.Io.Writer; +const Reader = std.Io.Reader; + +const Server = @This(); + +/// Data from the HTTP server to the HTTP client. +out: *Writer, +reader: http.Reader, /// Initialize an HTTP server that can respond to multiple requests on the same /// connection. +/// +/// The buffer of `in` must be large enough to store the client's entire HTTP +/// header, otherwise `receiveHead` returns `error.HttpHeadersOversize`. +/// /// The returned `Server` is ready for `receiveHead` to be called. -pub fn init(connection: net.Server.Connection, read_buffer: []u8) Server { +pub fn init(in: *Reader, out: *Writer) Server { return .{ - .connection = connection, - .state = .ready, - .read_buffer = read_buffer, - .read_buffer_len = 0, - .next_request_start = 0, + .reader = .{ + .in = in, + .state = .ready, + // Populated when `http.Reader.bodyReader` is called. + .interface = undefined, + }, + .out = out, }; } -pub const ReceiveHeadError = error{ - /// Client sent too many bytes of HTTP headers. - /// The HTTP specification suggests to respond with a 431 status code - /// before closing the connection. - HttpHeadersOversize, +pub const ReceiveHeadError = http.Reader.HeadError || error{ /// Client sent headers that did not conform to the HTTP protocol. + /// + /// To find out more detailed diagnostics, `Request.head_buffer` can be + /// passed directly to `Request.Head.parse`. HttpHeadersInvalid, - /// A low level I/O error occurred trying to read the headers. - HttpHeadersUnreadable, - /// Partial HTTP request was received but the connection was closed before - /// fully receiving the headers. - HttpRequestTruncated, - /// The client sent 0 bytes of headers before closing the stream. - /// In other words, a keep-alive connection was finally closed. - HttpConnectionClosing, }; -/// The header bytes reference the read buffer that Server was initialized with -/// and remain alive until the next call to receiveHead. pub fn receiveHead(s: *Server) ReceiveHeadError!Request { - assert(s.state == .ready); - s.state = .received_head; - errdefer s.state = .receiving_head; - - // In case of a reused connection, move the next request's bytes to the - // beginning of the buffer. - if (s.next_request_start > 0) { - if (s.read_buffer_len > s.next_request_start) { - rebase(s, 0); - } else { - s.read_buffer_len = 0; - } - } - - var hp: http.HeadParser = .{}; - - if (s.read_buffer_len > 0) { - const bytes = s.read_buffer[0..s.read_buffer_len]; - const end = hp.feed(bytes); - if (hp.state == .finished) - return finishReceivingHead(s, end); - } - - while (true) { - const buf = s.read_buffer[s.read_buffer_len..]; - if (buf.len == 0) - return error.HttpHeadersOversize; - const read_n = s.connection.stream.read(buf) catch - return error.HttpHeadersUnreadable; - if (read_n == 0) { - if (s.read_buffer_len > 0) { - return error.HttpRequestTruncated; - } else { - return error.HttpConnectionClosing; - } - } - s.read_buffer_len += read_n; - const bytes = buf[0..read_n]; - const end = hp.feed(bytes); - if (hp.state == .finished) - return finishReceivingHead(s, s.read_buffer_len - bytes.len + end); - } -} - -fn finishReceivingHead(s: *Server, head_end: usize) ReceiveHeadError!Request { + const head_buffer = try s.reader.receiveHead(); return .{ .server = s, - .head_end = head_end, - .head = Request.Head.parse(s.read_buffer[0..head_end]) catch - return error.HttpHeadersInvalid, - .reader_state = undefined, + .head_buffer = head_buffer, + // No need to track the returned error here since users can repeat the + // parse with the header buffer to get detailed diagnostics. + .head = Request.Head.parse(head_buffer) catch return error.HttpHeadersInvalid, }; } pub const Request = struct { server: *Server, - /// Index into Server's read_buffer. - head_end: usize, + /// Pointers in this struct are invalidated when the request body stream is + /// initialized. head: Head, - reader_state: union { - remaining_content_length: u64, - chunk_parser: http.ChunkParser, - }, - - pub const Compression = union(enum) { - pub const DeflateDecompressor = std.compress.zlib.Decompressor(std.io.AnyReader); - pub const GzipDecompressor = std.compress.gzip.Decompressor(std.io.AnyReader); - - deflate: std.compress.flate.Decompress, - gzip: std.compress.flate.Decompress, - zstd: std.compress.zstd.Decompress, - none: void, + head_buffer: []const u8, + respond_err: ?RespondError = null, + + pub const RespondError = error{ + /// The request contained an `expect` header with an unrecognized value. + HttpExpectationFailed, }; pub const Head = struct { @@ -146,7 +76,6 @@ pub const Request = struct { transfer_encoding: http.TransferEncoding, transfer_compression: http.ContentEncoding, keep_alive: bool, - compression: Compression, pub const ParseError = error{ UnknownHttpMethod, @@ -168,10 +97,9 @@ pub const Request = struct { const method_end = mem.indexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; - if (method_end > 24) return error.HttpHeadersInvalid; - const method_str = first_line[0..method_end]; - const method: http.Method = @enumFromInt(http.Method.parse(method_str)); + const method = std.meta.stringToEnum(http.Method, first_line[0..method_end]) orelse + return error.UnknownHttpMethod; const version_start = mem.lastIndexOfScalar(u8, first_line, ' ') orelse return error.HttpHeadersInvalid; @@ -200,7 +128,6 @@ pub const Request = struct { .@"HTTP/1.0" => false, .@"HTTP/1.1" => true, }, - .compression = .none, }; while (it.next()) |line| { @@ -230,7 +157,7 @@ pub const Request = struct { const trimmed = mem.trim(u8, header_value, " "); - if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| { + if (http.ContentEncoding.fromString(trimmed)) |ce| { head.transfer_compression = ce; } else { return error.HttpTransferEncodingUnsupported; @@ -255,7 +182,7 @@ pub const Request = struct { if (next) |second| { const trimmed_second = mem.trim(u8, second, " "); - if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| { + if (http.ContentEncoding.fromString(trimmed_second)) |transfer| { if (head.transfer_compression != .identity) return error.HttpHeadersInvalid; // double compression is not supported head.transfer_compression = transfer; @@ -296,10 +223,19 @@ pub const Request = struct { inline fn int64(array: *const [8]u8) u64 { return @bitCast(array.*); } + + /// Help the programmer avoid bugs by calling this when the string + /// memory of `Head` becomes invalidated. + fn invalidateStrings(h: *Head) void { + h.target = undefined; + if (h.expect) |*s| s.* = undefined; + if (h.content_type) |*s| s.* = undefined; + } }; - pub fn iterateHeaders(r: *Request) http.HeaderIterator { - return http.HeaderIterator.init(r.server.read_buffer[0..r.head_end]); + pub fn iterateHeaders(r: *const Request) http.HeaderIterator { + assert(r.server.reader.state == .received_head); + return http.HeaderIterator.init(r.head_buffer); } test iterateHeaders { @@ -310,22 +246,19 @@ pub const Request = struct { "TRansfer-encoding:\tdeflate, chunked \r\n" ++ "connectioN:\t keep-alive \r\n\r\n"; - var read_buffer: [500]u8 = undefined; - @memcpy(read_buffer[0..request_bytes.len], request_bytes); - var server: Server = .{ - .connection = undefined, - .state = .ready, - .read_buffer = &read_buffer, - .read_buffer_len = request_bytes.len, - .next_request_start = 0, + .reader = .{ + .in = undefined, + .state = .received_head, + .interface = undefined, + }, + .out = undefined, }; var request: Request = .{ .server = &server, - .head_end = request_bytes.len, .head = undefined, - .reader_state = undefined, + .head_buffer = @constCast(request_bytes), }; var it = request.iterateHeaders(); @@ -384,16 +317,22 @@ pub const Request = struct { /// no error is surfaced. /// /// Asserts status is not `continue`. - /// Asserts there are at most 25 extra_headers. /// Asserts that "\r\n" does not occur in any header name or value. pub fn respond( request: *Request, content: []const u8, options: RespondOptions, - ) Response.WriteError!void { - const max_extra_headers = 25; + ) ExpectContinueError!void { + try respondUnflushed(request, content, options); + try request.server.out.flush(); + } + + pub fn respondUnflushed( + request: *Request, + content: []const u8, + options: RespondOptions, + ) ExpectContinueError!void { assert(options.status != .@"continue"); - assert(options.extra_headers.len <= max_extra_headers); if (std.debug.runtime_safety) { for (options.extra_headers) |header| { assert(header.name.len != 0); @@ -402,6 +341,7 @@ pub const Request = struct { assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null); } } + try writeExpectContinue(request); const transfer_encoding_none = (options.transfer_encoding orelse .chunked) == .none; const server_keep_alive = !transfer_encoding_none and options.keep_alive; @@ -409,130 +349,42 @@ pub const Request = struct { const phrase = options.reason orelse options.status.phrase() orelse ""; - var first_buffer: [500]u8 = undefined; - var h = std.ArrayListUnmanaged(u8).initBuffer(&first_buffer); - if (request.head.expect != null) { - // reader() and hence discardBody() above sets expect to null if it - // is handled. So the fact that it is not null here means unhandled. - h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); - if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); - h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); - try request.server.connection.stream.writeAll(h.items); - return; - } - h.fixedWriter().print("{s} {d} {s}\r\n", .{ + const out = request.server.out; + try out.print("{s} {d} {s}\r\n", .{ @tagName(options.version), @intFromEnum(options.status), phrase, - }) catch unreachable; + }); switch (options.version) { - .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), - .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), + .@"HTTP/1.0" => if (keep_alive) try out.writeAll("connection: keep-alive\r\n"), + .@"HTTP/1.1" => if (!keep_alive) try out.writeAll("connection: close\r\n"), } if (options.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { .none => {}, - .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), + .chunked => try out.writeAll("transfer-encoding: chunked\r\n"), } else { - h.fixedWriter().print("content-length: {d}\r\n", .{content.len}) catch unreachable; + try out.print("content-length: {d}\r\n", .{content.len}); } - var chunk_header_buffer: [18]u8 = undefined; - var iovecs: [max_extra_headers * 4 + 3]std.posix.iovec_const = undefined; - var iovecs_len: usize = 0; - - iovecs[iovecs_len] = .{ - .base = h.items.ptr, - .len = h.items.len, - }; - iovecs_len += 1; - for (options.extra_headers) |header| { - iovecs[iovecs_len] = .{ - .base = header.name.ptr, - .len = header.name.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = ": ", - .len = 2, - }; - iovecs_len += 1; - - if (header.value.len != 0) { - iovecs[iovecs_len] = .{ - .base = header.value.ptr, - .len = header.value.len, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; + var vecs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" }; + try out.writeVecAll(&vecs); } - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; + try out.writeAll("\r\n"); if (request.head.method != .HEAD) { const is_chunked = (options.transfer_encoding orelse .none) == .chunked; if (is_chunked) { - if (content.len > 0) { - const chunk_header = std.fmt.bufPrint( - &chunk_header_buffer, - "{x}\r\n", - .{content.len}, - ) catch unreachable; - - iovecs[iovecs_len] = .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = content.ptr, - .len = content.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "0\r\n\r\n", - .len = 5, - }; - iovecs_len += 1; + if (content.len > 0) try out.print("{x}\r\n{s}\r\n", .{ content.len, content }); + try out.writeAll("0\r\n\r\n"); } else if (content.len > 0) { - iovecs[iovecs_len] = .{ - .base = content.ptr, - .len = content.len, - }; - iovecs_len += 1; + try out.writeAll(content); } } - - try request.server.connection.stream.writevAll(iovecs[0..iovecs_len]); } pub const RespondStreamingOptions = struct { - /// An externally managed slice of memory used to batch bytes before - /// sending. `respondStreaming` asserts this is large enough to store - /// the full HTTP response head. - /// - /// Must outlive the returned Response. - send_buffer: []u8, /// If provided, the response will use the content-length header; /// otherwise it will use transfer-encoding: chunked. content_length: ?u64 = null, @@ -540,254 +392,227 @@ pub const Request = struct { respond_options: RespondOptions = .{}, }; - /// The header is buffered but not sent until Response.flush is called. + /// The header is not guaranteed to be sent until `BodyWriter.flush` or + /// `BodyWriter.end` is called. /// /// If the request contains a body and the connection is to be reused, /// discards the request body, leaving the Server in the `ready` state. If /// this discarding fails, the connection is marked as not to be reused and /// no error is surfaced. /// - /// HEAD requests are handled transparently by setting a flag on the - /// returned Response to omit the body. However it may be worth noticing + /// HEAD requests are handled transparently by setting the + /// `BodyWriter.elide` flag on the returned `BodyWriter`, causing + /// the response stream to omit the body. However, it may be worth noticing /// that flag and skipping any expensive work that would otherwise need to /// be done to satisfy the request. /// - /// Asserts `send_buffer` is large enough to store the entire response header. /// Asserts status is not `continue`. - pub fn respondStreaming(request: *Request, options: RespondStreamingOptions) Response { + pub fn respondStreaming( + request: *Request, + buffer: []u8, + options: RespondStreamingOptions, + ) ExpectContinueError!http.BodyWriter { + try writeExpectContinue(request); const o = options.respond_options; assert(o.status != .@"continue"); const transfer_encoding_none = (o.transfer_encoding orelse .chunked) == .none; const server_keep_alive = !transfer_encoding_none and o.keep_alive; const keep_alive = request.discardBody(server_keep_alive); const phrase = o.reason orelse o.status.phrase() orelse ""; + const out = request.server.out; - var h = std.ArrayListUnmanaged(u8).initBuffer(options.send_buffer); - - const elide_body = if (request.head.expect != null) eb: { - // reader() and hence discardBody() above sets expect to null if it - // is handled. So the fact that it is not null here means unhandled. - h.appendSliceAssumeCapacity("HTTP/1.1 417 Expectation Failed\r\n"); - if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"); - h.appendSliceAssumeCapacity("content-length: 0\r\n\r\n"); - break :eb true; - } else eb: { - h.fixedWriter().print("{s} {d} {s}\r\n", .{ - @tagName(o.version), @intFromEnum(o.status), phrase, - }) catch unreachable; - - switch (o.version) { - .@"HTTP/1.0" => if (keep_alive) h.appendSliceAssumeCapacity("connection: keep-alive\r\n"), - .@"HTTP/1.1" => if (!keep_alive) h.appendSliceAssumeCapacity("connection: close\r\n"), - } + try out.print("{s} {d} {s}\r\n", .{ + @tagName(o.version), @intFromEnum(o.status), phrase, + }); - if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { - .chunked => h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"), - .none => {}, - } else if (options.content_length) |len| { - h.fixedWriter().print("content-length: {d}\r\n", .{len}) catch unreachable; - } else { - h.appendSliceAssumeCapacity("transfer-encoding: chunked\r\n"); - } + switch (o.version) { + .@"HTTP/1.0" => if (keep_alive) try out.writeAll("connection: keep-alive\r\n"), + .@"HTTP/1.1" => if (!keep_alive) try out.writeAll("connection: close\r\n"), + } - for (o.extra_headers) |header| { - assert(header.name.len != 0); - h.appendSliceAssumeCapacity(header.name); - h.appendSliceAssumeCapacity(": "); - h.appendSliceAssumeCapacity(header.value); - h.appendSliceAssumeCapacity("\r\n"); - } + if (o.transfer_encoding) |transfer_encoding| switch (transfer_encoding) { + .chunked => try out.writeAll("transfer-encoding: chunked\r\n"), + .none => {}, + } else if (options.content_length) |len| { + try out.print("content-length: {d}\r\n", .{len}); + } else { + try out.writeAll("transfer-encoding: chunked\r\n"); + } - h.appendSliceAssumeCapacity("\r\n"); - break :eb request.head.method == .HEAD; - }; + for (o.extra_headers) |header| { + assert(header.name.len != 0); + var bufs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" }; + try out.writeVecAll(&bufs); + } - return .{ - .stream = request.server.connection.stream, - .send_buffer = options.send_buffer, - .send_buffer_start = 0, - .send_buffer_end = h.items.len, - .transfer_encoding = if (o.transfer_encoding) |te| switch (te) { - .chunked => .chunked, - .none => .none, - } else if (options.content_length) |len| .{ - .content_length = len, - } else .chunked, - .elide_body = elide_body, - .chunk_len = 0, + try out.writeAll("\r\n"); + const elide_body = request.head.method == .HEAD; + const state: http.BodyWriter.State = if (o.transfer_encoding) |te| switch (te) { + .chunked => .{ .chunked = .init }, + .none => .none, + } else if (options.content_length) |len| .{ + .content_length = len, + } else .{ .chunked = .init }; + + return if (elide_body) .{ + .http_protocol_output = request.server.out, + .state = state, + .writer = .{ + .buffer = buffer, + .vtable = &.{ + .drain = http.BodyWriter.elidingDrain, + .sendFile = http.BodyWriter.elidingSendFile, + }, + }, + } else .{ + .http_protocol_output = request.server.out, + .state = state, + .writer = .{ + .buffer = buffer, + .vtable = switch (state) { + .none => &.{ + .drain = http.BodyWriter.noneDrain, + .sendFile = http.BodyWriter.noneSendFile, + }, + .content_length => &.{ + .drain = http.BodyWriter.contentLengthDrain, + .sendFile = http.BodyWriter.contentLengthSendFile, + }, + .chunked => &.{ + .drain = http.BodyWriter.chunkedDrain, + .sendFile = http.BodyWriter.chunkedSendFile, + }, + .end => unreachable, + }, + }, }; } - pub const ReadError = net.Stream.ReadError || error{ - HttpChunkInvalid, - HttpHeadersOversize, + pub const UpgradeRequest = union(enum) { + websocket: ?[]const u8, + other: []const u8, + none, }; - fn read_cl(context: *const anyopaque, buffer: []u8) ReadError!usize { - const request: *Request = @ptrCast(@alignCast(@constCast(context))); - const s = request.server; - - const remaining_content_length = &request.reader_state.remaining_content_length; - if (remaining_content_length.* == 0) { - s.state = .ready; - return 0; + /// Does not invalidate `request.head`. + pub fn upgradeRequested(request: *const Request) UpgradeRequest { + switch (request.head.version) { + .@"HTTP/1.0" => return .none, + .@"HTTP/1.1" => if (request.head.method != .GET) return .none, } - assert(s.state == .receiving_body); - const available = try fill(s, request.head_end); - const len = @min(remaining_content_length.*, available.len, buffer.len); - @memcpy(buffer[0..len], available[0..len]); - remaining_content_length.* -= len; - s.next_request_start += len; - if (remaining_content_length.* == 0) - s.state = .ready; - return len; - } - - fn fill(s: *Server, head_end: usize) ReadError![]u8 { - const available = s.read_buffer[s.next_request_start..s.read_buffer_len]; - if (available.len > 0) return available; - s.next_request_start = head_end; - s.read_buffer_len = head_end + try s.connection.stream.read(s.read_buffer[head_end..]); - return s.read_buffer[head_end..s.read_buffer_len]; - } - fn read_chunked(context: *const anyopaque, buffer: []u8) ReadError!usize { - const request: *Request = @ptrCast(@alignCast(@constCast(context))); - const s = request.server; - - const cp = &request.reader_state.chunk_parser; - const head_end = request.head_end; - - // Protect against returning 0 before the end of stream. - var out_end: usize = 0; - while (out_end == 0) { - switch (cp.state) { - .invalid => return 0, - .data => { - assert(s.state == .receiving_body); - const available = try fill(s, head_end); - const len = @min(cp.chunk_len, available.len, buffer.len); - @memcpy(buffer[0..len], available[0..len]); - cp.chunk_len -= len; - if (cp.chunk_len == 0) - cp.state = .data_suffix; - out_end += len; - s.next_request_start += len; - continue; - }, - else => { - assert(s.state == .receiving_body); - const available = try fill(s, head_end); - const n = cp.feed(available); - switch (cp.state) { - .invalid => return error.HttpChunkInvalid, - .data => { - if (cp.chunk_len == 0) { - // The next bytes in the stream are trailers, - // or \r\n to indicate end of chunked body. - // - // This function must append the trailers at - // head_end so that headers and trailers are - // together. - // - // Since returning 0 would indicate end of - // stream, this function must read all the - // trailers before returning. - if (s.next_request_start > head_end) rebase(s, head_end); - var hp: http.HeadParser = .{}; - { - const bytes = s.read_buffer[head_end..s.read_buffer_len]; - const end = hp.feed(bytes); - if (hp.state == .finished) { - cp.state = .invalid; - s.state = .ready; - s.next_request_start = s.read_buffer_len - bytes.len + end; - return out_end; - } - } - while (true) { - const buf = s.read_buffer[s.read_buffer_len..]; - if (buf.len == 0) - return error.HttpHeadersOversize; - const read_n = try s.connection.stream.read(buf); - s.read_buffer_len += read_n; - const bytes = buf[0..read_n]; - const end = hp.feed(bytes); - if (hp.state == .finished) { - cp.state = .invalid; - s.state = .ready; - s.next_request_start = s.read_buffer_len - bytes.len + end; - return out_end; - } - } - } - const data = available[n..]; - const len = @min(cp.chunk_len, data.len, buffer.len); - @memcpy(buffer[0..len], data[0..len]); - cp.chunk_len -= len; - if (cp.chunk_len == 0) - cp.state = .data_suffix; - out_end += len; - s.next_request_start += n + len; - continue; - }, - else => continue, - } - }, + var sec_websocket_key: ?[]const u8 = null; + var upgrade_name: ?[]const u8 = null; + var it = request.iterateHeaders(); + while (it.next()) |header| { + if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) { + sec_websocket_key = header.value; + } else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) { + upgrade_name = header.value; } } - return out_end; + + const name = upgrade_name orelse return .none; + if (std.ascii.eqlIgnoreCase(name, "websocket")) return .{ .websocket = sec_websocket_key }; + return .{ .other = name }; } - pub const ReaderError = Response.WriteError || error{ - /// The client sent an expect HTTP header value other than - /// "100-continue". - HttpExpectationFailed, + pub const WebSocketOptions = struct { + /// The value from `UpgradeRequest.websocket` (sec-websocket-key header value). + key: []const u8, + reason: ?[]const u8 = null, + extra_headers: []const http.Header = &.{}, }; + /// The header is not guaranteed to be sent until `WebSocket.flush` is + /// called on the returned struct. + pub fn respondWebSocket(request: *Request, options: WebSocketOptions) ExpectContinueError!WebSocket { + if (request.head.expect != null) return error.HttpExpectationFailed; + + const out = request.server.out; + const version: http.Version = .@"HTTP/1.1"; + const status: http.Status = .switching_protocols; + const phrase = options.reason orelse status.phrase() orelse ""; + + assert(request.head.version == version); + assert(request.head.method == .GET); + + var sha1 = std.crypto.hash.Sha1.init(.{}); + sha1.update(options.key); + sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined; + sha1.final(&digest); + try out.print("{s} {d} {s}\r\n", .{ @tagName(version), @intFromEnum(status), phrase }); + try out.writeAll("connection: upgrade\r\nupgrade: websocket\r\nsec-websocket-accept: "); + const base64_digest = try out.writableArray(28); + assert(std.base64.standard.Encoder.encode(base64_digest, &digest).len == base64_digest.len); + out.advance(base64_digest.len); + try out.writeAll("\r\n"); + + for (options.extra_headers) |header| { + assert(header.name.len != 0); + var bufs: [4][]const u8 = .{ header.name, ": ", header.value, "\r\n" }; + try out.writeVecAll(&bufs); + } + + try out.writeAll("\r\n"); + + return .{ + .input = request.server.reader.in, + .output = request.server.out, + .key = options.key, + }; + } + /// In the case that the request contains "expect: 100-continue", this /// function writes the continuation header, which means it can fail with a /// write error. After sending the continuation header, it sets the /// request's expect field to `null`. /// /// Asserts that this function is only called once. - pub fn reader(request: *Request) ReaderError!std.io.AnyReader { - const s = request.server; - assert(s.state == .received_head); - s.state = .receiving_body; - s.next_request_start = request.head_end; - - if (request.head.expect) |expect| { - if (mem.eql(u8, expect, "100-continue")) { - try request.server.connection.stream.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); - request.head.expect = null; - } else { - return error.HttpExpectationFailed; - } - } + /// + /// See `readerExpectNone` for an infallible alternative that cannot write + /// to the server output stream. + pub fn readerExpectContinue(request: *Request, buffer: []u8) ExpectContinueError!*Reader { + const flush = request.head.expect != null; + try writeExpectContinue(request); + if (flush) try request.server.out.flush(); + return readerExpectNone(request, buffer); + } - switch (request.head.transfer_encoding) { - .chunked => { - request.reader_state = .{ .chunk_parser = http.ChunkParser.init }; - return .{ - .readFn = read_chunked, - .context = request, - }; - }, - .none => { - request.reader_state = .{ - .remaining_content_length = request.head.content_length orelse 0, - }; - return .{ - .readFn = read_cl, - .context = request, - }; - }, - } + /// Asserts the expect header is `null`. The caller must handle the + /// expectation manually and then set the value to `null` prior to calling + /// this function. + /// + /// Asserts that this function is only called once. + /// + /// Invalidates the string memory inside `Head`. + pub fn readerExpectNone(request: *Request, buffer: []u8) *Reader { + assert(request.server.reader.state == .received_head); + assert(request.head.expect == null); + request.head.invalidateStrings(); + if (!request.head.method.requestHasBody()) return .ending; + return request.server.reader.bodyReader(buffer, request.head.transfer_encoding, request.head.content_length); + } + + pub const ExpectContinueError = error{ + /// Failed to write "HTTP/1.1 100 Continue\r\n\r\n" to the stream. + WriteFailed, + /// The client sent an expect HTTP header value other than + /// "100-continue". + HttpExpectationFailed, + }; + + pub fn writeExpectContinue(request: *Request) ExpectContinueError!void { + const expect = request.head.expect orelse return; + if (!mem.eql(u8, expect, "100-continue")) return error.HttpExpectationFailed; + try request.server.out.writeAll("HTTP/1.1 100 Continue\r\n\r\n"); + request.head.expect = null; } /// Returns whether the connection should remain persistent. - /// If it would fail, it instead sets the Server state to `receiving_body` + /// + /// If it would fail, it instead sets the Server state to receiving body /// and returns false. fn discardBody(request: *Request, keep_alive: bool) bool { // Prepare to receive another request on the same connection. @@ -798,350 +623,180 @@ pub const Request = struct { // or the request body. // If the connection won't be kept alive, then none of this matters // because the connection will be severed after the response is sent. - const s = request.server; - if (keep_alive and request.head.keep_alive) switch (s.state) { + const r = &request.server.reader; + if (keep_alive and request.head.keep_alive) switch (r.state) { .received_head => { - const r = request.reader() catch return false; - _ = r.discard() catch return false; - assert(s.state == .ready); + if (request.head.method.requestHasBody()) { + assert(request.head.transfer_encoding != .none or request.head.content_length != null); + const reader_interface = request.readerExpectContinue(&.{}) catch return false; + _ = reader_interface.discardRemaining() catch return false; + assert(r.state == .ready); + } else { + r.state = .ready; + } return true; }, - .receiving_body, .ready => return true, + .body_remaining_content_length, .body_remaining_chunk_len, .body_none, .ready => return true, else => unreachable, }; // Avoid clobbering the state in case a reading stream already exists. - switch (s.state) { - .received_head => s.state = .closing, + switch (r.state) { + .received_head => r.state = .closing, else => {}, } return false; } }; -pub const Response = struct { - stream: net.Stream, - send_buffer: []u8, - /// Index of the first byte in `send_buffer`. - /// This is 0 unless a short write happens in `write`. - send_buffer_start: usize, - /// Index of the last byte + 1 in `send_buffer`. - send_buffer_end: usize, - /// `null` means transfer-encoding: chunked. - /// As a debugging utility, counts down to zero as bytes are written. - transfer_encoding: TransferEncoding, - elide_body: bool, - /// Indicates how much of the end of the `send_buffer` corresponds to a - /// chunk. This amount of data will be wrapped by an HTTP chunk header. - chunk_len: usize, - - pub const TransferEncoding = union(enum) { - /// End of connection signals the end of the stream. - none, - /// As a debugging utility, counts down to zero as bytes are written. - content_length: u64, - /// Each chunk is wrapped in a header and trailer. - chunked, +/// See https://tools.ietf.org/html/rfc6455 +pub const WebSocket = struct { + key: []const u8, + input: *Reader, + output: *Writer, + + pub const Header0 = packed struct(u8) { + opcode: Opcode, + rsv3: u1 = 0, + rsv2: u1 = 0, + rsv1: u1 = 0, + fin: bool, }; - pub const WriteError = net.Stream.WriteError; - - /// When using content-length, asserts that the amount of data sent matches - /// the value sent in the header, then calls `flush`. - /// Otherwise, transfer-encoding: chunked is being used, and it writes the - /// end-of-stream message, then flushes the stream to the system. - /// Respects the value of `elide_body` to omit all data after the headers. - pub fn end(r: *Response) WriteError!void { - switch (r.transfer_encoding) { - .content_length => |len| { - assert(len == 0); // Trips when end() called before all bytes written. - try flush_cl(r); - }, - .none => { - try flush_cl(r); - }, - .chunked => { - try flush_chunked(r, &.{}); - }, - } - r.* = undefined; - } - - pub const EndChunkedOptions = struct { - trailers: []const http.Header = &.{}, + pub const Header1 = packed struct(u8) { + payload_len: enum(u7) { + len16 = 126, + len64 = 127, + _, + }, + mask: bool, }; - /// Asserts that the Response is using transfer-encoding: chunked. - /// Writes the end-of-stream message and any optional trailers, then - /// flushes the stream to the system. - /// Respects the value of `elide_body` to omit all data after the headers. - /// Asserts there are at most 25 trailers. - pub fn endChunked(r: *Response, options: EndChunkedOptions) WriteError!void { - assert(r.transfer_encoding == .chunked); - try flush_chunked(r, options.trailers); - r.* = undefined; - } - - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - /// May return 0, which does not indicate end of stream. The caller decides - /// when the end of stream occurs by calling `end`. - pub fn write(r: *Response, bytes: []const u8) WriteError!usize { - switch (r.transfer_encoding) { - .content_length, .none => return write_cl(r, bytes), - .chunked => return write_chunked(r, bytes), - } - } - - fn write_cl(context: *const anyopaque, bytes: []const u8) WriteError!usize { - const r: *Response = @ptrCast(@alignCast(@constCast(context))); + pub const Opcode = enum(u4) { + continuation = 0, + text = 1, + binary = 2, + connection_close = 8, + ping = 9, + /// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional + /// heartbeat. A response to an unsolicited Pong frame is not expected." + pong = 10, + _, + }; - var trash: u64 = std.math.maxInt(u64); - const len = switch (r.transfer_encoding) { - .content_length => |*len| len, - else => &trash, - }; + pub const ReadSmallTextMessageError = error{ + ConnectionClose, + UnexpectedOpCode, + MessageTooBig, + MissingMaskBit, + ReadFailed, + EndOfStream, + }; - if (r.elide_body) { - len.* -= bytes.len; - return bytes.len; - } + pub const SmallMessage = struct { + /// Can be text, binary, or ping. + opcode: Opcode, + data: []u8, + }; - if (bytes.len + r.send_buffer_end > r.send_buffer.len) { - const send_buffer_len = r.send_buffer_end - r.send_buffer_start; - var iovecs: [2]std.posix.iovec_const = .{ - .{ - .base = r.send_buffer.ptr + r.send_buffer_start, - .len = send_buffer_len, - }, - .{ - .base = bytes.ptr, - .len = bytes.len, - }, - }; - const n = try r.stream.writev(&iovecs); - - if (n >= send_buffer_len) { - // It was enough to reset the buffer. - r.send_buffer_start = 0; - r.send_buffer_end = 0; - const bytes_n = n - send_buffer_len; - len.* -= bytes_n; - return bytes_n; + /// Reads the next message from the WebSocket stream, failing if the + /// message does not fit into the input buffer. The returned memory points + /// into the input buffer and is invalidated on the next read. + pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage { + const in = ws.input; + while (true) { + const header = try in.takeArray(2); + const h0: Header0 = @bitCast(header[0]); + const h1: Header1 = @bitCast(header[1]); + + switch (h0.opcode) { + .text, .binary, .pong, .ping => {}, + .connection_close => return error.ConnectionClose, + .continuation => return error.UnexpectedOpCode, + _ => return error.UnexpectedOpCode, } - // It didn't even make it through the existing buffer, let - // alone the new bytes provided. - r.send_buffer_start += n; - return 0; - } - - // All bytes can be stored in the remaining space of the buffer. - @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); - r.send_buffer_end += bytes.len; - len.* -= bytes.len; - return bytes.len; - } + if (!h0.fin) return error.MessageTooBig; + if (!h1.mask) return error.MissingMaskBit; - fn write_chunked(context: *const anyopaque, bytes: []const u8) WriteError!usize { - const r: *Response = @ptrCast(@alignCast(@constCast(context))); - assert(r.transfer_encoding == .chunked); - - if (r.elide_body) - return bytes.len; - - if (bytes.len + r.send_buffer_end > r.send_buffer.len) { - const send_buffer_len = r.send_buffer_end - r.send_buffer_start; - const chunk_len = r.chunk_len + bytes.len; - var header_buf: [18]u8 = undefined; - const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{chunk_len}) catch unreachable; - - var iovecs: [5]std.posix.iovec_const = .{ - .{ - .base = r.send_buffer.ptr + r.send_buffer_start, - .len = send_buffer_len - r.chunk_len, - }, - .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }, - .{ - .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, - .len = r.chunk_len, - }, - .{ - .base = bytes.ptr, - .len = bytes.len, - }, - .{ - .base = "\r\n", - .len = 2, - }, + const len: usize = switch (h1.payload_len) { + .len16 => try in.takeInt(u16, .big), + .len64 => std.math.cast(usize, try in.takeInt(u64, .big)) orelse return error.MessageTooBig, + else => @intFromEnum(h1.payload_len), + }; + if (len > in.buffer.len) return error.MessageTooBig; + const mask: u32 = @bitCast((try in.takeArray(4)).*); + const payload = try in.take(len); + + // Skip pongs. + if (h0.opcode == .pong) continue; + + // The last item may contain a partial word of unused data. + const floored_len = (payload.len / 4) * 4; + const u32_payload: []align(1) u32 = @ptrCast(payload[0..floored_len]); + for (u32_payload) |*elem| elem.* ^= mask; + const mask_bytes: []const u8 = @ptrCast(&mask); + for (payload[floored_len..], mask_bytes[0 .. payload.len - floored_len]) |*leftover, m| + leftover.* ^= m; + + return .{ + .opcode = h0.opcode, + .data = payload, }; - // TODO make this writev instead of writevAll, which involves - // complicating the logic of this function. - try r.stream.writevAll(&iovecs); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - return bytes.len; } - - // All bytes can be stored in the remaining space of the buffer. - @memcpy(r.send_buffer[r.send_buffer_end..][0..bytes.len], bytes); - r.send_buffer_end += bytes.len; - r.chunk_len += bytes.len; - return bytes.len; } - /// If using content-length, asserts that writing these bytes to the client - /// would not exceed the content-length value sent in the HTTP header. - pub fn writeAll(r: *Response, bytes: []const u8) WriteError!void { - var index: usize = 0; - while (index < bytes.len) { - index += try write(r, bytes[index..]); - } + pub fn writeMessage(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void { + var bufs: [1][]const u8 = .{data}; + try writeMessageVecUnflushed(ws, &bufs, op); + try ws.output.flush(); } - /// Sends all buffered data to the client. - /// This is redundant after calling `end`. - /// Respects the value of `elide_body` to omit all data after the headers. - pub fn flush(r: *Response) WriteError!void { - switch (r.transfer_encoding) { - .none, .content_length => return flush_cl(r), - .chunked => return flush_chunked(r, null), - } + pub fn writeMessageUnflushed(ws: *WebSocket, data: []const u8, op: Opcode) Writer.Error!void { + var bufs: [1][]const u8 = .{data}; + try writeMessageVecUnflushed(ws, &bufs, op); } - fn flush_cl(r: *Response) WriteError!void { - try r.stream.writeAll(r.send_buffer[r.send_buffer_start..r.send_buffer_end]); - r.send_buffer_start = 0; - r.send_buffer_end = 0; + pub fn writeMessageVec(ws: *WebSocket, data: [][]const u8, op: Opcode) Writer.Error!void { + try writeMessageVecUnflushed(ws, data, op); + try ws.output.flush(); } - fn flush_chunked(r: *Response, end_trailers: ?[]const http.Header) WriteError!void { - const max_trailers = 25; - if (end_trailers) |trailers| assert(trailers.len <= max_trailers); - assert(r.transfer_encoding == .chunked); - - const http_headers = r.send_buffer[r.send_buffer_start .. r.send_buffer_end - r.chunk_len]; - - if (r.elide_body) { - try r.stream.writeAll(http_headers); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; - return; - } - - var header_buf: [18]u8 = undefined; - const chunk_header = std.fmt.bufPrint(&header_buf, "{x}\r\n", .{r.chunk_len}) catch unreachable; - - var iovecs: [max_trailers * 4 + 5]std.posix.iovec_const = undefined; - var iovecs_len: usize = 0; - - iovecs[iovecs_len] = .{ - .base = http_headers.ptr, - .len = http_headers.len, + pub fn writeMessageVecUnflushed(ws: *WebSocket, data: [][]const u8, op: Opcode) Writer.Error!void { + const total_len = l: { + var total_len: u64 = 0; + for (data) |iovec| total_len += iovec.len; + break :l total_len; }; - iovecs_len += 1; - - if (r.chunk_len > 0) { - iovecs[iovecs_len] = .{ - .base = chunk_header.ptr, - .len = chunk_header.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = r.send_buffer.ptr + r.send_buffer_end - r.chunk_len, - .len = r.chunk_len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - if (end_trailers) |trailers| { - iovecs[iovecs_len] = .{ - .base = "0\r\n", - .len = 3, - }; - iovecs_len += 1; - - for (trailers) |trailer| { - iovecs[iovecs_len] = .{ - .base = trailer.name.ptr, - .len = trailer.name.len, - }; - iovecs_len += 1; - - iovecs[iovecs_len] = .{ - .base = ": ", - .len = 2, - }; - iovecs_len += 1; - - if (trailer.value.len != 0) { - iovecs[iovecs_len] = .{ - .base = trailer.value.ptr, - .len = trailer.value.len, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; - } - - iovecs[iovecs_len] = .{ - .base = "\r\n", - .len = 2, - }; - iovecs_len += 1; + const out = ws.output; + try out.writeByte(@bitCast(@as(Header0, .{ + .opcode = op, + .fin = true, + }))); + switch (total_len) { + 0...125 => try out.writeByte(@bitCast(@as(Header1, .{ + .payload_len = @enumFromInt(total_len), + .mask = false, + }))), + 126...0xffff => { + try out.writeByte(@bitCast(@as(Header1, .{ + .payload_len = .len16, + .mask = false, + }))); + try out.writeInt(u16, @intCast(total_len), .big); + }, + else => { + try out.writeByte(@bitCast(@as(Header1, .{ + .payload_len = .len64, + .mask = false, + }))); + try out.writeInt(u64, total_len, .big); + }, } - - try r.stream.writevAll(iovecs[0..iovecs_len]); - r.send_buffer_start = 0; - r.send_buffer_end = 0; - r.chunk_len = 0; + try out.writeVecAll(data); } - pub fn writer(r: *Response) std.io.AnyWriter { - return .{ - .writeFn = switch (r.transfer_encoding) { - .none, .content_length => write_cl, - .chunked => write_chunked, - }, - .context = r, - }; + pub fn flush(ws: *WebSocket) Writer.Error!void { + try ws.output.flush(); } }; - -fn rebase(s: *Server, index: usize) void { - const leftover = s.read_buffer[s.next_request_start..s.read_buffer_len]; - const dest = s.read_buffer[index..][0..leftover.len]; - if (leftover.len <= s.next_request_start - index) { - @memcpy(dest, leftover); - } else { - mem.copyBackwards(u8, dest, leftover); - } - s.read_buffer_len = index + leftover.len; -} - -const std = @import("../std.zig"); -const http = std.http; -const mem = std.mem; -const net = std.net; -const Uri = std.Uri; -const assert = std.debug.assert; -const testing = std.testing; - -const Server = @This(); diff --git a/lib/std/http/WebSocket.zig b/lib/std/http/WebSocket.zig deleted file mode 100644 index b9a66cdbd660..000000000000 --- a/lib/std/http/WebSocket.zig +++ /dev/null @@ -1,246 +0,0 @@ -//! See https://tools.ietf.org/html/rfc6455 - -const builtin = @import("builtin"); -const std = @import("std"); -const WebSocket = @This(); -const assert = std.debug.assert; -const native_endian = builtin.cpu.arch.endian(); - -key: []const u8, -request: *std.http.Server.Request, -recv_fifo: std.fifo.LinearFifo(u8, .Slice), -reader: std.io.AnyReader, -response: std.http.Server.Response, -/// Number of bytes that have been peeked but not discarded yet. -outstanding_len: usize, - -pub const InitError = error{WebSocketUpgradeMissingKey} || - std.http.Server.Request.ReaderError; - -pub fn init( - request: *std.http.Server.Request, - send_buffer: []u8, - recv_buffer: []align(4) u8, -) InitError!?WebSocket { - switch (request.head.version) { - .@"HTTP/1.0" => return null, - .@"HTTP/1.1" => if (request.head.method != .GET) return null, - } - - var sec_websocket_key: ?[]const u8 = null; - var upgrade_websocket: bool = false; - var it = request.iterateHeaders(); - while (it.next()) |header| { - if (std.ascii.eqlIgnoreCase(header.name, "sec-websocket-key")) { - sec_websocket_key = header.value; - } else if (std.ascii.eqlIgnoreCase(header.name, "upgrade")) { - if (!std.ascii.eqlIgnoreCase(header.value, "websocket")) - return null; - upgrade_websocket = true; - } - } - if (!upgrade_websocket) - return null; - - const key = sec_websocket_key orelse return error.WebSocketUpgradeMissingKey; - - var sha1 = std.crypto.hash.Sha1.init(.{}); - sha1.update(key); - sha1.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); - var digest: [std.crypto.hash.Sha1.digest_length]u8 = undefined; - sha1.final(&digest); - var base64_digest: [28]u8 = undefined; - assert(std.base64.standard.Encoder.encode(&base64_digest, &digest).len == base64_digest.len); - - request.head.content_length = std.math.maxInt(u64); - - return .{ - .key = key, - .recv_fifo = std.fifo.LinearFifo(u8, .Slice).init(recv_buffer), - .reader = try request.reader(), - .response = request.respondStreaming(.{ - .send_buffer = send_buffer, - .respond_options = .{ - .status = .switching_protocols, - .extra_headers = &.{ - .{ .name = "upgrade", .value = "websocket" }, - .{ .name = "connection", .value = "upgrade" }, - .{ .name = "sec-websocket-accept", .value = &base64_digest }, - }, - .transfer_encoding = .none, - }, - }), - .request = request, - .outstanding_len = 0, - }; -} - -pub const Header0 = packed struct(u8) { - opcode: Opcode, - rsv3: u1 = 0, - rsv2: u1 = 0, - rsv1: u1 = 0, - fin: bool, -}; - -pub const Header1 = packed struct(u8) { - payload_len: enum(u7) { - len16 = 126, - len64 = 127, - _, - }, - mask: bool, -}; - -pub const Opcode = enum(u4) { - continuation = 0, - text = 1, - binary = 2, - connection_close = 8, - ping = 9, - /// "A Pong frame MAY be sent unsolicited. This serves as a unidirectional - /// heartbeat. A response to an unsolicited Pong frame is not expected." - pong = 10, - _, -}; - -pub const ReadSmallTextMessageError = error{ - ConnectionClose, - UnexpectedOpCode, - MessageTooBig, - MissingMaskBit, -} || RecvError; - -pub const SmallMessage = struct { - /// Can be text, binary, or ping. - opcode: Opcode, - data: []u8, -}; - -/// Reads the next message from the WebSocket stream, failing if the message does not fit -/// into `recv_buffer`. -pub fn readSmallMessage(ws: *WebSocket) ReadSmallTextMessageError!SmallMessage { - while (true) { - const header_bytes = (try recv(ws, 2))[0..2]; - const h0: Header0 = @bitCast(header_bytes[0]); - const h1: Header1 = @bitCast(header_bytes[1]); - - switch (h0.opcode) { - .text, .binary, .pong, .ping => {}, - .connection_close => return error.ConnectionClose, - .continuation => return error.UnexpectedOpCode, - _ => return error.UnexpectedOpCode, - } - - if (!h0.fin) return error.MessageTooBig; - if (!h1.mask) return error.MissingMaskBit; - - const len: usize = switch (h1.payload_len) { - .len16 => try recvReadInt(ws, u16), - .len64 => std.math.cast(usize, try recvReadInt(ws, u64)) orelse return error.MessageTooBig, - else => @intFromEnum(h1.payload_len), - }; - if (len > ws.recv_fifo.buf.len) return error.MessageTooBig; - - const mask: u32 = @bitCast((try recv(ws, 4))[0..4].*); - const payload = try recv(ws, len); - - // Skip pongs. - if (h0.opcode == .pong) continue; - - // The last item may contain a partial word of unused data. - const floored_len = (payload.len / 4) * 4; - const u32_payload: []align(1) u32 = @alignCast(std.mem.bytesAsSlice(u32, payload[0..floored_len])); - for (u32_payload) |*elem| elem.* ^= mask; - const mask_bytes = std.mem.asBytes(&mask)[0 .. payload.len - floored_len]; - for (payload[floored_len..], mask_bytes) |*leftover, m| leftover.* ^= m; - - return .{ - .opcode = h0.opcode, - .data = payload, - }; - } -} - -const RecvError = std.http.Server.Request.ReadError || error{EndOfStream}; - -fn recv(ws: *WebSocket, len: usize) RecvError![]u8 { - ws.recv_fifo.discard(ws.outstanding_len); - assert(len <= ws.recv_fifo.buf.len); - if (len > ws.recv_fifo.count) { - const small_buf = ws.recv_fifo.writableSlice(0); - const needed = len - ws.recv_fifo.count; - const buf = if (small_buf.len >= needed) small_buf else b: { - ws.recv_fifo.realign(); - break :b ws.recv_fifo.writableSlice(0); - }; - const n = try @as(RecvError!usize, @errorCast(ws.reader.readAtLeast(buf, needed))); - if (n < needed) return error.EndOfStream; - ws.recv_fifo.update(n); - } - ws.outstanding_len = len; - // TODO: improve the std lib API so this cast isn't necessary. - return @constCast(ws.recv_fifo.readableSliceOfLen(len)); -} - -fn recvReadInt(ws: *WebSocket, comptime I: type) !I { - const unswapped: I = @bitCast((try recv(ws, @sizeOf(I)))[0..@sizeOf(I)].*); - return switch (native_endian) { - .little => @byteSwap(unswapped), - .big => unswapped, - }; -} - -pub const WriteError = std.http.Server.Response.WriteError; - -pub fn writeMessage(ws: *WebSocket, message: []const u8, opcode: Opcode) WriteError!void { - const iovecs: [1]std.posix.iovec_const = .{ - .{ .base = message.ptr, .len = message.len }, - }; - return writeMessagev(ws, &iovecs, opcode); -} - -pub fn writeMessagev(ws: *WebSocket, message: []const std.posix.iovec_const, opcode: Opcode) WriteError!void { - const total_len = l: { - var total_len: u64 = 0; - for (message) |iovec| total_len += iovec.len; - break :l total_len; - }; - - var header_buf: [2 + 8]u8 = undefined; - header_buf[0] = @bitCast(@as(Header0, .{ - .opcode = opcode, - .fin = true, - })); - const header = switch (total_len) { - 0...125 => blk: { - header_buf[1] = @bitCast(@as(Header1, .{ - .payload_len = @enumFromInt(total_len), - .mask = false, - })); - break :blk header_buf[0..2]; - }, - 126...0xffff => blk: { - header_buf[1] = @bitCast(@as(Header1, .{ - .payload_len = .len16, - .mask = false, - })); - std.mem.writeInt(u16, header_buf[2..4], @intCast(total_len), .big); - break :blk header_buf[0..4]; - }, - else => blk: { - header_buf[1] = @bitCast(@as(Header1, .{ - .payload_len = .len64, - .mask = false, - })); - std.mem.writeInt(u64, header_buf[2..10], total_len, .big); - break :blk header_buf[0..10]; - }, - }; - - const response = &ws.response; - try response.writeAll(header); - for (message) |iovec| - try response.writeAll(iovec.base[0..iovec.len]); - try response.flush(); -} diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig deleted file mode 100644 index 797ed989adce..000000000000 --- a/lib/std/http/protocol.zig +++ /dev/null @@ -1,464 +0,0 @@ -const std = @import("../std.zig"); -const builtin = @import("builtin"); -const testing = std.testing; -const mem = std.mem; - -const assert = std.debug.assert; - -pub const State = enum { - invalid, - - // Begin header and trailer parsing states. - - start, - seen_n, - seen_r, - seen_rn, - seen_rnr, - finished, - - // Begin transfer-encoding: chunked parsing states. - - chunk_head_size, - chunk_head_ext, - chunk_head_r, - chunk_data, - chunk_data_suffix, - chunk_data_suffix_r, - - /// Returns true if the parser is in a content state (ie. not waiting for more headers). - pub fn isContent(self: State) bool { - return switch (self) { - .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => false, - .finished, .chunk_head_size, .chunk_head_ext, .chunk_head_r, .chunk_data, .chunk_data_suffix, .chunk_data_suffix_r => true, - }; - } -}; - -pub const HeadersParser = struct { - state: State = .start, - /// A fixed buffer of len `max_header_bytes`. - /// Pointers into this buffer are not stable until after a message is complete. - header_bytes_buffer: []u8, - header_bytes_len: u32, - next_chunk_length: u64, - /// `false`: headers. `true`: trailers. - done: bool, - - /// Initializes the parser with a provided buffer `buf`. - pub fn init(buf: []u8) HeadersParser { - return .{ - .header_bytes_buffer = buf, - .header_bytes_len = 0, - .done = false, - .next_chunk_length = 0, - }; - } - - /// Reinitialize the parser. - /// Asserts the parser is in the "done" state. - pub fn reset(hp: *HeadersParser) void { - assert(hp.done); - hp.* = .{ - .state = .start, - .header_bytes_buffer = hp.header_bytes_buffer, - .header_bytes_len = 0, - .done = false, - .next_chunk_length = 0, - }; - } - - pub fn get(hp: HeadersParser) []u8 { - return hp.header_bytes_buffer[0..hp.header_bytes_len]; - } - - pub fn findHeadersEnd(r: *HeadersParser, bytes: []const u8) u32 { - var hp: std.http.HeadParser = .{ - .state = switch (r.state) { - .start => .start, - .seen_n => .seen_n, - .seen_r => .seen_r, - .seen_rn => .seen_rn, - .seen_rnr => .seen_rnr, - .finished => .finished, - else => unreachable, - }, - }; - const result = hp.feed(bytes); - r.state = switch (hp.state) { - .start => .start, - .seen_n => .seen_n, - .seen_r => .seen_r, - .seen_rn => .seen_rn, - .seen_rnr => .seen_rnr, - .finished => .finished, - }; - return @intCast(result); - } - - pub fn findChunkedLen(r: *HeadersParser, bytes: []const u8) u32 { - var cp: std.http.ChunkParser = .{ - .state = switch (r.state) { - .chunk_head_size => .head_size, - .chunk_head_ext => .head_ext, - .chunk_head_r => .head_r, - .chunk_data => .data, - .chunk_data_suffix => .data_suffix, - .chunk_data_suffix_r => .data_suffix_r, - .invalid => .invalid, - else => unreachable, - }, - .chunk_len = r.next_chunk_length, - }; - const result = cp.feed(bytes); - r.state = switch (cp.state) { - .head_size => .chunk_head_size, - .head_ext => .chunk_head_ext, - .head_r => .chunk_head_r, - .data => .chunk_data, - .data_suffix => .chunk_data_suffix, - .data_suffix_r => .chunk_data_suffix_r, - .invalid => .invalid, - }; - r.next_chunk_length = cp.chunk_len; - return @intCast(result); - } - - /// Returns whether or not the parser has finished parsing a complete - /// message. A message is only complete after the entire body has been read - /// and any trailing headers have been parsed. - pub fn isComplete(r: *HeadersParser) bool { - return r.done and r.state == .finished; - } - - pub const CheckCompleteHeadError = error{HttpHeadersOversize}; - - /// Pushes `in` into the parser. Returns the number of bytes consumed by - /// the header. Any header bytes are appended to `header_bytes_buffer`. - pub fn checkCompleteHead(hp: *HeadersParser, in: []const u8) CheckCompleteHeadError!u32 { - if (hp.state.isContent()) return 0; - - const i = hp.findHeadersEnd(in); - const data = in[0..i]; - if (hp.header_bytes_len + data.len > hp.header_bytes_buffer.len) - return error.HttpHeadersOversize; - - @memcpy(hp.header_bytes_buffer[hp.header_bytes_len..][0..data.len], data); - hp.header_bytes_len += @intCast(data.len); - - return i; - } - - pub const ReadError = error{ - HttpChunkInvalid, - }; - - /// Reads the body of the message into `buffer`. Returns the number of - /// bytes placed in the buffer. - /// - /// If `skip` is true, the buffer will be unused and the body will be skipped. - /// - /// See `std.http.Client.Connection for an example of `conn`. - pub fn read(r: *HeadersParser, conn: anytype, buffer: []u8, skip: bool) !usize { - assert(r.state.isContent()); - if (r.done) return 0; - - var out_index: usize = 0; - while (true) { - switch (r.state) { - .invalid, .start, .seen_n, .seen_r, .seen_rn, .seen_rnr => unreachable, - .finished => { - const data_avail = r.next_chunk_length; - - if (skip) { - conn.fill() catch |err| switch (err) { - error.EndOfStream => { - r.done = true; - return 0; - }, - else => |e| return e, - }; - - const nread = @min(conn.peek().len, data_avail); - conn.drop(@intCast(nread)); - r.next_chunk_length -= nread; - - if (r.next_chunk_length == 0 or nread == 0) r.done = true; - - return out_index; - } else if (out_index < buffer.len) { - const out_avail = buffer.len - out_index; - - const can_read = @as(usize, @intCast(@min(data_avail, out_avail))); - const nread = try conn.read(buffer[0..can_read]); - r.next_chunk_length -= nread; - - if (r.next_chunk_length == 0 or nread == 0) r.done = true; - - return nread; - } else { - return out_index; - } - }, - .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { - conn.fill() catch |err| switch (err) { - error.EndOfStream => { - r.done = true; - return 0; - }, - else => |e| return e, - }; - - const i = r.findChunkedLen(conn.peek()); - conn.drop(@intCast(i)); - - switch (r.state) { - .invalid => return error.HttpChunkInvalid, - .chunk_data => if (r.next_chunk_length == 0) { - if (std.mem.eql(u8, conn.peek(), "\r\n")) { - r.state = .finished; - conn.drop(2); - } else { - // The trailer section is formatted identically - // to the header section. - r.state = .seen_rn; - } - r.done = true; - - return out_index; - }, - else => return out_index, - } - - continue; - }, - .chunk_data => { - const data_avail = r.next_chunk_length; - const out_avail = buffer.len - out_index; - - if (skip) { - conn.fill() catch |err| switch (err) { - error.EndOfStream => { - r.done = true; - return 0; - }, - else => |e| return e, - }; - - const nread = @min(conn.peek().len, data_avail); - conn.drop(@intCast(nread)); - r.next_chunk_length -= nread; - } else if (out_avail > 0) { - const can_read: usize = @intCast(@min(data_avail, out_avail)); - const nread = try conn.read(buffer[out_index..][0..can_read]); - r.next_chunk_length -= nread; - out_index += nread; - } - - if (r.next_chunk_length == 0) { - r.state = .chunk_data_suffix; - continue; - } - - return out_index; - }, - } - } - } -}; - -inline fn int16(array: *const [2]u8) u16 { - return @as(u16, @bitCast(array.*)); -} - -inline fn int24(array: *const [3]u8) u24 { - return @as(u24, @bitCast(array.*)); -} - -inline fn int32(array: *const [4]u8) u32 { - return @as(u32, @bitCast(array.*)); -} - -inline fn intShift(comptime T: type, x: anytype) T { - switch (@import("builtin").cpu.arch.endian()) { - .little => return @as(T, @truncate(x >> (@bitSizeOf(@TypeOf(x)) - @bitSizeOf(T)))), - .big => return @as(T, @truncate(x)), - } -} - -/// A buffered (and peekable) Connection. -const MockBufferedConnection = struct { - pub const buffer_size = 0x2000; - - conn: std.io.FixedBufferStream([]const u8), - buf: [buffer_size]u8 = undefined, - start: u16 = 0, - end: u16 = 0, - - pub fn fill(conn: *MockBufferedConnection) ReadError!void { - if (conn.end != conn.start) return; - - const nread = try conn.conn.read(conn.buf[0..]); - if (nread == 0) return error.EndOfStream; - conn.start = 0; - conn.end = @as(u16, @truncate(nread)); - } - - pub fn peek(conn: *MockBufferedConnection) []const u8 { - return conn.buf[conn.start..conn.end]; - } - - pub fn drop(conn: *MockBufferedConnection, num: u16) void { - conn.start += num; - } - - pub fn readAtLeast(conn: *MockBufferedConnection, buffer: []u8, len: usize) ReadError!usize { - var out_index: u16 = 0; - while (out_index < len) { - const available = conn.end - conn.start; - const left = buffer.len - out_index; - - if (available > 0) { - const can_read = @as(u16, @truncate(@min(available, left))); - - @memcpy(buffer[out_index..][0..can_read], conn.buf[conn.start..][0..can_read]); - out_index += can_read; - conn.start += can_read; - - continue; - } - - if (left > conn.buf.len) { - // skip the buffer if the output is large enough - return conn.conn.read(buffer[out_index..]); - } - - try conn.fill(); - } - - return out_index; - } - - pub fn read(conn: *MockBufferedConnection, buffer: []u8) ReadError!usize { - return conn.readAtLeast(buffer, 1); - } - - pub const ReadError = std.io.FixedBufferStream([]const u8).ReadError || error{EndOfStream}; - pub const Reader = std.io.GenericReader(*MockBufferedConnection, ReadError, read); - - pub fn reader(conn: *MockBufferedConnection) Reader { - return Reader{ .context = conn }; - } - - pub fn writeAll(conn: *MockBufferedConnection, buffer: []const u8) WriteError!void { - return conn.conn.writeAll(buffer); - } - - pub fn write(conn: *MockBufferedConnection, buffer: []const u8) WriteError!usize { - return conn.conn.write(buffer); - } - - pub const WriteError = std.io.FixedBufferStream([]const u8).WriteError; - pub const Writer = std.io.GenericWriter(*MockBufferedConnection, WriteError, write); - - pub fn writer(conn: *MockBufferedConnection) Writer { - return Writer{ .context = conn }; - } -}; - -test "HeadersParser.read length" { - // mock BufferedConnection for read - var headers_buf: [256]u8 = undefined; - - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nHello"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - - var buf: [8]u8 = undefined; - - r.next_chunk_length = 5; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\n", r.get()); -} - -test "HeadersParser.read chunked" { - // mock BufferedConnection for read - - var headers_buf: [256]u8 = undefined; - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\n\r\n"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - var buf: [8]u8 = undefined; - - r.state = .chunk_head_size; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n", r.get()); -} - -test "HeadersParser.read chunked trailer" { - // mock BufferedConnection for read - - var headers_buf: [256]u8 = undefined; - var r = HeadersParser.init(&headers_buf); - const data = "GET / HTTP/1.1\r\nHost: localhost\r\n\r\n2\r\nHe\r\n2\r\nll\r\n1\r\no\r\n0\r\nContent-Type: text/plain\r\n\r\n"; - - var conn: MockBufferedConnection = .{ - .conn = std.io.fixedBufferStream(data), - }; - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - var buf: [8]u8 = undefined; - - r.state = .chunk_head_size; - const len = try r.read(&conn, &buf, false); - try std.testing.expectEqual(@as(usize, 5), len); - try std.testing.expectEqualStrings("Hello", buf[0..len]); - - while (true) { // read headers - try conn.fill(); - - const nchecked = try r.checkCompleteHead(conn.peek()); - conn.drop(@intCast(nchecked)); - - if (r.state.isContent()) break; - } - - try std.testing.expectEqualStrings("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nContent-Type: text/plain\r\n\r\n", r.get()); -} diff --git a/lib/std/http/test.zig b/lib/std/http/test.zig index 33bc2eb19128..556afc092fbe 100644 --- a/lib/std/http/test.zig +++ b/lib/std/http/test.zig @@ -10,32 +10,33 @@ const expectError = std.testing.expectError; test "trailers" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [1024]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1024]u8 = undefined; + var send_buffer: [1024]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); - try expectEqual(.ready, server.state); + try expectEqual(.ready, server.reader.state); var request = try server.receiveHead(); try serve(&request); - try expectEqual(.ready, server.state); + try expectEqual(.ready, server.reader.state); } } fn serve(request: *http.Server.Request) !void { try expectEqualStrings(request.head.target, "/trailer"); - var send_buffer: [1024]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, - }); - try response.writeAll("Hello, "); + var response = try request.respondStreaming(&.{}, .{}); + try response.writer.writeAll("Hello, "); try response.flush(); - try response.writeAll("World!\n"); + try response.writer.writeAll("World!\n"); try response.flush(); try response.endChunked(.{ .trailers = &.{ @@ -58,34 +59,32 @@ test "trailers" { const uri = try std.Uri.parse(location); { - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); - - const body = try req.reader().readAllAlloc(gpa, 8192); - defer gpa.free(body); + try req.sendBodiless(); + var response = try req.receiveHead(&.{}); - try expectEqualStrings("Hello, World!\n", body); - - var it = req.response.iterateHeaders(); { + var it = response.head.iterateHeaders(); const header = it.next().?; - try expect(!it.is_trailer); try expectEqualStrings("transfer-encoding", header.name); try expectEqualStrings("chunked", header.value); + try expectEqual(null, it.next()); } + + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); + defer gpa.free(body); + + try expectEqualStrings("Hello, World!\n", body); + { + var it = response.iterateTrailers(); const header = it.next().?; - try expect(it.is_trailer); try expectEqualStrings("X-Checksum", header.name); try expectEqualStrings("aaaa", header.value); + try expectEqual(null, it.next()); } - try expectEqual(null, it.next()); } // connection has been kept alive @@ -94,19 +93,24 @@ test "trailers" { test "HTTP server handles a chunked transfer coding request" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) !void { - var header_buffer: [8192]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); - - var server = http.Server.init(conn, &header_buffer); + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [8192]u8 = undefined; + var send_buffer: [500]u8 = undefined; + const connection = try net_server.accept(); + defer connection.stream.close(); + + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); var request = try server.receiveHead(); try expect(request.head.transfer_encoding == .chunked); var buf: [128]u8 = undefined; - const n = try (try request.reader()).readAll(&buf); - try expect(mem.eql(u8, buf[0..n], "ABCD")); + var br = try request.readerExpectContinue(&.{}); + const n = try br.readSliceShort(&buf); + try expectEqualStrings("ABCD", buf[0..n]); try request.respond("message from server!\n", .{ .extra_headers = &.{ @@ -154,16 +158,20 @@ test "HTTP server handles a chunked transfer coding request" { test "echo content server" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var read_buffer: [1024]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1024]u8 = undefined; + var send_buffer: [100]u8 = undefined; - accept: while (true) { - const conn = try net_server.accept(); - defer conn.stream.close(); + accept: while (!test_server.shutting_down) { + const connection = try net_server.accept(); + defer connection.stream.close(); - var http_server = http.Server.init(conn, &read_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var http_server = http.Server.init(connection_br.interface(), &connection_bw.interface); - while (http_server.state == .ready) { + while (http_server.reader.state == .ready) { var request = http_server.receiveHead() catch |err| switch (err) { error.HttpConnectionClosing => continue :accept, else => |e| return e, @@ -173,8 +181,12 @@ test "echo content server" { } if (request.head.expect) |expect_header_value| { if (mem.eql(u8, expect_header_value, "garbage")) { - try expectError(error.HttpExpectationFailed, request.reader()); - try request.respond("", .{ .keep_alive = false }); + try expectError(error.HttpExpectationFailed, request.readerExpectContinue(&.{})); + request.head.expect = null; + try request.respond("", .{ + .keep_alive = false, + .status = .expectation_failed, + }); continue; } } @@ -195,16 +207,16 @@ test "echo content server" { // request.head.target, //}); - const body = try (try request.reader()).readAllAlloc(std.testing.allocator, 8192); + try expect(mem.startsWith(u8, request.head.target, "/echo-content")); + try expectEqualStrings("text/plain", request.head.content_type.?); + + // head strings expire here + const body = try (try request.readerExpectContinue(&.{})).allocRemaining(std.testing.allocator, .unlimited); defer std.testing.allocator.free(body); - try expect(mem.startsWith(u8, request.head.target, "/echo-content")); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", request.head.content_type.?); - var send_buffer: [100]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + var response = try request.respondStreaming(&.{}, .{ .content_length = switch (request.head.transfer_encoding) { .chunked => null, .none => len: { @@ -213,9 +225,8 @@ test "echo content server" { }, }, }); - try response.flush(); // Test an early flush to send the HTTP headers before the body. - const w = response.writer(); + const w = &response.writer; try w.writeAll("Hello, "); try w.writeAll("World!\n"); try response.end(); @@ -241,35 +252,35 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { // In this case, the response is expected to stream until the connection is // closed, indicating the end of the body. const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [1000]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1000]u8 = undefined; + var send_buffer: [500]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); - try expectEqual(.ready, server.state); + try expectEqual(.ready, server.reader.state); var request = try server.receiveHead(); try expectEqualStrings(request.head.target, "/foo"); - var send_buffer: [500]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + var buf: [30]u8 = undefined; + var response = try request.respondStreaming(&buf, .{ .respond_options = .{ .transfer_encoding = .none, }, }); - var total: usize = 0; + const w = &response.writer; for (0..500) |i| { - var buf: [30]u8 = undefined; - const line = try std.fmt.bufPrint(&buf, "{d}, ah ha ha!\n", .{i}); - try response.writeAll(line); - total += line.len; + try w.print("{d}, ah ha ha!\n", .{i}); } - try expectEqual(7390, total); + try w.flush(); try response.end(); - try expectEqual(.closing, server.state); + try expectEqual(.closing, server.reader.state); } } }); @@ -284,7 +295,7 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { var tiny_buffer: [1]u8 = undefined; // allows allocRemaining to detect limit exceeded var stream_reader = stream.reader(&tiny_buffer); - const response = try stream_reader.interface().allocRemaining(gpa, .limited(8192)); + const response = try stream_reader.interface().allocRemaining(gpa, .unlimited); defer gpa.free(response); var expected_response = std.ArrayList(u8).init(gpa); @@ -308,15 +319,20 @@ test "Server.Request.respondStreaming non-chunked, unknown content-length" { test "receiving arbitrary http headers from the client" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var read_buffer: [666]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [666]u8 = undefined; + var send_buffer: [777]u8 = undefined; var remaining: usize = 1; while (remaining != 0) : (remaining -= 1) { - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &read_buffer); - try expectEqual(.ready, server.state); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); + + try expectEqual(.ready, server.reader.state); var request = try server.receiveHead(); try expectEqualStrings("/bar", request.head.target); var it = request.iterateHeaders(); @@ -350,7 +366,7 @@ test "receiving arbitrary http headers from the client" { var tiny_buffer: [1]u8 = undefined; // allows allocRemaining to detect limit exceeded var stream_reader = stream.reader(&tiny_buffer); - const response = try stream_reader.interface().allocRemaining(gpa, .limited(8192)); + const response = try stream_reader.interface().allocRemaining(gpa, .unlimited); defer gpa.free(response); var expected_response = std.ArrayList(u8).init(gpa); @@ -368,19 +384,21 @@ test "general client/server API coverage" { return error.SkipZigTest; } - const global = struct { - var handle_new_requests = true; - }; const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var client_header_buffer: [1024]u8 = undefined; - outer: while (global.handle_new_requests) { + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1024]u8 = undefined; + var send_buffer: [100]u8 = undefined; + + outer: while (!test_server.shutting_down) { var connection = try net_server.accept(); defer connection.stream.close(); - var http_server = http.Server.init(connection, &client_header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var http_server = http.Server.init(connection_br.interface(), &connection_bw.interface); - while (http_server.state == .ready) { + while (http_server.reader.state == .ready) { var request = http_server.receiveHead() catch |err| switch (err) { error.HttpConnectionClosing => continue :outer, else => |e| return e, @@ -393,21 +411,19 @@ test "general client/server API coverage" { fn handleRequest(request: *http.Server.Request, listen_port: u16) !void { const log = std.log.scoped(.server); + const gpa = std.testing.allocator; - log.info("{f} {s} {s}", .{ - request.head.method, @tagName(request.head.version), request.head.target, - }); + log.info("{t} {t} {s}", .{ request.head.method, request.head.version, request.head.target }); + const target = try gpa.dupe(u8, request.head.target); + defer gpa.free(target); - const gpa = std.testing.allocator; - const body = try (try request.reader()).readAllAlloc(gpa, 8192); + const reader = (try request.readerExpectContinue(&.{})); + const body = try reader.allocRemaining(gpa, .unlimited); defer gpa.free(body); - var send_buffer: [100]u8 = undefined; - - if (mem.startsWith(u8, request.head.target, "/get")) { - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, - .content_length = if (mem.indexOf(u8, request.head.target, "?chunked") == null) + if (mem.startsWith(u8, target, "/get")) { + var response = try request.respondStreaming(&.{}, .{ + .content_length = if (mem.indexOf(u8, target, "?chunked") == null) 14 else null, @@ -417,27 +433,27 @@ test "general client/server API coverage" { }, }, }); - const w = response.writer(); + const w = &response.writer; try w.writeAll("Hello, "); try w.writeAll("World!\n"); try response.end(); // Writing again would cause an assertion failure. - } else if (mem.startsWith(u8, request.head.target, "/large")) { - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + } else if (mem.startsWith(u8, target, "/large")) { + var response = try request.respondStreaming(&.{}, .{ .content_length = 14 * 1024 + 14 * 10, }); try response.flush(); // Test an early flush to send the HTTP headers before the body. - const w = response.writer(); + const w = &response.writer; var i: u32 = 0; while (i < 5) : (i += 1) { try w.writeAll("Hello, World!\n"); } - try w.writeAll("Hello, World!\n" ** 1024); + var vec: [1][]const u8 = .{"Hello, World!\n"}; + try w.writeSplatAll(&vec, 1024); i = 0; while (i < 5) : (i += 1) { @@ -445,9 +461,8 @@ test "general client/server API coverage" { } try response.end(); - } else if (mem.eql(u8, request.head.target, "/redirect/1")) { - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + } else if (mem.eql(u8, target, "/redirect/1")) { + var response = try request.respondStreaming(&.{}, .{ .respond_options = .{ .status = .found, .extra_headers = &.{ @@ -456,18 +471,18 @@ test "general client/server API coverage" { }, }); - const w = response.writer(); + const w = &response.writer; try w.writeAll("Hello, "); try w.writeAll("Redirected!\n"); try response.end(); - } else if (mem.eql(u8, request.head.target, "/redirect/2")) { + } else if (mem.eql(u8, target, "/redirect/2")) { try request.respond("Hello, Redirected!\n", .{ .status = .found, .extra_headers = &.{ .{ .name = "location", .value = "/redirect/1" }, }, }); - } else if (mem.eql(u8, request.head.target, "/redirect/3")) { + } else if (mem.eql(u8, target, "/redirect/3")) { const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/redirect/2", .{ listen_port, }); @@ -479,23 +494,23 @@ test "general client/server API coverage" { .{ .name = "location", .value = location }, }, }); - } else if (mem.eql(u8, request.head.target, "/redirect/4")) { + } else if (mem.eql(u8, target, "/redirect/4")) { try request.respond("Hello, Redirected!\n", .{ .status = .found, .extra_headers = &.{ .{ .name = "location", .value = "/redirect/3" }, }, }); - } else if (mem.eql(u8, request.head.target, "/redirect/5")) { + } else if (mem.eql(u8, target, "/redirect/5")) { try request.respond("Hello, Redirected!\n", .{ .status = .found, .extra_headers = &.{ .{ .name = "location", .value = "/%2525" }, }, }); - } else if (mem.eql(u8, request.head.target, "/%2525")) { + } else if (mem.eql(u8, target, "/%2525")) { try request.respond("Encoded redirect successful!\n", .{}); - } else if (mem.eql(u8, request.head.target, "/redirect/invalid")) { + } else if (mem.eql(u8, target, "/redirect/invalid")) { const invalid_port = try getUnusedTcpPort(); const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}", .{invalid_port}); defer gpa.free(location); @@ -506,7 +521,7 @@ test "general client/server API coverage" { .{ .name = "location", .value = location }, }, }); - } else if (mem.eql(u8, request.head.target, "/empty")) { + } else if (mem.eql(u8, target, "/empty")) { try request.respond("", .{ .extra_headers = &.{ .{ .name = "empty", .value = "" }, @@ -524,17 +539,13 @@ test "general client/server API coverage" { return s.listen_address.in.getPort(); } }); - defer { - global.handle_new_requests = false; - test_server.destroy(); - } + defer test_server.destroy(); const log = std.log.scoped(.client); const gpa = std.testing.allocator; var client: http.Client = .{ .allocator = gpa }; - errdefer client.deinit(); - // defer client.deinit(); handled below + defer client.deinit(); const port = test_server.port(); @@ -544,20 +555,19 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); + + try expectEqualStrings("text/plain", response.head.content_type.?); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", req.response.content_type.?); } // connection has been kept alive @@ -569,16 +579,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192 * 1024); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqual(@as(usize, 14 * 1024 + 14 * 10), body.len); @@ -593,21 +601,20 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.HEAD, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.HEAD, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); + + try expectEqualStrings("text/plain", response.head.content_type.?); + try expectEqual(14, response.head.content_length.?); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("", body); - try expectEqualStrings("text/plain", req.response.content_type.?); - try expectEqual(14, req.response.content_length.?); } // connection has been kept alive @@ -619,20 +626,19 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); + + try expectEqualStrings("text/plain", response.head.content_type.?); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", req.response.content_type.?); } // connection has been kept alive @@ -644,21 +650,20 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.HEAD, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.HEAD, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + try expectEqualStrings("text/plain", response.head.content_type.?); + try expect(response.head.transfer_encoding == .chunked); + + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("", body); - try expectEqualStrings("text/plain", req.response.content_type.?); - try expect(req.response.transfer_encoding == .chunked); } // connection has been kept alive @@ -670,21 +675,21 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{ .keep_alive = false, }); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); + + try expectEqualStrings("text/plain", response.head.content_type.?); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); - try expectEqualStrings("text/plain", req.response.content_type.?); } // connection has been closed @@ -696,32 +701,32 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{ .extra_headers = &.{ .{ .name = "empty", .value = "" }, }, }); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - try std.testing.expectEqual(.ok, req.response.status); - - const body = try req.reader().readAllAlloc(gpa, 8192); - defer gpa.free(body); + try std.testing.expectEqual(.ok, response.head.status); - try expectEqualStrings("", body); - - var it = req.response.iterateHeaders(); + var it = response.head.iterateHeaders(); { const header = it.next().?; try expect(!it.is_trailer); try expectEqualStrings("content-length", header.name); try expectEqualStrings("0", header.value); } + + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); + defer gpa.free(body); + + try expectEqualStrings("", body); + { const header = it.next().?; try expect(!it.is_trailer); @@ -740,16 +745,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -764,16 +767,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -788,16 +789,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -812,17 +811,17 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - req.wait() catch |err| switch (err) { + try req.sendBodiless(); + if (req.receiveHead(&redirect_buffer)) |_| { + return error.TestFailed; + } else |err| switch (err) { error.TooManyHttpRedirects => {}, else => return err, - }; + } } { // redirect to encoded url @@ -831,16 +830,14 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Encoded redirect successful!\n", body); @@ -855,14 +852,12 @@ test "general client/server API coverage" { const uri = try std.Uri.parse(location); log.info("{s}", .{location}); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - const result = req.wait(); + try req.sendBodiless(); + const result = req.receiveHead(&redirect_buffer); // a proxy without an upstream is likely to return a 5xx status. if (client.http_proxy == null) { @@ -872,77 +867,40 @@ test "general client/server API coverage" { // connection has been kept alive try expect(client.http_proxy != null or client.connection_pool.free_len == 1); - - { // issue 16282 *** This test leaves the client in an invalid state, it must be last *** - const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/get", .{port}); - defer gpa.free(location); - const uri = try std.Uri.parse(location); - - const total_connections = client.connection_pool.free_size + 64; - var requests = try gpa.alloc(http.Client.Request, total_connections); - defer gpa.free(requests); - - var header_bufs = std.ArrayList([]u8).init(gpa); - defer header_bufs.deinit(); - defer for (header_bufs.items) |item| gpa.free(item); - - for (0..total_connections) |i| { - const headers_buf = try gpa.alloc(u8, 1024); - try header_bufs.append(headers_buf); - var req = try client.open(.GET, uri, .{ - .server_header_buffer = headers_buf, - }); - req.response.parser.done = true; - req.connection.?.closing = false; - requests[i] = req; - } - - for (0..total_connections) |i| { - requests[i].deinit(); - } - - // free connections should be full now - try expect(client.connection_pool.free_len == client.connection_pool.free_size); - } - - client.deinit(); - - { - global.handle_new_requests = false; - - const conn = try std.net.tcpConnectToAddress(test_server.net_server.listen_address); - conn.close(); - } } test "Server streams both reading and writing" { const test_server = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [1024]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [1024]u8 = undefined; + var send_buffer: [777]u8 = undefined; - var server = http.Server.init(conn, &header_buffer); - var request = try server.receiveHead(); - const reader = try request.reader(); + const connection = try net_server.accept(); + defer connection.stream.close(); - var send_buffer: [777]u8 = undefined; - var response = request.respondStreaming(.{ - .send_buffer = &send_buffer, + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); + var request = try server.receiveHead(); + var read_buffer: [100]u8 = undefined; + var br = try request.readerExpectContinue(&read_buffer); + var response = try request.respondStreaming(&.{}, .{ .respond_options = .{ .transfer_encoding = .none, // Causes keep_alive=false }, }); - const writer = response.writer(); + const w = &response.writer; while (true) { try response.flush(); - var buf: [100]u8 = undefined; - const n = try reader.read(&buf); - if (n == 0) break; - const sub_buf = buf[0..n]; - for (sub_buf) |*b| b.* = std.ascii.toUpper(b.*); - try writer.writeAll(sub_buf); + const buf = br.peekGreedy(1) catch |err| switch (err) { + error.EndOfStream => break, + error.ReadFailed => return error.ReadFailed, + }; + br.toss(buf.len); + for (buf) |*b| b.* = std.ascii.toUpper(b.*); + try w.writeAll(buf); } try response.end(); } @@ -952,27 +910,24 @@ test "Server streams both reading and writing" { var client: http.Client = .{ .allocator = std.testing.allocator }; defer client.deinit(); - var server_header_buffer: [555]u8 = undefined; - var req = try client.open(.POST, .{ + var redirect_buffer: [555]u8 = undefined; + var req = try client.request(.POST, .{ .scheme = "http", .host = .{ .raw = "127.0.0.1" }, .port = test_server.port(), .path = .{ .percent_encoded = "/" }, - }, .{ - .server_header_buffer = &server_header_buffer, - }); + }, .{}); defer req.deinit(); req.transfer_encoding = .chunked; - try req.send(); - try req.wait(); - - try req.writeAll("one "); - try req.writeAll("fish"); + var body_writer = try req.sendBody(&.{}); + var response = try req.receiveHead(&redirect_buffer); - try req.finish(); + try body_writer.writer.writeAll("one "); + try body_writer.writer.writeAll("fish"); + try body_writer.end(); - const body = try req.reader().readAllAlloc(std.testing.allocator, 8192); + const body = try response.reader(&.{}).allocRemaining(std.testing.allocator, .unlimited); defer std.testing.allocator.free(body); try expectEqualStrings("ONE FISH", body); @@ -987,9 +942,8 @@ fn echoTests(client: *http.Client, port: u16) !void { defer gpa.free(location); const uri = try std.Uri.parse(location); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.POST, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.POST, uri, .{ .extra_headers = &.{ .{ .name = "content-type", .value = "text/plain" }, }, @@ -998,14 +952,14 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .{ .content_length = 14 }; - try req.send(); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); - try req.finish(); + var body_writer = try req.sendBody(&.{}); + try body_writer.writer.writeAll("Hello, "); + try body_writer.writer.writeAll("World!\n"); + try body_writer.end(); - try req.wait(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -1021,9 +975,8 @@ fn echoTests(client: *http.Client, port: u16) !void { .{port}, )); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.POST, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.POST, uri, .{ .extra_headers = &.{ .{ .name = "content-type", .value = "text/plain" }, }, @@ -1032,14 +985,14 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; - try req.send(); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); - try req.finish(); + var body_writer = try req.sendBody(&.{}); + try body_writer.writer.writeAll("Hello, "); + try body_writer.writer.writeAll("World!\n"); + try body_writer.end(); - try req.wait(); + var response = try req.receiveHead(&redirect_buffer); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -1053,8 +1006,8 @@ fn echoTests(client: *http.Client, port: u16) !void { const location = try std.fmt.allocPrint(gpa, "http://127.0.0.1:{d}/echo-content#fetch", .{port}); defer gpa.free(location); - var body = std.ArrayList(u8).init(gpa); - defer body.deinit(); + var body: std.ArrayListUnmanaged(u8) = .empty; + defer body.deinit(gpa); const res = try client.fetch(.{ .location = .{ .url = location }, @@ -1063,7 +1016,7 @@ fn echoTests(client: *http.Client, port: u16) !void { .extra_headers = &.{ .{ .name = "content-type", .value = "text/plain" }, }, - .response_storage = .{ .dynamic = &body }, + .response_storage = .{ .allocator = gpa, .list = &body }, }); try expectEqual(.ok, res.status); try expectEqualStrings("Hello, World!\n", body.items); @@ -1074,9 +1027,8 @@ fn echoTests(client: *http.Client, port: u16) !void { defer gpa.free(location); const uri = try std.Uri.parse(location); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.POST, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.POST, uri, .{ .extra_headers = &.{ .{ .name = "expect", .value = "100-continue" }, .{ .name = "content-type", .value = "text/plain" }, @@ -1086,15 +1038,15 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; - try req.send(); - try req.writeAll("Hello, "); - try req.writeAll("World!\n"); - try req.finish(); + var body_writer = try req.sendBody(&.{}); + try body_writer.writer.writeAll("Hello, "); + try body_writer.writer.writeAll("World!\n"); + try body_writer.end(); - try req.wait(); - try expectEqual(.ok, req.response.status); + var response = try req.receiveHead(&redirect_buffer); + try expectEqual(.ok, response.head.status); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try response.reader(&.{}).allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("Hello, World!\n", body); @@ -1105,9 +1057,8 @@ fn echoTests(client: *http.Client, port: u16) !void { defer gpa.free(location); const uri = try std.Uri.parse(location); - var server_header_buffer: [1024]u8 = undefined; - var req = try client.open(.POST, uri, .{ - .server_header_buffer = &server_header_buffer, + var redirect_buffer: [1024]u8 = undefined; + var req = try client.request(.POST, uri, .{ .extra_headers = &.{ .{ .name = "content-type", .value = "text/plain" }, .{ .name = "expect", .value = "garbage" }, @@ -1117,23 +1068,24 @@ fn echoTests(client: *http.Client, port: u16) !void { req.transfer_encoding = .chunked; - try req.send(); - try req.wait(); - try expectEqual(.expectation_failed, req.response.status); + var body_writer = try req.sendBody(&.{}); + try body_writer.flush(); + var response = try req.receiveHead(&redirect_buffer); + try expectEqual(.expectation_failed, response.head.status); + _ = try response.reader(&.{}).discardRemaining(); } - - _ = try client.fetch(.{ - .location = .{ - .url = try std.fmt.bufPrint(&location_buffer, "http://127.0.0.1:{d}/end", .{port}), - }, - }); } const TestServer = struct { + shutting_down: bool, server_thread: std.Thread, net_server: std.net.Server, fn destroy(self: *@This()) void { + self.shutting_down = true; + const conn = std.net.tcpConnectToAddress(self.net_server.listen_address) catch @panic("shutdown failure"); + conn.close(); + self.server_thread.join(); self.net_server.deinit(); std.testing.allocator.destroy(self); @@ -1153,20 +1105,27 @@ fn createTestServer(S: type) !*TestServer { const address = try std.net.Address.parseIp("127.0.0.1", 0); const test_server = try std.testing.allocator.create(TestServer); - test_server.net_server = try address.listen(.{ .reuse_address = true }); - test_server.server_thread = try std.Thread.spawn(.{}, S.run, .{&test_server.net_server}); + test_server.* = .{ + .net_server = try address.listen(.{ .reuse_address = true }), + .server_thread = try std.Thread.spawn(.{}, S.run, .{test_server}), + .shutting_down = false, + }; return test_server; } test "redirect to different connection" { const test_server_new = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [888]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [888]u8 = undefined; + var send_buffer: [777]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); var request = try server.receiveHead(); try expectEqualStrings(request.head.target, "/ok"); try request.respond("good job, you pass", .{}); @@ -1180,18 +1139,22 @@ test "redirect to different connection" { global.other_port = test_server_new.port(); const test_server_orig = try createTestServer(struct { - fn run(net_server: *std.net.Server) anyerror!void { - var header_buffer: [999]u8 = undefined; + fn run(test_server: *TestServer) anyerror!void { + const net_server = &test_server.net_server; + var recv_buffer: [999]u8 = undefined; var send_buffer: [100]u8 = undefined; - const conn = try net_server.accept(); - defer conn.stream.close(); + const connection = try net_server.accept(); + defer connection.stream.close(); - const new_loc = try std.fmt.bufPrint(&send_buffer, "http://127.0.0.1:{d}/ok", .{ + var loc_buf: [50]u8 = undefined; + const new_loc = try std.fmt.bufPrint(&loc_buf, "http://127.0.0.1:{d}/ok", .{ global.other_port.?, }); - var server = http.Server.init(conn, &header_buffer); + var connection_br = connection.stream.reader(&recv_buffer); + var connection_bw = connection.stream.writer(&send_buffer); + var server = http.Server.init(connection_br.interface(), &connection_bw.interface); var request = try server.receiveHead(); try expectEqualStrings(request.head.target, "/help"); try request.respond("", .{ @@ -1216,16 +1179,15 @@ test "redirect to different connection" { const uri = try std.Uri.parse(location); { - var server_header_buffer: [666]u8 = undefined; - var req = try client.open(.GET, uri, .{ - .server_header_buffer = &server_header_buffer, - }); + var redirect_buffer: [666]u8 = undefined; + var req = try client.request(.GET, uri, .{}); defer req.deinit(); - try req.send(); - try req.wait(); + try req.sendBodiless(); + var response = try req.receiveHead(&redirect_buffer); + var reader = response.reader(&.{}); - const body = try req.reader().readAllAlloc(gpa, 8192); + const body = try reader.allocRemaining(gpa, .unlimited); defer gpa.free(body); try expectEqualStrings("good job, you pass", body); diff --git a/lib/std/net.zig b/lib/std/net.zig index 38c078519490..f63aa6dd0ad6 100644 --- a/lib/std/net.zig +++ b/lib/std/net.zig @@ -1944,7 +1944,7 @@ pub const Stream = struct { pub const Error = ReadError; pub fn getStream(r: *const Reader) Stream { - return r.stream; + return r.net_stream; } pub fn getError(r: *const Reader) ?Error { diff --git a/lib/std/std.zig b/lib/std/std.zig index aaae4c2eba45..5cca56262a45 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -57,7 +57,6 @@ pub const debug = @import("debug.zig"); pub const dwarf = @import("dwarf.zig"); pub const elf = @import("elf.zig"); pub const enums = @import("enums.zig"); -pub const fifo = @import("fifo.zig"); pub const fmt = @import("fmt.zig"); pub const fs = @import("fs.zig"); pub const gpu = @import("gpu.zig"); diff --git a/src/Package/Fetch.zig b/src/Package/Fetch.zig index 03593326c485..fd8c26f1e6c6 100644 --- a/src/Package/Fetch.zig +++ b/src/Package/Fetch.zig @@ -385,21 +385,23 @@ pub fn run(f: *Fetch) RunError!void { var resource: Resource = .{ .dir = dir }; return f.runResource(path_or_url, &resource, null); } else |dir_err| { + var server_header_buffer: [init_resource_buffer_size]u8 = undefined; + const file_err = if (dir_err == error.NotDir) e: { if (fs.cwd().openFile(path_or_url, .{})) |file| { - var resource: Resource = .{ .file = file }; + var resource: Resource = .{ .file = file.reader(&server_header_buffer) }; return f.runResource(path_or_url, &resource, null); } else |err| break :e err; } else dir_err; const uri = std.Uri.parse(path_or_url) catch |uri_err| { return f.fail(0, try eb.printString( - "'{s}' could not be recognized as a file path ({s}) or an URL ({s})", - .{ path_or_url, @errorName(file_err), @errorName(uri_err) }, + "'{s}' could not be recognized as a file path ({t}) or an URL ({t})", + .{ path_or_url, file_err, uri_err }, )); }; - var server_header_buffer: [header_buffer_size]u8 = undefined; - var resource = try f.initResource(uri, &server_header_buffer); + var resource: Resource = undefined; + try f.initResource(uri, &resource, &server_header_buffer); return f.runResource(try uri.path.toRawMaybeAlloc(arena), &resource, null); } }, @@ -464,8 +466,9 @@ pub fn run(f: *Fetch) RunError!void { f.location_tok, try eb.printString("invalid URI: {s}", .{@errorName(err)}), ); - var server_header_buffer: [header_buffer_size]u8 = undefined; - var resource = try f.initResource(uri, &server_header_buffer); + var buffer: [init_resource_buffer_size]u8 = undefined; + var resource: Resource = undefined; + try f.initResource(uri, &resource, &buffer); return f.runResource(try uri.path.toRawMaybeAlloc(arena), &resource, remote.hash); } @@ -866,8 +869,8 @@ fn fail(f: *Fetch, msg_tok: std.zig.Ast.TokenIndex, msg_str: u32) RunError { } const Resource = union(enum) { - file: fs.File, - http_request: std.http.Client.Request, + file: fs.File.Reader, + http_request: HttpRequest, git: Git, dir: fs.Dir, @@ -877,10 +880,16 @@ const Resource = union(enum) { want_oid: git.Oid, }; + const HttpRequest = struct { + request: std.http.Client.Request, + response: std.http.Client.Response, + buffer: []u8, + }; + fn deinit(resource: *Resource) void { switch (resource.*) { - .file => |*file| file.close(), - .http_request => |*req| req.deinit(), + .file => |*file_reader| file_reader.file.close(), + .http_request => |*http_request| http_request.request.deinit(), .git => |*git_resource| { git_resource.fetch_stream.deinit(); git_resource.session.deinit(); @@ -890,21 +899,13 @@ const Resource = union(enum) { resource.* = undefined; } - fn reader(resource: *Resource) std.io.AnyReader { - return .{ - .context = resource, - .readFn = read, - }; - } - - fn read(context: *const anyopaque, buffer: []u8) anyerror!usize { - const resource: *Resource = @ptrCast(@alignCast(@constCast(context))); - switch (resource.*) { - .file => |*f| return f.read(buffer), - .http_request => |*r| return r.read(buffer), - .git => |*g| return g.fetch_stream.read(buffer), + fn reader(resource: *Resource) *std.Io.Reader { + return switch (resource.*) { + .file => |*file_reader| return &file_reader.interface, + .http_request => |*http_request| return http_request.response.reader(http_request.buffer), + .git => |*g| return &g.fetch_stream.reader, .dir => unreachable, - } + }; } }; @@ -967,20 +968,22 @@ const FileType = enum { } }; -const header_buffer_size = 16 * 1024; +const init_resource_buffer_size = git.Packet.max_data_length; -fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Resource { +fn initResource(f: *Fetch, uri: std.Uri, resource: *Resource, reader_buffer: []u8) RunError!void { const gpa = f.arena.child_allocator; const arena = f.arena.allocator(); const eb = &f.error_bundle; if (ascii.eqlIgnoreCase(uri.scheme, "file")) { const path = try uri.path.toRawMaybeAlloc(arena); - return .{ .file = f.parent_package_root.openFile(path, .{}) catch |err| { - return f.fail(f.location_tok, try eb.printString("unable to open '{f}{s}': {s}", .{ - f.parent_package_root, path, @errorName(err), + const file = f.parent_package_root.openFile(path, .{}) catch |err| { + return f.fail(f.location_tok, try eb.printString("unable to open '{f}{s}': {t}", .{ + f.parent_package_root, path, err, })); - } }; + }; + resource.* = .{ .file = file.reader(reader_buffer) }; + return; } const http_client = f.job_queue.http_client; @@ -988,37 +991,35 @@ fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Re if (ascii.eqlIgnoreCase(uri.scheme, "http") or ascii.eqlIgnoreCase(uri.scheme, "https")) { - var req = http_client.open(.GET, uri, .{ - .server_header_buffer = server_header_buffer, - }) catch |err| { - return f.fail(f.location_tok, try eb.printString( - "unable to connect to server: {s}", - .{@errorName(err)}, - )); - }; - errdefer req.deinit(); // releases more than memory - - req.send() catch |err| { - return f.fail(f.location_tok, try eb.printString( - "HTTP request failed: {s}", - .{@errorName(err)}, - )); - }; - req.wait() catch |err| { - return f.fail(f.location_tok, try eb.printString( - "invalid HTTP response: {s}", - .{@errorName(err)}, - )); + resource.* = .{ .http_request = .{ + .request = http_client.request(.GET, uri, .{}) catch |err| + return f.fail(f.location_tok, try eb.printString("unable to connect to server: {t}", .{err})), + .response = undefined, + .buffer = reader_buffer, + } }; + const request = &resource.http_request.request; + errdefer request.deinit(); + + request.sendBodiless() catch |err| + return f.fail(f.location_tok, try eb.printString("HTTP request failed: {t}", .{err})); + + var redirect_buffer: [1024]u8 = undefined; + const response = &resource.http_request.response; + response.* = request.receiveHead(&redirect_buffer) catch |err| switch (err) { + error.ReadFailed => { + return f.fail(f.location_tok, try eb.printString("HTTP response read failure: {t}", .{ + request.connection.?.getReadError().?, + })); + }, + else => |e| return f.fail(f.location_tok, try eb.printString("invalid HTTP response: {t}", .{e})), }; - if (req.response.status != .ok) { - return f.fail(f.location_tok, try eb.printString( - "bad HTTP response code: '{d} {s}'", - .{ @intFromEnum(req.response.status), req.response.status.phrase() orelse "" }, - )); - } + if (response.head.status != .ok) return f.fail(f.location_tok, try eb.printString( + "bad HTTP response code: '{d} {s}'", + .{ response.head.status, response.head.status.phrase() orelse "" }, + )); - return .{ .http_request = req }; + return; } if (ascii.eqlIgnoreCase(uri.scheme, "git+http") or @@ -1026,7 +1027,7 @@ fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Re { var transport_uri = uri; transport_uri.scheme = uri.scheme["git+".len..]; - var session = git.Session.init(gpa, http_client, transport_uri, server_header_buffer) catch |err| { + var session = git.Session.init(gpa, http_client, transport_uri, reader_buffer) catch |err| { return f.fail(f.location_tok, try eb.printString( "unable to discover remote git server capabilities: {s}", .{@errorName(err)}, @@ -1042,16 +1043,12 @@ fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Re const want_ref_head = try std.fmt.allocPrint(arena, "refs/heads/{s}", .{want_ref}); const want_ref_tag = try std.fmt.allocPrint(arena, "refs/tags/{s}", .{want_ref}); - var ref_iterator = session.listRefs(.{ + var ref_iterator: git.Session.RefIterator = undefined; + session.listRefs(&ref_iterator, .{ .ref_prefixes = &.{ want_ref, want_ref_head, want_ref_tag }, .include_peeled = true, - .server_header_buffer = server_header_buffer, - }) catch |err| { - return f.fail(f.location_tok, try eb.printString( - "unable to list refs: {s}", - .{@errorName(err)}, - )); - }; + .buffer = reader_buffer, + }) catch |err| return f.fail(f.location_tok, try eb.printString("unable to list refs: {t}", .{err})); defer ref_iterator.deinit(); while (ref_iterator.next() catch |err| { return f.fail(f.location_tok, try eb.printString( @@ -1089,25 +1086,21 @@ fn initResource(f: *Fetch, uri: std.Uri, server_header_buffer: []u8) RunError!Re var want_oid_buf: [git.Oid.max_formatted_length]u8 = undefined; _ = std.fmt.bufPrint(&want_oid_buf, "{f}", .{want_oid}) catch unreachable; - var fetch_stream = session.fetch(&.{&want_oid_buf}, server_header_buffer) catch |err| { - return f.fail(f.location_tok, try eb.printString( - "unable to create fetch stream: {s}", - .{@errorName(err)}, - )); + var fetch_stream: git.Session.FetchStream = undefined; + session.fetch(&fetch_stream, &.{&want_oid_buf}, reader_buffer) catch |err| { + return f.fail(f.location_tok, try eb.printString("unable to create fetch stream: {t}", .{err})); }; errdefer fetch_stream.deinit(); - return .{ .git = .{ + resource.* = .{ .git = .{ .session = session, .fetch_stream = fetch_stream, .want_oid = want_oid, } }; + return; } - return f.fail(f.location_tok, try eb.printString( - "unsupported URL scheme: {s}", - .{uri.scheme}, - )); + return f.fail(f.location_tok, try eb.printString("unsupported URL scheme: {s}", .{uri.scheme})); } fn unpackResource( @@ -1121,9 +1114,11 @@ fn unpackResource( .file => FileType.fromPath(uri_path) orelse return f.fail(f.location_tok, try eb.printString("unknown file type: '{s}'", .{uri_path})), - .http_request => |req| ft: { + .http_request => |*http_request| ft: { + const head = &http_request.response.head; + // Content-Type takes first precedence. - const content_type = req.response.content_type orelse + const content_type = head.content_type orelse return f.fail(f.location_tok, try eb.addString("missing 'Content-Type' header")); // Extract the MIME type, ignoring charset and boundary directives @@ -1165,7 +1160,7 @@ fn unpackResource( } // Next, the filename from 'content-disposition: attachment' takes precedence. - if (req.response.content_disposition) |cd_header| { + if (head.content_disposition) |cd_header| { break :ft FileType.fromContentDisposition(cd_header) orelse { return f.fail(f.location_tok, try eb.printString( "unsupported Content-Disposition header value: '{s}' for Content-Type=application/octet-stream", @@ -1176,10 +1171,7 @@ fn unpackResource( // Finally, the path from the URI is used. break :ft FileType.fromPath(uri_path) orelse { - return f.fail(f.location_tok, try eb.printString( - "unknown file type: '{s}'", - .{uri_path}, - )); + return f.fail(f.location_tok, try eb.printString("unknown file type: '{s}'", .{uri_path})); }; }, @@ -1187,10 +1179,9 @@ fn unpackResource( .dir => |dir| { f.recursiveDirectoryCopy(dir, tmp_directory.handle) catch |err| { - return f.fail(f.location_tok, try eb.printString( - "unable to copy directory '{s}': {s}", - .{ uri_path, @errorName(err) }, - )); + return f.fail(f.location_tok, try eb.printString("unable to copy directory '{s}': {t}", .{ + uri_path, err, + })); }; return .{}; }, @@ -1198,27 +1189,17 @@ fn unpackResource( switch (file_type) { .tar => { - var adapter_buffer: [1024]u8 = undefined; - var adapter = resource.reader().adaptToNewApi(&adapter_buffer); - return unpackTarball(f, tmp_directory.handle, &adapter.new_interface); + return unpackTarball(f, tmp_directory.handle, resource.reader()); }, .@"tar.gz" => { - var adapter_buffer: [std.crypto.tls.max_ciphertext_record_len]u8 = undefined; - var adapter = resource.reader().adaptToNewApi(&adapter_buffer); var flate_buffer: [std.compress.flate.max_window_len]u8 = undefined; - var decompress: std.compress.flate.Decompress = .init(&adapter.new_interface, .gzip, &flate_buffer); + var decompress: std.compress.flate.Decompress = .init(resource.reader(), .gzip, &flate_buffer); return try unpackTarball(f, tmp_directory.handle, &decompress.reader); }, .@"tar.xz" => { const gpa = f.arena.child_allocator; - const reader = resource.reader(); - var br = std.io.bufferedReaderSize(std.crypto.tls.max_ciphertext_record_len, reader); - var dcp = std.compress.xz.decompress(gpa, br.reader()) catch |err| { - return f.fail(f.location_tok, try eb.printString( - "unable to decompress tarball: {s}", - .{@errorName(err)}, - )); - }; + var dcp = std.compress.xz.decompress(gpa, resource.reader().adaptToOldInterface()) catch |err| + return f.fail(f.location_tok, try eb.printString("unable to decompress tarball: {t}", .{err})); defer dcp.deinit(); var adapter_buffer: [1024]u8 = undefined; var adapter = dcp.reader().adaptToNewApi(&adapter_buffer); @@ -1227,9 +1208,7 @@ fn unpackResource( .@"tar.zst" => { const window_size = std.compress.zstd.default_window_len; const window_buffer = try f.arena.allocator().create([window_size]u8); - var adapter_buffer: [std.crypto.tls.max_ciphertext_record_len]u8 = undefined; - var adapter = resource.reader().adaptToNewApi(&adapter_buffer); - var decompress: std.compress.zstd.Decompress = .init(&adapter.new_interface, window_buffer, .{ + var decompress: std.compress.zstd.Decompress = .init(resource.reader(), window_buffer, .{ .verify_checksum = false, }); return try unpackTarball(f, tmp_directory.handle, &decompress.reader); @@ -1237,12 +1216,15 @@ fn unpackResource( .git_pack => return unpackGitPack(f, tmp_directory.handle, &resource.git) catch |err| switch (err) { error.FetchFailed => return error.FetchFailed, error.OutOfMemory => return error.OutOfMemory, - else => |e| return f.fail(f.location_tok, try eb.printString( - "unable to unpack git files: {s}", - .{@errorName(e)}, + else => |e| return f.fail(f.location_tok, try eb.printString("unable to unpack git files: {t}", .{e})), + }, + .zip => return unzip(f, tmp_directory.handle, resource.reader()) catch |err| switch (err) { + error.ReadFailed => return f.fail(f.location_tok, try eb.printString( + "failed reading resource: {t}", + .{err}, )), + else => |e| return e, }, - .zip => return try unzip(f, tmp_directory.handle, resource.reader()), } } @@ -1277,99 +1259,69 @@ fn unpackTarball(f: *Fetch, out_dir: fs.Dir, reader: *std.Io.Reader) RunError!Un return res; } -fn unzip(f: *Fetch, out_dir: fs.Dir, reader: anytype) RunError!UnpackResult { +fn unzip(f: *Fetch, out_dir: fs.Dir, reader: *std.Io.Reader) error{ ReadFailed, OutOfMemory, FetchFailed }!UnpackResult { // We write the entire contents to a file first because zip files // must be processed back to front and they could be too large to // load into memory. const cache_root = f.job_queue.global_cache; - - // TODO: the downside of this solution is if we get a failure/crash/oom/power out - // during this process, we leave behind a zip file that would be - // difficult to know if/when it can be cleaned up. - // Might be worth it to use a mechanism that enables other processes - // to see if the owning process of a file is still alive (on linux this - // can be done with file locks). - // Coupled with this mechansism, we could also use slots (i.e. zig-cache/tmp/0, - // zig-cache/tmp/1, etc) which would mean that subsequent runs would - // automatically clean up old dead files. - // This could all be done with a simple TmpFile abstraction. const prefix = "tmp/"; const suffix = ".zip"; - - const random_bytes_count = 20; - const random_path_len = comptime std.fs.base64_encoder.calcSize(random_bytes_count); - var zip_path: [prefix.len + random_path_len + suffix.len]u8 = undefined; - @memcpy(zip_path[0..prefix.len], prefix); - @memcpy(zip_path[prefix.len + random_path_len ..], suffix); - { - var random_bytes: [random_bytes_count]u8 = undefined; - std.crypto.random.bytes(&random_bytes); - _ = std.fs.base64_encoder.encode( - zip_path[prefix.len..][0..random_path_len], - &random_bytes, - ); - } - - defer cache_root.handle.deleteFile(&zip_path) catch {}; - const eb = &f.error_bundle; - - { - var zip_file = cache_root.handle.createFile( - &zip_path, - .{}, - ) catch |err| return f.fail(f.location_tok, try eb.printString( - "failed to create tmp zip file: {s}", - .{@errorName(err)}, - )); - defer zip_file.close(); - var buf: [4096]u8 = undefined; - while (true) { - const len = reader.readAll(&buf) catch |err| return f.fail(f.location_tok, try eb.printString( - "read zip stream failed: {s}", - .{@errorName(err)}, - )); - if (len == 0) break; - zip_file.deprecatedWriter().writeAll(buf[0..len]) catch |err| return f.fail(f.location_tok, try eb.printString( - "write temporary zip file failed: {s}", - .{@errorName(err)}, - )); - } - } + const random_len = @sizeOf(u64) * 2; + + var zip_path: [prefix.len + random_len + suffix.len]u8 = undefined; + zip_path[0..prefix.len].* = prefix.*; + zip_path[prefix.len + random_len ..].* = suffix.*; + + var zip_file = while (true) { + const random_integer = std.crypto.random.int(u64); + zip_path[prefix.len..][0..random_len].* = std.fmt.hex(random_integer); + + break cache_root.handle.createFile(&zip_path, .{ + .exclusive = true, + .read = true, + }) catch |err| switch (err) { + error.PathAlreadyExists => continue, + else => |e| return f.fail( + f.location_tok, + try eb.printString("failed to create temporary zip file: {t}", .{e}), + ), + }; + }; + defer zip_file.close(); + var zip_file_buffer: [4096]u8 = undefined; + var zip_file_reader = b: { + var zip_file_writer = zip_file.writer(&zip_file_buffer); + + _ = reader.streamRemaining(&zip_file_writer.interface) catch |err| switch (err) { + error.ReadFailed => return error.ReadFailed, + error.WriteFailed => return f.fail( + f.location_tok, + try eb.printString("failed writing temporary zip file: {t}", .{err}), + ), + }; + zip_file_writer.interface.flush() catch |err| return f.fail( + f.location_tok, + try eb.printString("failed writing temporary zip file: {t}", .{err}), + ); + break :b zip_file_writer.moveToReader(); + }; var diagnostics: std.zip.Diagnostics = .{ .allocator = f.arena.allocator() }; // no need to deinit since we are using an arena allocator - { - var zip_file = cache_root.handle.openFile( - &zip_path, - .{}, - ) catch |err| return f.fail(f.location_tok, try eb.printString( - "failed to open temporary zip file: {s}", - .{@errorName(err)}, - )); - defer zip_file.close(); - - var zip_file_buffer: [1024]u8 = undefined; - var zip_file_reader = zip_file.reader(&zip_file_buffer); - - std.zip.extract(out_dir, &zip_file_reader, .{ - .allow_backslashes = true, - .diagnostics = &diagnostics, - }) catch |err| return f.fail(f.location_tok, try eb.printString( - "zip extract failed: {s}", - .{@errorName(err)}, - )); - } + zip_file_reader.seekTo(0) catch |err| + return f.fail(f.location_tok, try eb.printString("failed to seek temporary zip file: {t}", .{err})); + std.zip.extract(out_dir, &zip_file_reader, .{ + .allow_backslashes = true, + .diagnostics = &diagnostics, + }) catch |err| return f.fail(f.location_tok, try eb.printString("zip extract failed: {t}", .{err})); - cache_root.handle.deleteFile(&zip_path) catch |err| return f.fail(f.location_tok, try eb.printString( - "delete temporary zip failed: {s}", - .{@errorName(err)}, - )); + cache_root.handle.deleteFile(&zip_path) catch |err| + return f.fail(f.location_tok, try eb.printString("delete temporary zip failed: {t}", .{err})); - const res: UnpackResult = .{ .root_dir = diagnostics.root_dir }; - return res; + return .{ .root_dir = diagnostics.root_dir }; } fn unpackGitPack(f: *Fetch, out_dir: fs.Dir, resource: *Resource.Git) anyerror!UnpackResult { @@ -1387,10 +1339,13 @@ fn unpackGitPack(f: *Fetch, out_dir: fs.Dir, resource: *Resource.Git) anyerror!U var pack_file = try pack_dir.createFile("pkg.pack", .{ .read = true }); defer pack_file.close(); var pack_file_buffer: [4096]u8 = undefined; - var fifo = std.fifo.LinearFifo(u8, .{ .Slice = {} }).init(&pack_file_buffer); - try fifo.pump(resource.fetch_stream.reader(), pack_file.deprecatedWriter()); - - var pack_file_reader = pack_file.reader(&pack_file_buffer); + var pack_file_reader = b: { + var pack_file_writer = pack_file.writer(&pack_file_buffer); + const fetch_reader = &resource.fetch_stream.reader; + _ = try fetch_reader.streamRemaining(&pack_file_writer.interface); + try pack_file_writer.interface.flush(); + break :b pack_file_writer.moveToReader(); + }; var index_file = try pack_dir.createFile("pkg.idx", .{ .read = true }); defer index_file.close(); diff --git a/src/Package/Fetch/git.zig b/src/Package/Fetch/git.zig index 6ff951014b57..390b977c3a40 100644 --- a/src/Package/Fetch/git.zig +++ b/src/Package/Fetch/git.zig @@ -585,17 +585,17 @@ const ObjectCache = struct { /// [protocol-common](https://git-scm.com/docs/protocol-common). The special /// meanings of the delimiter and response-end packets are documented in /// [protocol-v2](https://git-scm.com/docs/protocol-v2). -const Packet = union(enum) { +pub const Packet = union(enum) { flush, delimiter, response_end, data: []const u8, - const max_data_length = 65516; + pub const max_data_length = 65516; /// Reads a packet in pkt-line format. - fn read(reader: anytype, buf: *[max_data_length]u8) !Packet { - const length = std.fmt.parseUnsigned(u16, &try reader.readBytesNoEof(4), 16) catch return error.InvalidPacket; + fn read(reader: *std.Io.Reader) !Packet { + const length = std.fmt.parseUnsigned(u16, try reader.take(4), 16) catch return error.InvalidPacket; switch (length) { 0 => return .flush, 1 => return .delimiter, @@ -603,13 +603,11 @@ const Packet = union(enum) { 3 => return error.InvalidPacket, else => if (length - 4 > max_data_length) return error.InvalidPacket, } - const data = buf[0 .. length - 4]; - try reader.readNoEof(data); - return .{ .data = data }; + return .{ .data = try reader.take(length - 4) }; } /// Writes a packet in pkt-line format. - fn write(packet: Packet, writer: anytype) !void { + fn write(packet: Packet, writer: *std.Io.Writer) !void { switch (packet) { .flush => try writer.writeAll("0000"), .delimiter => try writer.writeAll("0001"), @@ -657,8 +655,10 @@ pub const Session = struct { allocator: Allocator, transport: *std.http.Client, uri: std.Uri, - http_headers_buffer: []u8, + /// Asserted to be at least `Packet.max_data_length` + response_buffer: []u8, ) !Session { + assert(response_buffer.len >= Packet.max_data_length); var session: Session = .{ .transport = transport, .location = try .init(allocator, uri), @@ -668,7 +668,8 @@ pub const Session = struct { .allocator = allocator, }; errdefer session.deinit(); - var capability_iterator = try session.getCapabilities(http_headers_buffer); + var capability_iterator: CapabilityIterator = undefined; + try session.getCapabilities(&capability_iterator, response_buffer); defer capability_iterator.deinit(); while (try capability_iterator.next()) |capability| { if (mem.eql(u8, capability.key, "agent")) { @@ -743,7 +744,8 @@ pub const Session = struct { /// /// The `session.location` is updated if the server returns a redirect, so /// that subsequent session functions do not need to handle redirects. - fn getCapabilities(session: *Session, http_headers_buffer: []u8) !CapabilityIterator { + fn getCapabilities(session: *Session, it: *CapabilityIterator, response_buffer: []u8) !void { + assert(response_buffer.len >= Packet.max_data_length); var info_refs_uri = session.location.uri; { const session_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{ @@ -757,19 +759,22 @@ pub const Session = struct { info_refs_uri.fragment = null; const max_redirects = 3; - var request = try session.transport.open(.GET, info_refs_uri, .{ - .redirect_behavior = @enumFromInt(max_redirects), - .server_header_buffer = http_headers_buffer, - .extra_headers = &.{ - .{ .name = "Git-Protocol", .value = "version=2" }, - }, - }); - errdefer request.deinit(); - try request.send(); - try request.finish(); + it.* = .{ + .request = try session.transport.request(.GET, info_refs_uri, .{ + .redirect_behavior = .init(max_redirects), + .extra_headers = &.{ + .{ .name = "Git-Protocol", .value = "version=2" }, + }, + }), + .reader = undefined, + }; + errdefer it.deinit(); + const request = &it.request; + try request.sendBodiless(); - try request.wait(); - if (request.response.status != .ok) return error.ProtocolError; + var redirect_buffer: [1024]u8 = undefined; + var response = try request.receiveHead(&redirect_buffer); + if (response.head.status != .ok) return error.ProtocolError; const any_redirects_occurred = request.redirect_behavior.remaining() < max_redirects; if (any_redirects_occurred) { const request_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{ @@ -784,8 +789,7 @@ pub const Session = struct { session.location = new_location; } - const reader = request.reader(); - var buf: [Packet.max_data_length]u8 = undefined; + it.reader = response.reader(response_buffer); var state: enum { response_start, response_content } = .response_start; while (true) { // Some Git servers (at least GitHub) include an additional @@ -795,15 +799,15 @@ pub const Session = struct { // Thus, we need to skip any such useless additional responses // before we get the one we're actually looking for. The responses // will be delimited by flush packets. - const packet = Packet.read(reader, &buf) catch |e| switch (e) { + const packet = Packet.read(it.reader) catch |err| switch (err) { error.EndOfStream => return error.UnsupportedProtocol, // 'version 2' packet not found - else => |other| return other, + else => |e| return e, }; switch (packet) { .flush => state = .response_start, .data => |data| switch (state) { .response_start => if (mem.eql(u8, Packet.normalizeText(data), "version 2")) { - return .{ .request = request }; + return; } else { state = .response_content; }, @@ -816,7 +820,7 @@ pub const Session = struct { const CapabilityIterator = struct { request: std.http.Client.Request, - buf: [Packet.max_data_length]u8 = undefined, + reader: *std.Io.Reader, const Capability = struct { key: []const u8, @@ -830,13 +834,13 @@ pub const Session = struct { } }; - fn deinit(iterator: *CapabilityIterator) void { - iterator.request.deinit(); - iterator.* = undefined; + fn deinit(it: *CapabilityIterator) void { + it.request.deinit(); + it.* = undefined; } - fn next(iterator: *CapabilityIterator) !?Capability { - switch (try Packet.read(iterator.request.reader(), &iterator.buf)) { + fn next(it: *CapabilityIterator) !?Capability { + switch (try Packet.read(it.reader)) { .flush => return null, .data => |data| return Capability.parse(Packet.normalizeText(data)), else => return error.UnexpectedPacket, @@ -854,11 +858,13 @@ pub const Session = struct { include_symrefs: bool = false, /// Whether to include the peeled object ID for returned tag refs. include_peeled: bool = false, - server_header_buffer: []u8, + /// Asserted to be at least `Packet.max_data_length`. + buffer: []u8, }; /// Returns an iterator over refs known to the server. - pub fn listRefs(session: Session, options: ListRefsOptions) !RefIterator { + pub fn listRefs(session: Session, it: *RefIterator, options: ListRefsOptions) !void { + assert(options.buffer.len >= Packet.max_data_length); var upload_pack_uri = session.location.uri; { const session_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{ @@ -871,59 +877,56 @@ pub const Session = struct { upload_pack_uri.query = null; upload_pack_uri.fragment = null; - var body: std.ArrayListUnmanaged(u8) = .empty; - defer body.deinit(session.allocator); - const body_writer = body.writer(session.allocator); - try Packet.write(.{ .data = "command=ls-refs\n" }, body_writer); + var body: std.Io.Writer = .fixed(options.buffer); + try Packet.write(.{ .data = "command=ls-refs\n" }, &body); if (session.supports_agent) { - try Packet.write(.{ .data = agent_capability }, body_writer); + try Packet.write(.{ .data = agent_capability }, &body); } { - const object_format_packet = try std.fmt.allocPrint(session.allocator, "object-format={s}\n", .{@tagName(session.object_format)}); + const object_format_packet = try std.fmt.allocPrint(session.allocator, "object-format={t}\n", .{ + session.object_format, + }); defer session.allocator.free(object_format_packet); - try Packet.write(.{ .data = object_format_packet }, body_writer); + try Packet.write(.{ .data = object_format_packet }, &body); } - try Packet.write(.delimiter, body_writer); + try Packet.write(.delimiter, &body); for (options.ref_prefixes) |ref_prefix| { const ref_prefix_packet = try std.fmt.allocPrint(session.allocator, "ref-prefix {s}\n", .{ref_prefix}); defer session.allocator.free(ref_prefix_packet); - try Packet.write(.{ .data = ref_prefix_packet }, body_writer); + try Packet.write(.{ .data = ref_prefix_packet }, &body); } if (options.include_symrefs) { - try Packet.write(.{ .data = "symrefs\n" }, body_writer); + try Packet.write(.{ .data = "symrefs\n" }, &body); } if (options.include_peeled) { - try Packet.write(.{ .data = "peel\n" }, body_writer); + try Packet.write(.{ .data = "peel\n" }, &body); } - try Packet.write(.flush, body_writer); - - var request = try session.transport.open(.POST, upload_pack_uri, .{ - .redirect_behavior = .unhandled, - .server_header_buffer = options.server_header_buffer, - .extra_headers = &.{ - .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" }, - .{ .name = "Git-Protocol", .value = "version=2" }, - }, - }); - errdefer request.deinit(); - request.transfer_encoding = .{ .content_length = body.items.len }; - try request.send(); - try request.writeAll(body.items); - try request.finish(); - - try request.wait(); - if (request.response.status != .ok) return error.ProtocolError; - - return .{ + try Packet.write(.flush, &body); + + it.* = .{ + .request = try session.transport.request(.POST, upload_pack_uri, .{ + .redirect_behavior = .unhandled, + .extra_headers = &.{ + .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" }, + .{ .name = "Git-Protocol", .value = "version=2" }, + }, + }), + .reader = undefined, .format = session.object_format, - .request = request, }; + const request = &it.request; + errdefer request.deinit(); + try request.sendBodyComplete(body.buffered()); + + var response = try request.receiveHead(options.buffer); + if (response.head.status != .ok) return error.ProtocolError; + it.reader = response.reader(options.buffer); } pub const RefIterator = struct { format: Oid.Format, request: std.http.Client.Request, - buf: [Packet.max_data_length]u8 = undefined, + reader: *std.Io.Reader, pub const Ref = struct { oid: Oid, @@ -937,13 +940,13 @@ pub const Session = struct { iterator.* = undefined; } - pub fn next(iterator: *RefIterator) !?Ref { - switch (try Packet.read(iterator.request.reader(), &iterator.buf)) { + pub fn next(it: *RefIterator) !?Ref { + switch (try Packet.read(it.reader)) { .flush => return null, .data => |data| { const ref_data = Packet.normalizeText(data); const oid_sep_pos = mem.indexOfScalar(u8, ref_data, ' ') orelse return error.InvalidRefPacket; - const oid = Oid.parse(iterator.format, data[0..oid_sep_pos]) catch return error.InvalidRefPacket; + const oid = Oid.parse(it.format, data[0..oid_sep_pos]) catch return error.InvalidRefPacket; const name_sep_pos = mem.indexOfScalarPos(u8, ref_data, oid_sep_pos + 1, ' ') orelse ref_data.len; const name = ref_data[oid_sep_pos + 1 .. name_sep_pos]; @@ -957,7 +960,7 @@ pub const Session = struct { if (mem.startsWith(u8, attribute, "symref-target:")) { symref_target = attribute["symref-target:".len..]; } else if (mem.startsWith(u8, attribute, "peeled:")) { - peeled = Oid.parse(iterator.format, attribute["peeled:".len..]) catch return error.InvalidRefPacket; + peeled = Oid.parse(it.format, attribute["peeled:".len..]) catch return error.InvalidRefPacket; } last_sep_pos = next_sep_pos; } @@ -973,9 +976,12 @@ pub const Session = struct { /// performed if the server supports it. pub fn fetch( session: Session, + fs: *FetchStream, wants: []const []const u8, - http_headers_buffer: []u8, - ) !FetchStream { + /// Asserted to be at least `Packet.max_data_length`. + response_buffer: []u8, + ) !void { + assert(response_buffer.len >= Packet.max_data_length); var upload_pack_uri = session.location.uri; { const session_uri_path = try std.fmt.allocPrint(session.allocator, "{f}", .{ @@ -988,63 +994,71 @@ pub const Session = struct { upload_pack_uri.query = null; upload_pack_uri.fragment = null; - var body: std.ArrayListUnmanaged(u8) = .empty; - defer body.deinit(session.allocator); - const body_writer = body.writer(session.allocator); - try Packet.write(.{ .data = "command=fetch\n" }, body_writer); + var body: std.Io.Writer = .fixed(response_buffer); + try Packet.write(.{ .data = "command=fetch\n" }, &body); if (session.supports_agent) { - try Packet.write(.{ .data = agent_capability }, body_writer); + try Packet.write(.{ .data = agent_capability }, &body); } { const object_format_packet = try std.fmt.allocPrint(session.allocator, "object-format={s}\n", .{@tagName(session.object_format)}); defer session.allocator.free(object_format_packet); - try Packet.write(.{ .data = object_format_packet }, body_writer); + try Packet.write(.{ .data = object_format_packet }, &body); } - try Packet.write(.delimiter, body_writer); + try Packet.write(.delimiter, &body); // Our packfile parser supports the OFS_DELTA object type - try Packet.write(.{ .data = "ofs-delta\n" }, body_writer); + try Packet.write(.{ .data = "ofs-delta\n" }, &body); // We do not currently convey server progress information to the user - try Packet.write(.{ .data = "no-progress\n" }, body_writer); + try Packet.write(.{ .data = "no-progress\n" }, &body); if (session.supports_shallow) { - try Packet.write(.{ .data = "deepen 1\n" }, body_writer); + try Packet.write(.{ .data = "deepen 1\n" }, &body); } for (wants) |want| { var buf: [Packet.max_data_length]u8 = undefined; const arg = std.fmt.bufPrint(&buf, "want {s}\n", .{want}) catch unreachable; - try Packet.write(.{ .data = arg }, body_writer); + try Packet.write(.{ .data = arg }, &body); } - try Packet.write(.{ .data = "done\n" }, body_writer); - try Packet.write(.flush, body_writer); - - var request = try session.transport.open(.POST, upload_pack_uri, .{ - .redirect_behavior = .not_allowed, - .server_header_buffer = http_headers_buffer, - .extra_headers = &.{ - .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" }, - .{ .name = "Git-Protocol", .value = "version=2" }, - }, - }); + try Packet.write(.{ .data = "done\n" }, &body); + try Packet.write(.flush, &body); + + fs.* = .{ + .request = try session.transport.request(.POST, upload_pack_uri, .{ + .redirect_behavior = .not_allowed, + .extra_headers = &.{ + .{ .name = "Content-Type", .value = "application/x-git-upload-pack-request" }, + .{ .name = "Git-Protocol", .value = "version=2" }, + }, + }), + .input = undefined, + .reader = undefined, + .remaining_len = undefined, + }; + const request = &fs.request; errdefer request.deinit(); - request.transfer_encoding = .{ .content_length = body.items.len }; - try request.send(); - try request.writeAll(body.items); - try request.finish(); - try request.wait(); - if (request.response.status != .ok) return error.ProtocolError; + try request.sendBodyComplete(body.buffered()); + + var response = try request.receiveHead(&.{}); + if (response.head.status != .ok) return error.ProtocolError; - const reader = request.reader(); + const reader = response.reader(response_buffer); // We are not interested in any of the sections of the returned fetch // data other than the packfile section, since we aren't doing anything // complex like ref negotiation (this is a fresh clone). var state: enum { section_start, section_content } = .section_start; while (true) { - var buf: [Packet.max_data_length]u8 = undefined; - const packet = try Packet.read(reader, &buf); + const packet = try Packet.read(reader); switch (state) { .section_start => switch (packet) { .data => |data| if (mem.eql(u8, Packet.normalizeText(data), "packfile")) { - return .{ .request = request }; + fs.input = reader; + fs.reader = .{ + .buffer = &.{}, + .vtable = &.{ .stream = FetchStream.stream }, + .seek = 0, + .end = 0, + }; + fs.remaining_len = 0; + return; } else { state = .section_content; }, @@ -1061,20 +1075,23 @@ pub const Session = struct { pub const FetchStream = struct { request: std.http.Client.Request, - buf: [Packet.max_data_length]u8 = undefined, - pos: usize = 0, - len: usize = 0, + input: *std.Io.Reader, + reader: std.Io.Reader, + err: ?Error = null, + remaining_len: usize, - pub fn deinit(stream: *FetchStream) void { - stream.request.deinit(); + pub fn deinit(fs: *FetchStream) void { + fs.request.deinit(); } - pub const ReadError = std.http.Client.Request.ReadError || error{ + pub const Error = error{ InvalidPacket, ProtocolError, UnexpectedPacket, + WriteFailed, + ReadFailed, + EndOfStream, }; - pub const Reader = std.io.GenericReader(*FetchStream, ReadError, read); const StreamCode = enum(u8) { pack_data = 1, @@ -1083,33 +1100,41 @@ pub const Session = struct { _, }; - pub fn reader(stream: *FetchStream) Reader { - return .{ .context = stream }; - } - - pub fn read(stream: *FetchStream, buf: []u8) !usize { - if (stream.pos == stream.len) { + pub fn stream(r: *std.Io.Reader, w: *std.Io.Writer, limit: std.Io.Limit) std.Io.Reader.StreamError!usize { + const fs: *FetchStream = @alignCast(@fieldParentPtr("reader", r)); + const input = fs.input; + if (fs.remaining_len == 0) { while (true) { - switch (try Packet.read(stream.request.reader(), &stream.buf)) { - .flush => return 0, + switch (Packet.read(input) catch |err| { + fs.err = err; + return error.ReadFailed; + }) { + .flush => return error.EndOfStream, .data => |data| if (data.len > 1) switch (@as(StreamCode, @enumFromInt(data[0]))) { .pack_data => { - stream.pos = 1; - stream.len = data.len; + input.toss(1); + fs.remaining_len = data.len; break; }, - .fatal_error => return error.ProtocolError, + .fatal_error => { + fs.err = error.ProtocolError; + return error.ReadFailed; + }, else => {}, }, - else => return error.UnexpectedPacket, + else => { + fs.err = error.UnexpectedPacket; + return error.ReadFailed; + }, } } } - - const size = @min(buf.len, stream.len - stream.pos); - @memcpy(buf[0..size], stream.buf[stream.pos .. stream.pos + size]); - stream.pos += size; - return size; + const buf = limit.slice(try w.writableSliceGreedy(1)); + const n = @min(buf.len, fs.remaining_len); + @memcpy(buf[0..n], input.buffered()[0..n]); + input.toss(n); + fs.remaining_len -= n; + return n; } }; };