Skip to content

Commit 9edfccb

Browse files
jacobly0mlugg
authored andcommitted
Legalize: implement scalarization of overflow intrinsics
1 parent ec579aa commit 9edfccb

19 files changed

+350
-98
lines changed

lib/std/simd.zig

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,6 @@ pub fn prefixScan(comptime op: std.builtin.ReduceOp, comptime hop: isize, vec: a
455455
}
456456

457457
test "vector prefix scan" {
458-
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
459458
if ((builtin.cpu.arch == .armeb or builtin.cpu.arch == .thumbeb) and builtin.zig_backend == .stage2_llvm) return error.SkipZigTest; // https://github.com/ziglang/zig/issues/22060
460459
if (builtin.cpu.arch == .aarch64_be and builtin.zig_backend == .stage2_llvm) return error.SkipZigTest; // https://github.com/ziglang/zig/issues/21893
461460
if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch == .hexagon) return error.SkipZigTest;

src/Air/Legalize.zig

Lines changed: 191 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ pub const Feature = enum {
3333
scalarize_mod_optimized,
3434
scalarize_max,
3535
scalarize_min,
36+
scalarize_add_with_overflow,
37+
scalarize_sub_with_overflow,
38+
scalarize_mul_with_overflow,
39+
scalarize_shl_with_overflow,
3640
scalarize_bit_and,
3741
scalarize_bit_or,
3842
scalarize_shr,
@@ -129,6 +133,10 @@ pub const Feature = enum {
129133
.mod_optimized => .scalarize_mod_optimized,
130134
.max => .scalarize_max,
131135
.min => .scalarize_min,
136+
.add_with_overflow => .scalarize_add_with_overflow,
137+
.sub_with_overflow => .scalarize_sub_with_overflow,
138+
.mul_with_overflow => .scalarize_mul_with_overflow,
139+
.shl_with_overflow => .scalarize_shl_with_overflow,
132140
.bit_and => .scalarize_bit_and,
133141
.bit_or => .scalarize_bit_or,
134142
.shr => .scalarize_shr,
@@ -279,10 +287,15 @@ fn legalizeBody(l: *Legalize, body_start: usize, body_len: usize) Error!void {
279287
},
280288
.ptr_add,
281289
.ptr_sub,
282-
.add_with_overflow,
290+
=> {},
291+
inline .add_with_overflow,
283292
.sub_with_overflow,
284293
.mul_with_overflow,
285294
.shl_with_overflow,
295+
=> |air_tag| if (l.features.contains(comptime .scalarize(air_tag))) {
296+
const ty_pl = l.air_instructions.items(.data)[@intFromEnum(inst)].ty_pl;
297+
if (ty_pl.ty.toType().fieldType(0, zcu).isVector(zcu)) continue :inst l.replaceInst(inst, .block, try l.scalarizeOverflowBlockPayload(inst));
298+
},
286299
.alloc,
287300
=> {},
288301
.inferred_alloc,
@@ -518,7 +531,7 @@ fn legalizeBody(l: *Legalize, body_start: usize, body_len: usize) Error!void {
518531
switch (vector_ty.vectorLen(zcu)) {
519532
0 => unreachable,
520533
1 => continue :inst l.replaceInst(inst, .bitcast, .{ .ty_op = .{
521-
.ty = Air.internedToRef(vector_ty.scalarType(zcu).toIntern()),
534+
.ty = Air.internedToRef(vector_ty.childType(zcu).toIntern()),
522535
.operand = reduce.operand,
523536
} }),
524537
else => break :done,
@@ -646,7 +659,7 @@ fn scalarizeBlockPayload(l: *Legalize, orig_inst: Air.Inst.Index, comptime form:
646659
.ty_op => loop.block.add(l, .{
647660
.tag = orig.tag,
648661
.data = .{ .ty_op = .{
649-
.ty = Air.internedToRef(orig.data.ty_op.ty.toType().scalarType(zcu).toIntern()),
662+
.ty = Air.internedToRef(res_ty.childType(zcu).toIntern()),
650663
.operand = loop.block.add(l, .{
651664
.tag = .array_elem_val,
652665
.data = .{ .bin_op = .{
@@ -745,7 +758,7 @@ fn scalarizeBlockPayload(l: *Legalize, orig_inst: Air.Inst.Index, comptime form:
745758
.shuffle_two => unwrapped.operand_a,
746759
};
747760
const operand_a_len = l.typeOf(operand_a).vectorLen(zcu);
748-
const elem_ty = unwrapped.result_ty.scalarType(zcu);
761+
const elem_ty = res_ty.childType(zcu);
749762
var res_elem: Result = .init(l, elem_ty, &loop.block);
750763
res_elem.block = .init(loop.block.stealCapacity(extra_insts));
751764
{
@@ -945,7 +958,7 @@ fn scalarizeBlockPayload(l: *Legalize, orig_inst: Air.Inst.Index, comptime form:
945958
},
946959
.select => {
947960
const extra = l.extraData(Air.Bin, orig.data.pl_op.payload).data;
948-
var res_elem: Result = .init(l, l.typeOf(extra.lhs).scalarType(zcu), &loop.block);
961+
var res_elem: Result = .init(l, l.typeOf(extra.lhs).childType(zcu), &loop.block);
949962
res_elem.block = .init(loop.block.stealCapacity(extra_insts));
950963
{
951964
var select_cond_br: CondBr = .init(l, res_elem.block.add(l, .{
@@ -1043,6 +1056,176 @@ fn scalarizeBlockPayload(l: *Legalize, orig_inst: Air.Inst.Index, comptime form:
10431056
.payload = try l.addBlockBody(res_block.body()),
10441057
} };
10451058
}
1059+
fn scalarizeOverflowBlockPayload(l: *Legalize, orig_inst: Air.Inst.Index) Error!Air.Inst.Data {
1060+
const pt = l.pt;
1061+
const zcu = pt.zcu;
1062+
1063+
const orig = l.air_instructions.get(@intFromEnum(orig_inst));
1064+
const res_ty = l.typeOfIndex(orig_inst);
1065+
const wrapped_res_ty = res_ty.fieldType(0, zcu);
1066+
const wrapped_res_scalar_ty = wrapped_res_ty.childType(zcu);
1067+
const res_len = wrapped_res_ty.vectorLen(zcu);
1068+
1069+
var inst_buf: [21]Air.Inst.Index = undefined;
1070+
try l.air_instructions.ensureUnusedCapacity(zcu.gpa, inst_buf.len);
1071+
1072+
var res_block: Block = .init(&inst_buf);
1073+
{
1074+
const res_alloc_inst = res_block.add(l, .{
1075+
.tag = .alloc,
1076+
.data = .{ .ty = try pt.singleMutPtrType(res_ty) },
1077+
});
1078+
const ptr_wrapped_res_inst = res_block.add(l, .{
1079+
.tag = .struct_field_ptr_index_0,
1080+
.data = .{ .ty_op = .{
1081+
.ty = Air.internedToRef((try pt.singleMutPtrType(wrapped_res_ty)).toIntern()),
1082+
.operand = res_alloc_inst.toRef(),
1083+
} },
1084+
});
1085+
const ptr_overflow_res_inst = res_block.add(l, .{
1086+
.tag = .struct_field_ptr_index_1,
1087+
.data = .{ .ty_op = .{
1088+
.ty = Air.internedToRef((try pt.singleMutPtrType(res_ty.fieldType(1, zcu))).toIntern()),
1089+
.operand = res_alloc_inst.toRef(),
1090+
} },
1091+
});
1092+
const index_alloc_inst = res_block.add(l, .{
1093+
.tag = .alloc,
1094+
.data = .{ .ty = .ptr_usize },
1095+
});
1096+
_ = res_block.add(l, .{
1097+
.tag = .store,
1098+
.data = .{ .bin_op = .{
1099+
.lhs = index_alloc_inst.toRef(),
1100+
.rhs = .zero_usize,
1101+
} },
1102+
});
1103+
1104+
var loop: Loop = .init(l, &res_block);
1105+
loop.block = .init(res_block.stealRemainingCapacity());
1106+
{
1107+
const cur_index_inst = loop.block.add(l, .{
1108+
.tag = .load,
1109+
.data = .{ .ty_op = .{
1110+
.ty = .usize_type,
1111+
.operand = index_alloc_inst.toRef(),
1112+
} },
1113+
});
1114+
const extra = l.extraData(Air.Bin, orig.data.ty_pl.payload).data;
1115+
const res_elem = loop.block.add(l, .{
1116+
.tag = orig.tag,
1117+
.data = .{ .ty_pl = .{
1118+
.ty = Air.internedToRef(try zcu.intern_pool.getTupleType(zcu.gpa, pt.tid, .{
1119+
.types = &.{ wrapped_res_scalar_ty.toIntern(), .u1_type },
1120+
.values = &(.{.none} ** 2),
1121+
})),
1122+
.payload = try l.addExtra(Air.Bin, .{
1123+
.lhs = loop.block.add(l, .{
1124+
.tag = .array_elem_val,
1125+
.data = .{ .bin_op = .{
1126+
.lhs = extra.lhs,
1127+
.rhs = cur_index_inst.toRef(),
1128+
} },
1129+
}).toRef(),
1130+
.rhs = loop.block.add(l, .{
1131+
.tag = .array_elem_val,
1132+
.data = .{ .bin_op = .{
1133+
.lhs = extra.rhs,
1134+
.rhs = cur_index_inst.toRef(),
1135+
} },
1136+
}).toRef(),
1137+
}),
1138+
} },
1139+
});
1140+
_ = loop.block.add(l, .{
1141+
.tag = .vector_store_elem,
1142+
.data = .{ .vector_store_elem = .{
1143+
.vector_ptr = ptr_overflow_res_inst.toRef(),
1144+
.payload = try l.addExtra(Air.Bin, .{
1145+
.lhs = cur_index_inst.toRef(),
1146+
.rhs = loop.block.add(l, .{
1147+
.tag = .struct_field_val,
1148+
.data = .{ .ty_pl = .{
1149+
.ty = .u1_type,
1150+
.payload = try l.addExtra(Air.StructField, .{
1151+
.struct_operand = res_elem.toRef(),
1152+
.field_index = 1,
1153+
}),
1154+
} },
1155+
}).toRef(),
1156+
}),
1157+
} },
1158+
});
1159+
_ = loop.block.add(l, .{
1160+
.tag = .vector_store_elem,
1161+
.data = .{ .vector_store_elem = .{
1162+
.vector_ptr = ptr_wrapped_res_inst.toRef(),
1163+
.payload = try l.addExtra(Air.Bin, .{
1164+
.lhs = cur_index_inst.toRef(),
1165+
.rhs = loop.block.add(l, .{
1166+
.tag = .struct_field_val,
1167+
.data = .{ .ty_pl = .{
1168+
.ty = Air.internedToRef(wrapped_res_scalar_ty.toIntern()),
1169+
.payload = try l.addExtra(Air.StructField, .{
1170+
.struct_operand = res_elem.toRef(),
1171+
.field_index = 0,
1172+
}),
1173+
} },
1174+
}).toRef(),
1175+
}),
1176+
} },
1177+
});
1178+
1179+
var loop_cond_br: CondBr = .init(l, (try loop.block.addCmp(
1180+
l,
1181+
.lt,
1182+
cur_index_inst.toRef(),
1183+
try pt.intRef(.usize, res_len - 1),
1184+
.{},
1185+
)).toRef(), &loop.block, .{});
1186+
loop_cond_br.then_block = .init(loop.block.stealRemainingCapacity());
1187+
{
1188+
_ = loop_cond_br.then_block.add(l, .{
1189+
.tag = .store,
1190+
.data = .{ .bin_op = .{
1191+
.lhs = index_alloc_inst.toRef(),
1192+
.rhs = loop_cond_br.then_block.add(l, .{
1193+
.tag = .add,
1194+
.data = .{ .bin_op = .{
1195+
.lhs = cur_index_inst.toRef(),
1196+
.rhs = .one_usize,
1197+
} },
1198+
}).toRef(),
1199+
} },
1200+
});
1201+
_ = loop_cond_br.then_block.add(l, .{
1202+
.tag = .repeat,
1203+
.data = .{ .repeat = .{ .loop_inst = loop.inst } },
1204+
});
1205+
}
1206+
loop_cond_br.else_block = .init(loop_cond_br.then_block.stealRemainingCapacity());
1207+
_ = loop_cond_br.else_block.add(l, .{
1208+
.tag = .br,
1209+
.data = .{ .br = .{
1210+
.block_inst = orig_inst,
1211+
.operand = loop_cond_br.else_block.add(l, .{
1212+
.tag = .load,
1213+
.data = .{ .ty_op = .{
1214+
.ty = Air.internedToRef(res_ty.toIntern()),
1215+
.operand = res_alloc_inst.toRef(),
1216+
} },
1217+
}).toRef(),
1218+
} },
1219+
});
1220+
try loop_cond_br.finish(l);
1221+
}
1222+
try loop.finish(l);
1223+
}
1224+
return .{ .ty_pl = .{
1225+
.ty = Air.internedToRef(res_ty.toIntern()),
1226+
.payload = try l.addBlockBody(res_block.body()),
1227+
} };
1228+
}
10461229

10471230
fn safeIntcastBlockPayload(l: *Legalize, orig_inst: Air.Inst.Index) Error!Air.Inst.Data {
10481231
const pt = l.pt;
@@ -1535,8 +1718,9 @@ fn addBlockBody(l: *Legalize, body: []const Air.Inst.Index) Error!u32 {
15351718
return @intCast(l.air_extra.items.len);
15361719
}
15371720

1538-
// inline to propagate comptime `tag`s
1539-
inline fn replaceInst(l: *Legalize, inst: Air.Inst.Index, tag: Air.Inst.Tag, data: Air.Inst.Data) Air.Inst.Tag {
1721+
/// Returns `tag` to remind the caller to `continue :inst` the result.
1722+
/// This is inline to propagate the comptime-known `tag`.
1723+
inline fn replaceInst(l: *Legalize, inst: Air.Inst.Index, comptime tag: Air.Inst.Tag, data: Air.Inst.Data) Air.Inst.Tag {
15401724
const orig_ty = if (std.debug.runtime_safety) l.typeOfIndex(inst) else {};
15411725
l.air_instructions.set(@intFromEnum(inst), .{ .tag = tag, .data = data });
15421726
if (std.debug.runtime_safety) assert(l.typeOfIndex(inst).toIntern() == orig_ty.toIntern());

0 commit comments

Comments
 (0)