22package types
33
44import (
5+ "strings"
6+
57 "github.com/open-policy-agent/opa/v1/ast"
68)
79
@@ -94,6 +96,12 @@ func (ta *TypeAnalyzer) GetType(val ast.Value) RegoTypeDef {
9496//
9597// val ast.Value: The AST value to set the type for.
9698// typ RegoTypeDef: The type to assign to the value.
99+ //
100+ // Promotion across kinds is intentionally not supported: IsMorePrecise returns
101+ // false when the incoming and existing types have different, non-unknown kinds.
102+ // This is by design — valid Rego never uses the same symbol as both a plain rule
103+ // and a user-defined function, so cross-kind name collisions cannot occur in
104+ // practice. If they did, the later setType call would be silently ignored.
97105func (ta * TypeAnalyzer ) setType (val ast.Value , typ RegoTypeDef ) {
98106 key := ta .getValueKey (val )
99107 if existingType , exists := ta .Types [key ]; exists {
@@ -199,8 +207,8 @@ func (ta *TypeAnalyzer) InferExprType(expr *ast.Expr) RegoTypeDef {
199207
200208 return NewAtomicType (AtomicBoolean )
201209 } else {
202- // Handle function calls
203- funcType , funcParams := funcParamsType (operator .String (), len (terms )- 1 )
210+ // Handle function calls (user-defined functions are checked first)
211+ funcType , funcParams := ta . resolveFunctionType (operator .String (), len (terms )- 1 )
204212 for i := 1 ; i < len (terms ); i ++ {
205213 ta .InferTermType (terms [i ], & funcParams [i - 1 ])
206214 }
@@ -280,7 +288,7 @@ func (ta *TypeAnalyzer) inferAstType(val ast.Value, inherType *RegoTypeDef) Rego
280288 }
281289 case ast.Call :
282290 operator := v [0 ]
283- funcType , funcParams := funcParamsType (operator .String (), len (v )- 1 )
291+ funcType , funcParams := ta . resolveFunctionType (operator .String (), len (v )- 1 )
284292 for i := 1 ; i < len (v ); i ++ {
285293 ta .InferTermType (v [i ], & funcParams [i - 1 ])
286294 }
@@ -359,9 +367,11 @@ func (ta *TypeAnalyzer) inferRefType(ref ast.Ref) RegoTypeDef {
359367
360368// AnalyzeRule analyzes the given Rego rule and records the inferred type for the rule head.
361369//
362- // AnalyzeRule constructs a union type to collect possible return types produced by the
363- // rule's body (including any else branches). It delegates the body analysis to
364- // AnalyzeRuleBody which appends discovered return types into the union. After analysis
370+ // For parametric rules (functions) — those whose head carries at least one argument —
371+ // analysis is delegated to analyzeParametricRule which produces a FunctionTypeDef.
372+ // For plain rules AnalyzeRule constructs a union type to collect possible return types
373+ // produced by the rule's body (including any else branches). It delegates the body analysis
374+ // to AnalyzeRuleBody which appends discovered return types into the union. After analysis
365375// the union is canonicalized and stored in the analyzer's type map under the rule head's
366376// name via ta.setType.
367377//
@@ -373,12 +383,121 @@ func (ta *TypeAnalyzer) inferRefType(ref ast.Ref) RegoTypeDef {
373383// - rule *ast.Rule: the Rego rule to analyze. The function expects a valid rule with a
374384// head; behavior for malformed rules follows the underlying setType logic.
375385func (ta * TypeAnalyzer ) AnalyzeRule (rule * ast.Rule ) {
386+ if len (rule .Head .Args ) > 0 {
387+ ta .analyzeParametricRule (rule )
388+ return
389+ }
376390 tp := NewUnionType ([]RegoTypeDef {})
377391 ta .AnalyzeRuleBody (rule , & tp )
378392 tp .CanonizeUnion ()
379393 ta .setType (rule .Head .Name , tp )
380394}
381395
396+ // analyzeParametricRule analyzes a Rego function / parametric rule and stores a
397+ // FunctionTypeDef under the rule's name in the type map.
398+ //
399+ // The rule body is analyzed exactly as for plain rules; any type information inferred
400+ // for the parameter variables (e.g. via equality constraints in the body or head value)
401+ // is then collected and combined with the inferred return type to produce a
402+ // KindFunction type definition.
403+ //
404+ // Parameters:
405+ // - rule *ast.Rule: a parametric rule (rule.Head.Args must be non-empty).
406+ func (ta * TypeAnalyzer ) analyzeParametricRule (rule * ast.Rule ) {
407+ // Analyze the body and head value, collecting return types into tp.
408+ tp := NewUnionType ([]RegoTypeDef {})
409+ ta .AnalyzeRuleBody (rule , & tp )
410+ tp .CanonizeUnion ()
411+
412+ // Collect inferred types for each parameter variable.
413+ paramTypes := make ([]RegoTypeDef , len (rule .Head .Args ))
414+ for i , arg := range rule .Head .Args {
415+ if v , ok := arg .Value .(ast.Var ); ok {
416+ if t , exists := ta .Types [string (v )]; exists {
417+ paramTypes [i ] = t
418+ } else {
419+ paramTypes [i ] = NewUnknownType ()
420+ }
421+ } else {
422+ paramTypes [i ] = NewUnknownType ()
423+ }
424+ }
425+
426+ funcName := string (rule .Head .Name )
427+ funcType := NewFunctionType (funcName , paramTypes , tp )
428+ ta .setType (rule .Head .Name , funcType )
429+ }
430+
431+ // lookupFunction searches the type map for a user-defined function by name.
432+ //
433+ // Two lookups are attempted in order:
434+ // 1. Exact match on `name` (works for uncompiled modules where the operator is
435+ // the bare function name, e.g. "add_one").
436+ // 2. Prefix-stripped match: if the analyzer has a package path and `name` starts
437+ // with "<packagePath>." (e.g. "data.test.add_one"), the prefix is removed and
438+ // the short name is looked up (handles fully-qualified references produced by
439+ // OPA's compiler).
440+ //
441+ // Returns the FunctionTypeDef if a KindFunction entry is found, or nil otherwise.
442+ func (ta * TypeAnalyzer ) lookupFunction (name string ) * FunctionTypeDef {
443+ if ft , exists := ta .Types [name ]; exists && ft .IsFunction () && ft .FunctionDef != nil {
444+ return ft .FunctionDef
445+ }
446+ if ta .packagePath != nil {
447+ prefix := ta .packagePath .String () + "."
448+ if strings .HasPrefix (name , prefix ) {
449+ shortName := name [len (prefix ):]
450+ if ft , exists := ta .Types [shortName ]; exists && ft .IsFunction () && ft .FunctionDef != nil {
451+ return ft .FunctionDef
452+ }
453+ }
454+ }
455+ return nil
456+ }
457+
458+ // resolveFunctionType returns the expected return type and parameter types for a
459+ // function call. User-defined functions (stored as KindFunction entries in the type
460+ // map) are checked first; the call falls back to the predefined-function registry.
461+ //
462+ // Two arities are accepted for user-defined functions:
463+ // - Exact match (len(ParamTypes) == arity): direct call, works for predicates and
464+ // uncompiled value-returning functions.
465+ // - Arity +1 (len(ParamTypes)+1 == arity): compiled value-returning function where
466+ // OPA appends the output variable as the last argument. In this case the last
467+ // parameter slot is given the function's return type so that the output variable
468+ // is correctly typed by the caller.
469+ //
470+ // Parameters:
471+ //
472+ // name string: The function name as it appears in the call expression.
473+ // arity int: The number of arguments supplied at the call site.
474+ //
475+ // Returns:
476+ //
477+ // RegoTypeDef: The expected return type.
478+ // []RegoTypeDef: A fresh slice with the expected type for each argument position.
479+ func (ta * TypeAnalyzer ) resolveFunctionType (name string , arity int ) (RegoTypeDef , []RegoTypeDef ) {
480+ if fd := ta .lookupFunction (name ); fd != nil {
481+ nParams := len (fd .ParamTypes )
482+ if nParams == arity {
483+ // Exact arity: predicate or uncompiled value-returning call.
484+ paramsCopy := make ([]RegoTypeDef , nParams )
485+ copy (paramsCopy , fd .ParamTypes )
486+ return fd .ReturnType , paramsCopy
487+ }
488+ if nParams + 1 == arity {
489+ // Compiled value-returning call: OPA appended the output variable as
490+ // the last argument. Propagate the return type to that slot so the
491+ // output variable inherits the correct type.
492+ paramsCopy := make ([]RegoTypeDef , arity )
493+ copy (paramsCopy [:nParams ], fd .ParamTypes )
494+ paramsCopy [nParams ] = fd .ReturnType
495+ return fd .ReturnType , paramsCopy
496+ }
497+ }
498+ return funcParamsType (name , arity )
499+ }
500+
382501// AnalyzeRuleBody analyzes the body (and else branches) of a rule and appends discovered
383502// return types to the provided union type `tp`.
384503//
0 commit comments