Skip to content

Commit 62f22df

Browse files
authored
Merge pull request #24 from vhavlena/types-func
Types: Inference function types
2 parents bedffd8 + 3304cf4 commit 62f22df

File tree

5 files changed

+671
-17
lines changed

5 files changed

+671
-17
lines changed

pkg/smt/translator.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ func (t *Translator) getSmtVarsDeclare() map[string]any {
166166
if t.TypeTrans.TypeInfo != nil {
167167
for name := range t.TypeTrans.TypeInfo.Types {
168168
if _, isParam := inputParamSet[name]; !isParam {
169+
tp := t.TypeTrans.TypeInfo.Types[name]
170+
// Skip function definitions: they describe call signatures, not
171+
// runtime variables that need SMT declarations.
172+
if tp.IsFunction() {
173+
continue
174+
}
169175
_, okVar := t.TypeTrans.TypeInfo.Refs[name].(ast.Var)
170176
if okVar {
171177
globalVars[name] = struct{}{}

pkg/smt/translator_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,30 @@ func TestTranslateModuleToSmt_Basic(t *testing.T) {
117117
}
118118
}
119119
}
120+
121+
// TestGetSmtVarsDeclare_SkipsFunctionTypes verifies that KindFunction entries in
122+
// the type map are excluded from the set of global SMT variables to declare.
123+
// User-defined functions contribute a KindFunction entry to the TypeAnalyzer —
124+
// including them would cause GenerateVarDecl to fail with ErrUnsupportedType.
125+
func TestGetSmtVarsDeclare_SkipsFunctionTypes(t *testing.T) {
126+
t.Parallel()
127+
rego := `package test
128+
add_one(x) := y if { y := x + 1 }
129+
p := z if { z := 1 }
130+
`
131+
mod, err := ast.ParseModule("test.rego", rego)
132+
if err != nil {
133+
t.Fatalf("failed to parse rego: %v", err)
134+
}
135+
136+
ta := types.NewTypeAnalyzerWithParams(mod.Package.Path, types.NewInputSchema(), nil)
137+
ta.AnalyzeModule(mod)
138+
139+
tr := NewTranslator(ta, mod)
140+
globalVars := tr.getSmtVarsDeclare()
141+
142+
// "add_one" must be absent – it is a KindFunction entry, not a variable.
143+
if _, found := globalVars["add_one"]; found {
144+
t.Error("KindFunction entry 'add_one' should not appear in globalVars for SMT declaration")
145+
}
146+
}

pkg/types/type_analysis.go

Lines changed: 125 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
package types
33

44
import (
5+
"strings"
6+
57
"github.com/open-policy-agent/opa/v1/ast"
68
)
79

@@ -94,6 +96,12 @@ func (ta *TypeAnalyzer) GetType(val ast.Value) RegoTypeDef {
9496
//
9597
// val ast.Value: The AST value to set the type for.
9698
// typ RegoTypeDef: The type to assign to the value.
99+
//
100+
// Promotion across kinds is intentionally not supported: IsMorePrecise returns
101+
// false when the incoming and existing types have different, non-unknown kinds.
102+
// This is by design — valid Rego never uses the same symbol as both a plain rule
103+
// and a user-defined function, so cross-kind name collisions cannot occur in
104+
// practice. If they did, the later setType call would be silently ignored.
97105
func (ta *TypeAnalyzer) setType(val ast.Value, typ RegoTypeDef) {
98106
key := ta.getValueKey(val)
99107
if existingType, exists := ta.Types[key]; exists {
@@ -199,8 +207,8 @@ func (ta *TypeAnalyzer) InferExprType(expr *ast.Expr) RegoTypeDef {
199207

200208
return NewAtomicType(AtomicBoolean)
201209
} else {
202-
// Handle function calls
203-
funcType, funcParams := funcParamsType(operator.String(), len(terms)-1)
210+
// Handle function calls (user-defined functions are checked first)
211+
funcType, funcParams := ta.resolveFunctionType(operator.String(), len(terms)-1)
204212
for i := 1; i < len(terms); i++ {
205213
ta.InferTermType(terms[i], &funcParams[i-1])
206214
}
@@ -280,7 +288,7 @@ func (ta *TypeAnalyzer) inferAstType(val ast.Value, inherType *RegoTypeDef) Rego
280288
}
281289
case ast.Call:
282290
operator := v[0]
283-
funcType, funcParams := funcParamsType(operator.String(), len(v)-1)
291+
funcType, funcParams := ta.resolveFunctionType(operator.String(), len(v)-1)
284292
for i := 1; i < len(v); i++ {
285293
ta.InferTermType(v[i], &funcParams[i-1])
286294
}
@@ -359,9 +367,11 @@ func (ta *TypeAnalyzer) inferRefType(ref ast.Ref) RegoTypeDef {
359367

360368
// AnalyzeRule analyzes the given Rego rule and records the inferred type for the rule head.
361369
//
362-
// AnalyzeRule constructs a union type to collect possible return types produced by the
363-
// rule's body (including any else branches). It delegates the body analysis to
364-
// AnalyzeRuleBody which appends discovered return types into the union. After analysis
370+
// For parametric rules (functions) — those whose head carries at least one argument —
371+
// analysis is delegated to analyzeParametricRule which produces a FunctionTypeDef.
372+
// For plain rules AnalyzeRule constructs a union type to collect possible return types
373+
// produced by the rule's body (including any else branches). It delegates the body analysis
374+
// to AnalyzeRuleBody which appends discovered return types into the union. After analysis
365375
// the union is canonicalized and stored in the analyzer's type map under the rule head's
366376
// name via ta.setType.
367377
//
@@ -373,12 +383,121 @@ func (ta *TypeAnalyzer) inferRefType(ref ast.Ref) RegoTypeDef {
373383
// - rule *ast.Rule: the Rego rule to analyze. The function expects a valid rule with a
374384
// head; behavior for malformed rules follows the underlying setType logic.
375385
func (ta *TypeAnalyzer) AnalyzeRule(rule *ast.Rule) {
386+
if len(rule.Head.Args) > 0 {
387+
ta.analyzeParametricRule(rule)
388+
return
389+
}
376390
tp := NewUnionType([]RegoTypeDef{})
377391
ta.AnalyzeRuleBody(rule, &tp)
378392
tp.CanonizeUnion()
379393
ta.setType(rule.Head.Name, tp)
380394
}
381395

396+
// analyzeParametricRule analyzes a Rego function / parametric rule and stores a
397+
// FunctionTypeDef under the rule's name in the type map.
398+
//
399+
// The rule body is analyzed exactly as for plain rules; any type information inferred
400+
// for the parameter variables (e.g. via equality constraints in the body or head value)
401+
// is then collected and combined with the inferred return type to produce a
402+
// KindFunction type definition.
403+
//
404+
// Parameters:
405+
// - rule *ast.Rule: a parametric rule (rule.Head.Args must be non-empty).
406+
func (ta *TypeAnalyzer) analyzeParametricRule(rule *ast.Rule) {
407+
// Analyze the body and head value, collecting return types into tp.
408+
tp := NewUnionType([]RegoTypeDef{})
409+
ta.AnalyzeRuleBody(rule, &tp)
410+
tp.CanonizeUnion()
411+
412+
// Collect inferred types for each parameter variable.
413+
paramTypes := make([]RegoTypeDef, len(rule.Head.Args))
414+
for i, arg := range rule.Head.Args {
415+
if v, ok := arg.Value.(ast.Var); ok {
416+
if t, exists := ta.Types[string(v)]; exists {
417+
paramTypes[i] = t
418+
} else {
419+
paramTypes[i] = NewUnknownType()
420+
}
421+
} else {
422+
paramTypes[i] = NewUnknownType()
423+
}
424+
}
425+
426+
funcName := string(rule.Head.Name)
427+
funcType := NewFunctionType(funcName, paramTypes, tp)
428+
ta.setType(rule.Head.Name, funcType)
429+
}
430+
431+
// lookupFunction searches the type map for a user-defined function by name.
432+
//
433+
// Two lookups are attempted in order:
434+
// 1. Exact match on `name` (works for uncompiled modules where the operator is
435+
// the bare function name, e.g. "add_one").
436+
// 2. Prefix-stripped match: if the analyzer has a package path and `name` starts
437+
// with "<packagePath>." (e.g. "data.test.add_one"), the prefix is removed and
438+
// the short name is looked up (handles fully-qualified references produced by
439+
// OPA's compiler).
440+
//
441+
// Returns the FunctionTypeDef if a KindFunction entry is found, or nil otherwise.
442+
func (ta *TypeAnalyzer) lookupFunction(name string) *FunctionTypeDef {
443+
if ft, exists := ta.Types[name]; exists && ft.IsFunction() && ft.FunctionDef != nil {
444+
return ft.FunctionDef
445+
}
446+
if ta.packagePath != nil {
447+
prefix := ta.packagePath.String() + "."
448+
if strings.HasPrefix(name, prefix) {
449+
shortName := name[len(prefix):]
450+
if ft, exists := ta.Types[shortName]; exists && ft.IsFunction() && ft.FunctionDef != nil {
451+
return ft.FunctionDef
452+
}
453+
}
454+
}
455+
return nil
456+
}
457+
458+
// resolveFunctionType returns the expected return type and parameter types for a
459+
// function call. User-defined functions (stored as KindFunction entries in the type
460+
// map) are checked first; the call falls back to the predefined-function registry.
461+
//
462+
// Two arities are accepted for user-defined functions:
463+
// - Exact match (len(ParamTypes) == arity): direct call, works for predicates and
464+
// uncompiled value-returning functions.
465+
// - Arity +1 (len(ParamTypes)+1 == arity): compiled value-returning function where
466+
// OPA appends the output variable as the last argument. In this case the last
467+
// parameter slot is given the function's return type so that the output variable
468+
// is correctly typed by the caller.
469+
//
470+
// Parameters:
471+
//
472+
// name string: The function name as it appears in the call expression.
473+
// arity int: The number of arguments supplied at the call site.
474+
//
475+
// Returns:
476+
//
477+
// RegoTypeDef: The expected return type.
478+
// []RegoTypeDef: A fresh slice with the expected type for each argument position.
479+
func (ta *TypeAnalyzer) resolveFunctionType(name string, arity int) (RegoTypeDef, []RegoTypeDef) {
480+
if fd := ta.lookupFunction(name); fd != nil {
481+
nParams := len(fd.ParamTypes)
482+
if nParams == arity {
483+
// Exact arity: predicate or uncompiled value-returning call.
484+
paramsCopy := make([]RegoTypeDef, nParams)
485+
copy(paramsCopy, fd.ParamTypes)
486+
return fd.ReturnType, paramsCopy
487+
}
488+
if nParams+1 == arity {
489+
// Compiled value-returning call: OPA appended the output variable as
490+
// the last argument. Propagate the return type to that slot so the
491+
// output variable inherits the correct type.
492+
paramsCopy := make([]RegoTypeDef, arity)
493+
copy(paramsCopy[:nParams], fd.ParamTypes)
494+
paramsCopy[nParams] = fd.ReturnType
495+
return fd.ReturnType, paramsCopy
496+
}
497+
}
498+
return funcParamsType(name, arity)
499+
}
500+
382501
// AnalyzeRuleBody analyzes the body (and else branches) of a rule and appends discovered
383502
// return types to the provided union type `tp`.
384503
//

0 commit comments

Comments
 (0)