@@ -59,6 +59,8 @@ package vta
5959import (
6060 "go/types"
6161
62+ "maps"
63+
6264 "golang.org/x/tools/go/callgraph"
6365 "golang.org/x/tools/go/ssa"
6466)
@@ -190,3 +192,37 @@ func (mc methodCache) methods(t types.Type, name string, prog *ssa.Program) []*s
190192 mc [t ] = ms
191193 return ms [name ]
192194}
195+
196+ // typeAssertTypes returns a mapping from each type assertion instruction in `f` to the possible types of its input variable.
197+ func typeAssertTypes (f * ssa.Function , typesMap * propTypeMap , cache methodCache ) map [* ssa.TypeAssert ][]types.Type {
198+ asserts := typeAsserts (f )
199+ result := make (map [* ssa.TypeAssert ][]types.Type )
200+
201+ for _ , ta := range asserts {
202+ inputVal := ta .X
203+ n := local {val : inputVal }
204+
205+ var possTypes []types.Type
206+ typesMap .propTypes (n )(func (p propType ) bool {
207+ possTypes = append (possTypes , p .typ )
208+ return true
209+ })
210+
211+ result [ta ] = possTypes
212+ }
213+
214+ return result
215+ }
216+
217+ func GetTypeAsserts (funcs map [* ssa.Function ]bool , initial * callgraph.Graph ) map [* ssa.TypeAssert ][]types.Type {
218+ callees := makeCalleesFunc (funcs , initial )
219+ vtaG , canon := typePropGraph (funcs , callees )
220+ typesMap := propagate (vtaG , canon )
221+ result := make (map [* ssa.TypeAssert ][]types.Type )
222+ for f , in := range funcs {
223+ if in {
224+ maps .Copy (result , typeAssertTypes (f , & typesMap , methodCache {}))
225+ }
226+ }
227+ return result
228+ }
0 commit comments