Skip to content

Commit 64814dc

Browse files
committed
std.compress.flate.Decompress: respect stream limit
1 parent 6caa100 commit 64814dc

File tree

2 files changed

+72
-24
lines changed

2 files changed

+72
-24
lines changed

lib/std/Io/Writer.zig

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,6 +2286,13 @@ pub fn fixedDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usiz
22862286
}
22872287
}
22882288

2289+
pub fn unreachableDrain(w: *Writer, data: []const []const u8, splat: usize) Error!usize {
2290+
_ = w;
2291+
_ = data;
2292+
_ = splat;
2293+
unreachable;
2294+
}
2295+
22892296
/// Provides a `Writer` implementation based on calling `Hasher.update`, sending
22902297
/// all data also to an underlying `Writer`.
22912298
///

lib/std/compress/flate/Decompress.zig

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ const State = union(enum) {
3737
stored_block: u16,
3838
fixed_block,
3939
dynamic_block,
40+
dynamic_block_literal: u8,
41+
dynamic_block_match: u16,
4042
protocol_footer,
4143
end,
4244
};
@@ -63,7 +65,7 @@ const direct_vtable: Reader.VTable = .{
6365
const indirect_vtable: Reader.VTable = .{
6466
.stream = streamIndirect,
6567
.rebase = rebaseFallible,
66-
.discard = discard,
68+
.discard = discardIndirect,
6769
.readVec = readVec,
6870
};
6971

@@ -128,6 +130,26 @@ fn discard(r: *Reader, limit: std.Io.Limit) Reader.Error!usize {
128130
return n;
129131
}
130132

133+
fn discardIndirect(r: *Reader, limit: std.Io.Limit) Reader.Error!usize {
134+
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
135+
if (r.end + flate.history_len > r.buffer.len) rebase(r, flate.history_len);
136+
var writer: Writer = .{
137+
.buffer = r.buffer,
138+
.end = r.end,
139+
.vtable = &.{ .drain = Writer.unreachableDrain },
140+
};
141+
{
142+
defer r.end = writer.end;
143+
_ = streamFallible(d, &writer, .limited(writer.buffer.len - writer.end)) catch |err| switch (err) {
144+
error.WriteFailed => unreachable,
145+
else => |e| return e,
146+
};
147+
}
148+
const n = limit.minInt(r.end - r.seek);
149+
r.seek += n;
150+
return n;
151+
}
152+
131153
fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize {
132154
_ = data;
133155
const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
@@ -140,7 +162,7 @@ fn streamIndirectInner(d: *Decompress) Reader.Error!usize {
140162
var writer: Writer = .{
141163
.buffer = r.buffer,
142164
.end = r.end,
143-
.vtable = &.{ .drain = Writer.fixedDrain },
165+
.vtable = &.{ .drain = Writer.unreachableDrain },
144166
};
145167
defer r.end = writer.end;
146168
_ = streamFallible(d, &writer, .limited(writer.buffer.len - writer.end)) catch |err| switch (err) {
@@ -379,30 +401,49 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
379401
.dynamic_block => {
380402
// In larger archives most blocks are usually dynamic, so
381403
// decompression performance depends on this logic.
382-
while (remaining > 0) {
383-
const sym = try d.decodeSymbol(&d.lit_dec);
384-
385-
switch (sym.kind) {
386-
.literal => {
387-
try w.writeBytePreserve(flate.history_len, sym.symbol);
404+
var sym = try d.decodeSymbol(&d.lit_dec);
405+
sym: switch (sym.kind) {
406+
.literal => {
407+
if (remaining != 0) {
408+
@branchHint(.likely);
388409
remaining -= 1;
389-
},
390-
.match => {
391-
// Decode match backreference <length, distance>
392-
const length = try d.decodeLength(sym.symbol);
393-
const dsm = try d.decodeSymbol(&d.dst_dec);
394-
const distance = try d.decodeDistance(dsm.symbol);
395-
try writeMatch(w, length, distance);
396-
remaining -= length;
397-
},
398-
.end_of_block => {
399-
d.state = if (d.final_block) .protocol_footer else .block_header;
410+
try w.writeBytePreserve(flate.history_len, sym.symbol);
411+
sym = try d.decodeSymbol(&d.lit_dec);
412+
continue :sym sym.kind;
413+
} else {
414+
d.state = .{ .dynamic_block_literal = sym.symbol };
400415
return @intFromEnum(limit) - remaining;
401-
},
402-
}
416+
}
417+
},
418+
.match => {
419+
// Decode match backreference <length, distance>
420+
const length = try d.decodeLength(sym.symbol);
421+
continue :sw .{ .dynamic_block_match = length };
422+
},
423+
.end_of_block => {
424+
d.state = if (d.final_block) .protocol_footer else .block_header;
425+
continue :sw d.state;
426+
},
427+
}
428+
},
429+
.dynamic_block_literal => |symbol| {
430+
assert(remaining != 0);
431+
remaining -= 1;
432+
try w.writeBytePreserve(flate.history_len, symbol);
433+
continue :sw .dynamic_block;
434+
},
435+
.dynamic_block_match => |length| {
436+
if (remaining >= length) {
437+
@branchHint(.likely);
438+
remaining -= length;
439+
const dsm = try d.decodeSymbol(&d.dst_dec);
440+
const distance = try d.decodeDistance(dsm.symbol);
441+
try writeMatch(w, length, distance);
442+
continue :sw .dynamic_block;
443+
} else {
444+
d.state = .{ .dynamic_block_match = length };
445+
return @intFromEnum(limit) - remaining;
403446
}
404-
d.state = .dynamic_block;
405-
return @intFromEnum(limit) - remaining;
406447
},
407448
.protocol_footer => {
408449
switch (d.container_metadata) {
@@ -424,7 +465,7 @@ fn streamInner(d: *Decompress, w: *Writer, limit: std.Io.Limit) (Error || Reader
424465
},
425466
}
426467
d.state = .end;
427-
return 0;
468+
return @intFromEnum(limit) - remaining;
428469
},
429470
.end => return error.EndOfStream,
430471
}

0 commit comments

Comments
 (0)