Skip to content

Commit 5cc130b

Browse files
authored
compiler: implement spec-compliant shifts
Previously, the compiler used LLVM's shift instructions directly, which have UB whenever the shifts are large or negative. This commit adds runtime checks for negative shifts, and handles oversized shifts.
1 parent 91d1a23 commit 5cc130b

File tree

6 files changed

+99
-15
lines changed

6 files changed

+99
-15
lines changed

compiler/asserts.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,19 @@ func (b *builder) createNilCheck(inst ssa.Value, ptr llvm.Value, blockPrefix str
186186
b.createRuntimeAssert(isnil, blockPrefix, "nilPanic")
187187
}
188188

189+
// createNegativeShiftCheck creates an assertion that panics if the given shift value is negative.
190+
// This function assumes that the shift value is signed.
191+
func (b *builder) createNegativeShiftCheck(shift llvm.Value) {
192+
if b.fn.IsNoBounds() {
193+
// Function disabled bounds checking - skip shift check.
194+
return
195+
}
196+
197+
// isNegative = shift < 0
198+
isNegative := b.CreateICmp(llvm.IntSLT, shift, llvm.ConstInt(shift.Type(), 0, false), "")
199+
b.createRuntimeAssert(isNegative, "shift", "negativeShiftPanic")
200+
}
201+
189202
// createRuntimeAssert is a common function to create a new branch on an assert
190203
// bool, calling an assert func if the assert value is true (1).
191204
func (b *builder) createRuntimeAssert(assert llvm.Value, blockPrefix, assertFunc string) {

compiler/compiler.go

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,7 +1476,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) {
14761476
case *ssa.BinOp:
14771477
x := b.getValue(expr.X)
14781478
y := b.getValue(expr.Y)
1479-
return b.createBinOp(expr.Op, expr.X.Type(), x, y, expr.Pos())
1479+
return b.createBinOp(expr.Op, expr.X.Type(), expr.Y.Type(), x, y, expr.Pos())
14801480
case *ssa.Call:
14811481
return b.createFunctionCall(expr.Common())
14821482
case *ssa.ChangeInterface:
@@ -1925,7 +1925,7 @@ func (b *builder) createExpr(expr ssa.Value) (llvm.Value, error) {
19251925
// same type, even for bitshifts. Also, signedness in Go is encoded in the type
19261926
// and is encoded in the operation in LLVM IR: this is important for some
19271927
// operations such as divide.
1928-
func (b *builder) createBinOp(op token.Token, typ types.Type, x, y llvm.Value, pos token.Pos) (llvm.Value, error) {
1928+
func (b *builder) createBinOp(op token.Token, typ, ytyp types.Type, x, y llvm.Value, pos token.Pos) (llvm.Value, error) {
19291929
switch typ := typ.Underlying().(type) {
19301930
case *types.Basic:
19311931
if typ.Info()&types.IsInteger != 0 {
@@ -1957,32 +1957,49 @@ func (b *builder) createBinOp(op token.Token, typ types.Type, x, y llvm.Value, p
19571957
case token.XOR: // ^
19581958
return b.CreateXor(x, y, ""), nil
19591959
case token.SHL, token.SHR:
1960+
if ytyp.Underlying().(*types.Basic).Info()&types.IsUnsigned == 0 {
1961+
// Ensure that y is not negative.
1962+
b.createNegativeShiftCheck(y)
1963+
}
1964+
19601965
sizeX := b.targetData.TypeAllocSize(x.Type())
19611966
sizeY := b.targetData.TypeAllocSize(y.Type())
1962-
if sizeX > sizeY {
1963-
// x and y must have equal sizes, make Y bigger in this case.
1964-
// y is unsigned, this has been checked by the Go type checker.
1967+
1968+
// Check if the shift is bigger than the bit-width of the shifted value.
1969+
// This is UB in LLVM, so it needs to be handled seperately.
1970+
// The Go spec indirectly defines the result as 0.
1971+
// Negative shifts are handled earlier, so we can treat y as unsigned.
1972+
overshifted := b.CreateICmp(llvm.IntUGE, y, llvm.ConstInt(y.Type(), 8*sizeX, false), "shift.overflow")
1973+
1974+
// Adjust the size of y to match x.
1975+
switch {
1976+
case sizeX > sizeY:
19651977
y = b.CreateZExt(y, x.Type(), "")
1966-
} else if sizeX < sizeY {
1967-
// What about shifting more than the integer width?
1968-
// I'm not entirely sure what the Go spec is on that, but as
1969-
// Intel CPUs have undefined behavior when shifting more
1970-
// than the integer width I'm assuming it is also undefined
1971-
// in Go.
1978+
case sizeX < sizeY:
1979+
// If it gets truncated, overshifted will be true and it will not matter.
19721980
y = b.CreateTrunc(y, x.Type(), "")
19731981
}
1982+
1983+
// Create a shift operation.
1984+
var val llvm.Value
19741985
switch op {
19751986
case token.SHL: // <<
1976-
return b.CreateShl(x, y, ""), nil
1987+
val = b.CreateShl(x, y, "")
19771988
case token.SHR: // >>
19781989
if signed {
1990+
// Arithmetic right shifts work differently, since shifting a negative number right yields -1.
1991+
// Cap the shift input rather than selecting the output.
1992+
y = b.CreateSelect(overshifted, llvm.ConstInt(y.Type(), 8*sizeX-1, false), y, "shift.offset")
19791993
return b.CreateAShr(x, y, ""), nil
19801994
} else {
1981-
return b.CreateLShr(x, y, ""), nil
1995+
val = b.CreateLShr(x, y, "")
19821996
}
19831997
default:
19841998
panic("unreachable")
19851999
}
2000+
2001+
// Select between the shift result and zero depending on whether there was an overshift.
2002+
return b.CreateSelect(overshifted, llvm.ConstInt(val.Type(), 0, false), val, "shift.result"), nil
19862003
case token.EQL: // ==
19872004
return b.CreateICmp(llvm.IntEQ, x, y, ""), nil
19882005
case token.NEQ: // !=
@@ -2218,7 +2235,7 @@ func (b *builder) createBinOp(op token.Token, typ types.Type, x, y llvm.Value, p
22182235
for i := 0; i < int(typ.Len()); i++ {
22192236
xField := b.CreateExtractValue(x, i, "")
22202237
yField := b.CreateExtractValue(y, i, "")
2221-
fieldEqual, err := b.createBinOp(token.EQL, typ.Elem(), xField, yField, pos)
2238+
fieldEqual, err := b.createBinOp(token.EQL, typ.Elem(), typ.Elem(), xField, yField, pos)
22222239
if err != nil {
22232240
return llvm.Value{}, err
22242241
}
@@ -2246,7 +2263,7 @@ func (b *builder) createBinOp(op token.Token, typ types.Type, x, y llvm.Value, p
22462263
fieldType := typ.Field(i).Type()
22472264
xField := b.CreateExtractValue(x, i, "")
22482265
yField := b.CreateExtractValue(y, i, "")
2249-
fieldEqual, err := b.createBinOp(token.EQL, fieldType, xField, yField, pos)
2266+
fieldEqual, err := b.createBinOp(token.EQL, fieldType, fieldType, xField, yField, pos)
22502267
if err != nil {
22512268
return llvm.Value{}, err
22522269
}

interp/frame.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,18 @@ func (fr *frame) evalBasicBlock(bb, incoming llvm.BasicBlock, indent string) (re
603603
}
604604
fr.locals[inst] = &LocalValue{fr.Eval, fr.builder.CreateInsertValue(agg.Underlying, val.Value(), int(indices[0]), inst.Name())}
605605
}
606+
case !inst.IsASelectInst().IsNil():
607+
// var result T
608+
// if cond {
609+
// result = x
610+
// } else {
611+
// result = y
612+
// }
613+
// return result
614+
cond := fr.getLocal(inst.Operand(0)).(*LocalValue).Underlying
615+
x := fr.getLocal(inst.Operand(1)).(*LocalValue).Underlying
616+
y := fr.getLocal(inst.Operand(2)).(*LocalValue).Underlying
617+
fr.locals[inst] = &LocalValue{fr.Eval, fr.builder.CreateSelect(cond, x, y, "")}
606618

607619
case !inst.IsAReturnInst().IsNil() && inst.OperandsCount() == 0:
608620
return nil, nil, nil // ret void

src/runtime/panic.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ func chanMakePanic() {
4747
runtimePanic("new channel is too big")
4848
}
4949

50+
// Panic when a shift value is negative.
51+
func negativeShiftPanic() {
52+
runtimePanic("negative shift")
53+
}
54+
5055
func blockingPanic() {
5156
runtimePanic("trying to do blocking operation in exported function")
5257
}

testdata/binop.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ func main() {
6161
println(c128 != 3+2i)
6262
println(c128 != 4+2i)
6363
println(c128 != 3+3i)
64+
65+
println("shifts")
66+
println(shlSimple == 4)
67+
println(shlOverflow == 0)
68+
println(shrSimple == 1)
69+
println(shrOverflow == 0)
70+
println(ashrNeg == -1)
71+
println(ashrOverflow == 0)
72+
println(ashrNegOverflow == -1)
6473
}
6574

6675
var x = true
@@ -87,3 +96,23 @@ type Struct2 struct {
8796
_ float64
8897
i int
8998
}
99+
100+
func shl(x uint, y uint) uint {
101+
return x << y
102+
}
103+
104+
func shr(x uint, y uint) uint {
105+
return x >> y
106+
}
107+
108+
func ashr(x int, y uint) int {
109+
return x >> y
110+
}
111+
112+
var shlSimple = shl(2, 1)
113+
var shlOverflow = shl(2, 1000)
114+
var shrSimple = shr(2, 1)
115+
var shrOverflow = shr(2, 1000000)
116+
var ashrNeg = ashr(-1, 1)
117+
var ashrOverflow = ashr(1, 1000000)
118+
var ashrNegOverflow = ashr(-1, 1000000)

testdata/binop.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,11 @@ false
5454
true
5555
true
5656
true
57+
shifts
58+
true
59+
true
60+
true
61+
true
62+
true
63+
true
64+
true

0 commit comments

Comments
 (0)