Skip to content

Commit 85854cd

Browse files
aykevldeadprogram
authored andcommitted
compiler: add dereferenceable_or_null attribute where possible
This gives a hint to the compiler that such parameters are either NULL or point to a valid object that can be dereferenced. This is not directly very useful, but is very useful when combined with https://reviews.llvm.org/D60047 to remove the runtime.isnil hack without regressing escape analysis.
1 parent 9800685 commit 85854cd

File tree

4 files changed

+98
-17
lines changed

4 files changed

+98
-17
lines changed

compiler/calls.go

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package compiler
22

33
import (
4+
"go/types"
5+
46
"tinygo.org/x/go-llvm"
57
)
68

@@ -11,6 +13,16 @@ import (
1113
// a struct contains more fields, it is passed as a struct without expanding.
1214
const MaxFieldsPerParam = 3
1315

16+
// paramFlags identifies parameter attributes for flags. Most importantly, it
17+
// determines which parameters are dereferenceable_or_null and which aren't.
18+
type paramFlags uint8
19+
20+
const (
21+
// Parameter may have the deferenceable_or_null attribute. This attribute
22+
// cannot be applied to unsafe.Pointer and to the data pointer of slices.
23+
paramIsDeferenceableOrNull = 1 << iota
24+
)
25+
1426
// createCall creates a new call to runtime.<fnName> with the given arguments.
1527
func (b *builder) createRuntimeCall(fnName string, args []llvm.Value, name string) llvm.Value {
1628
fullName := "runtime." + fnName
@@ -36,19 +48,19 @@ func (b *builder) createCall(fn llvm.Value, args []llvm.Value, name string) llvm
3648

3749
// Expand an argument type to a list that can be used in a function call
3850
// parameter list.
39-
func expandFormalParamType(t llvm.Type) []llvm.Type {
51+
func expandFormalParamType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) {
4052
switch t.TypeKind() {
4153
case llvm.StructTypeKind:
42-
fields := flattenAggregateType(t)
54+
fields, fieldFlags := flattenAggregateType(t, goType)
4355
if len(fields) <= MaxFieldsPerParam {
44-
return fields
56+
return fields, fieldFlags
4557
} else {
4658
// failed to lower
47-
return []llvm.Type{t}
59+
return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)}
4860
}
4961
default:
5062
// TODO: split small arrays
51-
return []llvm.Type{t}
63+
return []llvm.Type{t}, []paramFlags{getTypeFlags(goType)}
5264
}
5365
}
5466

@@ -79,7 +91,7 @@ func (b *builder) expandFormalParamOffsets(t llvm.Type) []uint64 {
7991
func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value {
8092
switch v.Type().TypeKind() {
8193
case llvm.StructTypeKind:
82-
fieldTypes := flattenAggregateType(v.Type())
94+
fieldTypes, _ := flattenAggregateType(v.Type(), nil)
8395
if len(fieldTypes) <= MaxFieldsPerParam {
8496
fields := b.flattenAggregate(v)
8597
if len(fields) != len(fieldTypes) {
@@ -98,17 +110,62 @@ func (b *builder) expandFormalParam(v llvm.Value) []llvm.Value {
98110

99111
// Try to flatten a struct type to a list of types. Returns a 1-element slice
100112
// with the passed in type if this is not possible.
101-
func flattenAggregateType(t llvm.Type) []llvm.Type {
113+
func flattenAggregateType(t llvm.Type, goType types.Type) ([]llvm.Type, []paramFlags) {
114+
typeFlags := getTypeFlags(goType)
102115
switch t.TypeKind() {
103116
case llvm.StructTypeKind:
104117
fields := make([]llvm.Type, 0, t.StructElementTypesCount())
105-
for _, subfield := range t.StructElementTypes() {
106-
subfields := flattenAggregateType(subfield)
118+
fieldFlags := make([]paramFlags, 0, cap(fields))
119+
for i, subfield := range t.StructElementTypes() {
120+
subfields, subfieldFlags := flattenAggregateType(subfield, extractSubfield(goType, i))
121+
for i := range subfieldFlags {
122+
subfieldFlags[i] |= typeFlags
123+
}
107124
fields = append(fields, subfields...)
125+
fieldFlags = append(fieldFlags, subfieldFlags...)
108126
}
109-
return fields
127+
return fields, fieldFlags
128+
default:
129+
return []llvm.Type{t}, []paramFlags{typeFlags}
130+
}
131+
}
132+
133+
// getTypeFlags returns the type flags for a given type. It will not recurse
134+
// into sub-types (such as in structs).
135+
func getTypeFlags(t types.Type) paramFlags {
136+
if t == nil {
137+
return 0
138+
}
139+
switch t.Underlying().(type) {
140+
case *types.Pointer:
141+
// Pointers in Go must either point to an object or be nil.
142+
return paramIsDeferenceableOrNull
143+
case *types.Chan, *types.Map:
144+
// Channels and maps are implemented as pointers pointing to some
145+
// object, and follow the same rules as *types.Pointer.
146+
return paramIsDeferenceableOrNull
147+
default:
148+
return 0
149+
}
150+
}
151+
152+
// extractSubfield extracts a field from a struct, or returns null if this is
153+
// not a struct and thus no subfield can be obtained.
154+
func extractSubfield(t types.Type, field int) types.Type {
155+
if t == nil {
156+
return nil
157+
}
158+
switch t := t.Underlying().(type) {
159+
case *types.Struct:
160+
return t.Field(field).Type()
161+
case *types.Interface, *types.Slice, *types.Basic, *types.Signature:
162+
// These Go types are (sometimes) implemented as LLVM structs but can't
163+
// really be split further up in Go (with the possible exception of
164+
// complex numbers).
165+
return nil
110166
default:
111-
return []llvm.Type{t}
167+
// This should be unreachable.
168+
panic("cannot split subfield: " + t.String())
112169
}
113170
}
114171

@@ -169,7 +226,8 @@ func (b *builder) collapseFormalParam(t llvm.Type, fields []llvm.Value) llvm.Val
169226
func (b *builder) collapseFormalParamInternal(t llvm.Type, fields []llvm.Value) (llvm.Value, []llvm.Value) {
170227
switch t.TypeKind() {
171228
case llvm.StructTypeKind:
172-
if len(flattenAggregateType(t)) <= MaxFieldsPerParam {
229+
flattened, _ := flattenAggregateType(t, nil)
230+
if len(flattened) <= MaxFieldsPerParam {
173231
value := llvm.ConstNull(t)
174232
for i, subtyp := range t.StructElementTypes() {
175233
structField, remaining := b.collapseFormalParamInternal(subtyp, fields)

compiler/compiler.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -750,17 +750,20 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) {
750750
}
751751

752752
var paramTypes []llvm.Type
753+
var paramTypeVariants []paramFlags
753754
for _, param := range f.Params {
754755
paramType := c.getLLVMType(param.Type())
755-
paramTypeFragments := expandFormalParamType(paramType)
756+
paramTypeFragments, paramTypeFragmentVariants := expandFormalParamType(paramType, param.Type())
756757
paramTypes = append(paramTypes, paramTypeFragments...)
758+
paramTypeVariants = append(paramTypeVariants, paramTypeFragmentVariants...)
757759
}
758760

759761
// Add an extra parameter as the function context. This context is used in
760762
// closures and bound methods, but should be optimized away when not used.
761763
if !f.IsExported() {
762764
paramTypes = append(paramTypes, c.i8ptrType) // context
763765
paramTypes = append(paramTypes, c.i8ptrType) // parent coroutine
766+
paramTypeVariants = append(paramTypeVariants, 0, 0)
764767
}
765768

766769
fnType := llvm.FunctionType(retType, paramTypes, false)
@@ -771,6 +774,23 @@ func (c *compilerContext) createFunctionDeclaration(f *ir.Function) {
771774
f.LLVMFn = llvm.AddFunction(c.mod, name, fnType)
772775
}
773776

777+
dereferenceableOrNullKind := llvm.AttributeKindID("dereferenceable_or_null")
778+
for i, typ := range paramTypes {
779+
if paramTypeVariants[i]&paramIsDeferenceableOrNull == 0 {
780+
continue
781+
}
782+
if typ.TypeKind() == llvm.PointerTypeKind {
783+
el := typ.ElementType()
784+
size := c.targetData.TypeAllocSize(el)
785+
if size == 0 {
786+
// dereferenceable_or_null(0) appears to be illegal in LLVM.
787+
continue
788+
}
789+
dereferenceableOrNull := c.ctx.CreateEnumAttribute(dereferenceableOrNullKind, size)
790+
f.LLVMFn.AddAttributeAtIndex(i+1, dereferenceableOrNull)
791+
}
792+
}
793+
774794
// External/exported functions may not retain pointer values.
775795
// https://golang.org/cmd/cgo/#hdr-Passing_pointers
776796
if f.IsExported() {
@@ -901,7 +921,8 @@ func (b *builder) createFunctionDefinition() {
901921
for _, param := range b.fn.Params {
902922
llvmType := b.getLLVMType(param.Type())
903923
fields := make([]llvm.Value, 0, 1)
904-
for range expandFormalParamType(llvmType) {
924+
fieldFragments, _ := expandFormalParamType(llvmType, nil)
925+
for range fieldFragments {
905926
fields = append(fields, b.fn.LLVMFn.Param(llvmParamIndex))
906927
llvmParamIndex++
907928
}

compiler/func.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,13 @@ func (c *compilerContext) getRawFuncType(typ *types.Signature) llvm.Type {
125125
// The receiver is not an interface, but a i8* type.
126126
recv = c.i8ptrType
127127
}
128-
paramTypes = append(paramTypes, expandFormalParamType(recv)...)
128+
recvFragments, _ := expandFormalParamType(recv, nil)
129+
paramTypes = append(paramTypes, recvFragments...)
129130
}
130131
for i := 0; i < typ.Params().Len(); i++ {
131132
subType := c.getLLVMType(typ.Params().At(i).Type())
132-
paramTypes = append(paramTypes, expandFormalParamType(subType)...)
133+
paramTypeFragments, _ := expandFormalParamType(subType, nil)
134+
paramTypes = append(paramTypes, paramTypeFragments...)
133135
}
134136
// All functions take these parameters at the end.
135137
paramTypes = append(paramTypes, c.i8ptrType) // context

compiler/interface.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ func (c *compilerContext) getInterfaceInvokeWrapper(f *ir.Function) llvm.Value {
437437

438438
// Get the expanded receiver type.
439439
receiverType := c.getLLVMType(f.Params[0].Type())
440-
expandedReceiverType := expandFormalParamType(receiverType)
440+
expandedReceiverType, _ := expandFormalParamType(receiverType, nil)
441441

442442
// Does this method even need any wrapping?
443443
if len(expandedReceiverType) == 1 && receiverType.TypeKind() == llvm.PointerTypeKind {

0 commit comments

Comments
 (0)