Skip to content

Commit b48d6ff

Browse files
jacobly0mlugg
authored andcommitted
Legalize: implement scalarization of @select
1 parent 32a57bf commit b48d6ff

File tree

5 files changed

+102
-21
lines changed

5 files changed

+102
-21
lines changed

lib/std/simd.zig

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,9 +368,6 @@ pub fn countElementsWithValue(vec: anytype, value: std.meta.Child(@TypeOf(vec)))
368368
}
369369

370370
test "vector searching" {
371-
if (builtin.zig_backend == .stage2_x86_64 and
372-
!comptime std.Target.x86.featureSetHas(builtin.cpu.features, .ssse3)) return error.SkipZigTest;
373-
374371
const base = @Vector(8, u32){ 6, 4, 7, 4, 4, 2, 3, 7 };
375372

376373
try std.testing.expectEqual(@as(?u3, 1), firstIndexOfValue(base, 4));

src/Air/Legalize.zig

Lines changed: 100 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ pub const Feature = enum {
7474
scalarize_int_from_float,
7575
scalarize_int_from_float_optimized,
7676
scalarize_float_from_int,
77+
scalarize_select,
7778
scalarize_mul_add,
7879

7980
/// Legalize (shift lhs, (splat rhs)) -> (shift lhs, rhs)
@@ -167,6 +168,7 @@ pub const Feature = enum {
167168
.int_from_float => .scalarize_int_from_float,
168169
.int_from_float_optimized => .scalarize_int_from_float_optimized,
169170
.float_from_int => .scalarize_float_from_int,
171+
.select => .scalarize_select,
170172
.mul_add => .scalarize_mul_add,
171173
};
172174
}
@@ -520,7 +522,9 @@ fn legalizeBody(l: *Legalize, body_start: usize, body_len: usize) Error!void {
520522
},
521523
.splat,
522524
.shuffle,
525+
=> {},
523526
.select,
527+
=> if (l.features.contains(.scalarize_select)) continue :inst try l.scalarize(inst, .select_pl_op_bin),
524528
.memset,
525529
.memset_safe,
526530
.memcpy,
@@ -568,7 +572,7 @@ fn legalizeBody(l: *Legalize, body_start: usize, body_len: usize) Error!void {
568572
}
569573
}
570574

571-
const ScalarizeDataTag = enum { un_op, ty_op, bin_op, ty_pl_vector_cmp, pl_op_bin };
575+
const ScalarizeDataTag = enum { un_op, ty_op, bin_op, ty_pl_vector_cmp, pl_op_bin, select_pl_op_bin };
572576
inline fn scalarize(l: *Legalize, orig_inst: Air.Inst.Index, comptime data_tag: ScalarizeDataTag) Error!Air.Inst.Tag {
573577
return l.replaceInst(orig_inst, .block, try l.scalarizeBlockPayload(orig_inst, data_tag));
574578
}
@@ -584,6 +588,7 @@ fn scalarizeBlockPayload(l: *Legalize, orig_inst: Air.Inst.Index, comptime data_
584588
.un_op, .ty_op => 1,
585589
.bin_op, .ty_pl_vector_cmp => 2,
586590
.pl_op_bin => 3,
591+
.select_pl_op_bin => 6,
587592
} + 9
588593
]Air.Inst.Index = undefined;
589594
try l.air_instructions.ensureUnusedCapacity(zcu.gpa, inst_buf.len);
@@ -722,23 +727,67 @@ fn scalarizeBlockPayload(l: *Legalize, orig_inst: Air.Inst.Index, comptime data_
722727
} },
723728
});
724729
},
730+
.select_pl_op_bin => {
731+
const extra = l.extraData(Air.Bin, orig.data.pl_op.payload).data;
732+
var res_elem: Result = .init(l, l.typeOf(extra.lhs).scalarType(zcu), &loop.block);
733+
res_elem.block = .init(loop.block.stealCapacity(6));
734+
{
735+
var select_cond_br: CondBr = .init(l, res_elem.block.add(l, .{
736+
.tag = .array_elem_val,
737+
.data = .{ .bin_op = .{
738+
.lhs = orig.data.pl_op.operand,
739+
.rhs = cur_index_inst.toRef(),
740+
} },
741+
}).toRef(), &res_elem.block, .{});
742+
select_cond_br.then_block = .init(res_elem.block.stealRemainingCapacity());
743+
{
744+
_ = select_cond_br.then_block.add(l, .{
745+
.tag = .br,
746+
.data = .{ .br = .{
747+
.block_inst = res_elem.inst,
748+
.operand = select_cond_br.then_block.add(l, .{
749+
.tag = .array_elem_val,
750+
.data = .{ .bin_op = .{
751+
.lhs = extra.lhs,
752+
.rhs = cur_index_inst.toRef(),
753+
} },
754+
}).toRef(),
755+
} },
756+
});
757+
}
758+
select_cond_br.else_block = .init(select_cond_br.then_block.stealRemainingCapacity());
759+
{
760+
_ = select_cond_br.else_block.add(l, .{
761+
.tag = .br,
762+
.data = .{ .br = .{
763+
.block_inst = res_elem.inst,
764+
.operand = select_cond_br.else_block.add(l, .{
765+
.tag = .array_elem_val,
766+
.data = .{ .bin_op = .{
767+
.lhs = extra.rhs,
768+
.rhs = cur_index_inst.toRef(),
769+
} },
770+
}).toRef(),
771+
} },
772+
});
773+
}
774+
try select_cond_br.finish(l);
775+
}
776+
try res_elem.finish(l);
777+
break :res_elem res_elem.inst;
778+
},
725779
}.toRef(),
726780
}),
727781
} },
728782
});
729783

730-
var loop_cond_br: CondBr = .init(
784+
var loop_cond_br: CondBr = .init(l, (try loop.block.addCmp(
731785
l,
732-
(try loop.block.addCmp(
733-
l,
734-
.lt,
735-
cur_index_inst.toRef(),
736-
try pt.intRef(.usize, res_ty.vectorLen(zcu) - 1),
737-
.{},
738-
)).toRef(),
739-
&loop.block,
786+
.lt,
787+
cur_index_inst.toRef(),
788+
try pt.intRef(.usize, res_ty.vectorLen(zcu) - 1),
740789
.{},
741-
);
790+
)).toRef(), &loop.block, .{});
742791
loop_cond_br.then_block = .init(loop.block.stealRemainingCapacity());
743792
{
744793
_ = loop_cond_br.then_block.add(l, .{
@@ -1138,9 +1187,21 @@ const Block = struct {
11381187
/// This is useful when you've provided a buffer big enough for all your instructions, but you are
11391188
/// now starting a new block and some of them need to live there instead.
11401189
fn stealRemainingCapacity(b: *Block) []Air.Inst.Index {
1141-
const remaining = b.instructions[b.len..];
1142-
b.instructions = b.instructions[0..b.len];
1143-
return remaining;
1190+
return b.stealFrom(b.len);
1191+
}
1192+
1193+
/// Returns `len` elements taken from the unused capacity of `b.instructions`, and shrinks
1194+
/// `b.instructions` down to not include them anymore.
1195+
/// This is useful when you've provided a buffer big enough for all your instructions, but you are
1196+
/// now starting a new block and some of them need to live there instead.
1197+
fn stealCapacity(b: *Block, len: usize) []Air.Inst.Index {
1198+
return b.stealFrom(b.instructions.len - len);
1199+
}
1200+
1201+
fn stealFrom(b: *Block, start: usize) []Air.Inst.Index {
1202+
assert(start >= b.len);
1203+
defer b.instructions.len = start;
1204+
return b.instructions[start..];
11441205
}
11451206

11461207
fn body(b: *const Block) []const Air.Inst.Index {
@@ -1149,6 +1210,31 @@ const Block = struct {
11491210
}
11501211
};
11511212

1213+
const Result = struct {
1214+
inst: Air.Inst.Index,
1215+
block: Block,
1216+
1217+
/// The return value has `block` initialized to `undefined`; it is the caller's reponsibility
1218+
/// to initialize it.
1219+
fn init(l: *Legalize, ty: Type, parent_block: *Block) Result {
1220+
return .{
1221+
.inst = parent_block.add(l, .{
1222+
.tag = .block,
1223+
.data = .{ .ty_pl = .{
1224+
.ty = Air.internedToRef(ty.toIntern()),
1225+
.payload = undefined,
1226+
} },
1227+
}),
1228+
.block = undefined,
1229+
};
1230+
}
1231+
1232+
fn finish(res: Result, l: *Legalize) Error!void {
1233+
const data = &l.air_instructions.items(.data)[@intFromEnum(res.inst)];
1234+
data.ty_pl.payload = try l.addBlockBody(res.block.body());
1235+
}
1236+
};
1237+
11521238
const Loop = struct {
11531239
inst: Air.Inst.Index,
11541240
block: Block,

src/Compilation.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2529,6 +2529,7 @@ pub fn destroy(comp: *Compilation) void {
25292529

25302530
pub fn clearMiscFailures(comp: *Compilation) void {
25312531
comp.alloc_failure_occurred = false;
2532+
comp.link_diags.flags = .{};
25322533
for (comp.misc_failures.values()) |*value| {
25332534
value.deinit(comp.gpa);
25342535
}
@@ -2795,7 +2796,6 @@ pub fn update(comp: *Compilation, main_progress_node: std.Progress.Node) !void {
27952796

27962797
if (anyErrors(comp)) {
27972798
// Skip flushing and keep source files loaded for error reporting.
2798-
comp.link_diags.flags = .{};
27992799
return;
28002800
}
28012801

src/arch/x86_64/CodeGen.zig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ pub fn legalizeFeatures(target: *const std.Target) *const Air.Legalize.Features
8484
.scalarize_int_from_float = use_old,
8585
.scalarize_int_from_float_optimized = use_old,
8686
.scalarize_float_from_int = use_old,
87+
.scalarize_select = true,
8788
.scalarize_mul_add = use_old,
8889

8990
.unsplat_shift_rhs = false,

test/behavior/select.zig

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ test "@select arrays" {
4141
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
4242
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
4343
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
44-
if (builtin.zig_backend == .stage2_x86_64 and
45-
!comptime std.Target.x86.featureSetHas(builtin.cpu.features, .avx2)) return error.SkipZigTest;
4644

4745
try comptime selectArrays();
4846
try selectArrays();
@@ -70,7 +68,6 @@ fn selectArrays() !void {
7068
test "@select compare result" {
7169
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
7270
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
73-
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
7471
if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch == .hexagon) return error.SkipZigTest;
7572

7673
const S = struct {

0 commit comments

Comments
 (0)