Skip to content

Commit 156c73d

Browse files
authored
Merge pull request #22 from vhavlena/rule-transformation
Rule transformation
2 parents 450743b + f8b3d0c commit 156c73d

File tree

7 files changed

+663
-82
lines changed

7 files changed

+663
-82
lines changed

pkg/smt/exprs.go

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/open-policy-agent/opa/v1/ast"
88
verr "github.com/vhavlena/verirego/pkg/err"
9+
"github.com/vhavlena/verirego/pkg/types"
910
)
1011

1112
// ExprTranslator handles the translation of Rego expressions to SMT-LIB format.
@@ -95,6 +96,226 @@ func (et *ExprTranslator) ExprToSmt(expr *ast.Expr) (string, error) {
9596
return smtStr, nil
9697
}
9798

99+
type varDef struct { string; SmtValue }
100+
101+
func (et *ExprTranslator) BodyToSmt(ruleBody *ast.Body) (*SmtProposition,[]varDef,error) {
102+
bodySmts := make([]SmtProposition, 0, len(*ruleBody))
103+
definedVars := make(map[string]bool, 0)
104+
localVarDefs := make([]varDef, 0)
105+
for _, expr := range *ruleBody {
106+
// single term
107+
if term, ok := expr.Terms.(*ast.Term); ok {
108+
smtVal,err := et.termToSmtValue(term)
109+
if err != nil {
110+
return nil, localVarDefs, err
111+
}
112+
bodySmts = append(bodySmts, *smtVal.Holds())
113+
continue
114+
}
115+
116+
// call
117+
terms, ok := expr.Terms.([]*ast.Term)
118+
if !ok || len(terms) == 0 {
119+
return nil, localVarDefs, verr.ErrInvalidEmptyTerm
120+
}
121+
122+
opStr := removeQuotes(terms[0].String())
123+
op,err := getOperation(opStr)
124+
if err != nil {
125+
return nil, localVarDefs, err
126+
}
127+
128+
arity := op.Decl.Arity()
129+
params := len(terms)-1
130+
if arity < params { // the return is a part of the call
131+
def, err := et.handleAssigningFunction(opStr, terms)
132+
if err != nil {
133+
return nil, localVarDefs, err
134+
}
135+
localVarDefs = append(localVarDefs, *def)
136+
definedVars[def.string] = true
137+
continue
138+
}
139+
140+
// we handle ast.Equality separately, because it can be both assignment and comparison, based on the context
141+
if op == ast.Equality {
142+
if variable,ok := terms[1].Value.(ast.Var); ok {
143+
// create variable
144+
rhs := terms[2]
145+
val,err := et.termToSmtValue(rhs)
146+
if err != nil {
147+
return nil, localVarDefs, err
148+
}
149+
150+
name := removeQuotes(variable.String())
151+
if definedVars[name] != true {
152+
localVarDefs = append(localVarDefs, varDef{name, *val})
153+
definedVars[name] = true
154+
} else {
155+
varSmt, err := et.GetVarValue(variable)
156+
if err != nil {
157+
return nil, nil, err
158+
}
159+
bodySmts = append(bodySmts, *varSmt.Equals(val))
160+
}
161+
continue
162+
}
163+
}
164+
165+
// Convert all arguments
166+
args, err := et.getOperationArgSmts(terms)
167+
if err != nil {
168+
return nil, localVarDefs, err
169+
}
170+
171+
// Use regoFuncToSmt to get the SMT-LIB string for the operator
172+
bodySmt, err := et.getOperationValue(opStr, args, terms)
173+
if err != nil {
174+
return nil, localVarDefs, err
175+
}
176+
bodySmts = append(bodySmts, *bodySmt.Holds())
177+
}
178+
179+
bodySmt := And(bodySmts)
180+
return bodySmt, localVarDefs, nil
181+
}
182+
183+
func (et *ExprTranslator) handleAssigningFunction(op string, terms []*ast.Term) (*varDef, error) {
184+
if name,ok := terms[len(terms)-1].Value.(ast.Var); ok { // creating variable
185+
// remove assigned variable from call
186+
rhs := terms[0:len(terms)-1]
187+
args, err := et.getOperationArgSmts(rhs)
188+
if err != nil {
189+
return nil, err
190+
}
191+
val, err := et.getOperationValue(op,args,rhs)
192+
if err != nil {
193+
return nil, err
194+
}
195+
196+
return &varDef{name.String(), *val}, nil
197+
}
198+
return nil, verr.ErrUnsupportedFunction // this should be unreachable
199+
}
200+
201+
func (et *ExprTranslator) getOperationArgSmts(terms []*ast.Term) ([]string, error) {
202+
args := make([]string, len(terms)-1)
203+
for i := 1; i < len(terms); i++ {
204+
s, err := et.termToSmt(terms[i])
205+
if err != nil {
206+
return nil, err
207+
}
208+
args[i-1] = s
209+
}
210+
return args, nil
211+
}
212+
213+
func (et *ExprTranslator) getOperationValue(op string, args []string, rhs []*ast.Term) (*SmtValue, error) {
214+
val,err := et.regoFuncToSmt(op,args,rhs)
215+
if err != nil {
216+
return nil, err
217+
}
218+
construct, opType, err := getAtomConstructorForOperation(op)
219+
if err != nil {
220+
return nil, err
221+
}
222+
if construct != "" {
223+
val = fmt.Sprintf("(%s %s)", construct, val)
224+
}
225+
return &SmtValue{value: val, depth: 0, atomics: []types.AtomicType{opType,types.AtomicUndef}}, nil // TODO: user-defined functions
226+
}
227+
228+
func getOperation(op string) (*ast.Builtin,error) {
229+
switch op {
230+
case ast.Plus.Name:
231+
return ast.Plus,nil
232+
case ast.Minus.Name:
233+
return ast.Minus,nil
234+
case ast.Multiply.Name:
235+
return ast.Multiply,nil
236+
case ast.Divide.Name:
237+
return ast.Divide,nil
238+
case ast.Equal.Name:
239+
return ast.Equal,nil
240+
case ast.Equality.Name:
241+
return ast.Equality,nil
242+
case ast.Assign.Name:
243+
return ast.Assign,nil
244+
case ast.GreaterThan.Name:
245+
return ast.GreaterThan,nil
246+
case ast.GreaterThanEq.Name:
247+
return ast.GreaterThanEq,nil
248+
case ast.LessThan.Name:
249+
return ast.LessThan,nil
250+
case ast.LessThanEq.Name:
251+
return ast.LessThanEq,nil
252+
case ast.Concat.Name:
253+
return ast.Concat,nil
254+
case ast.Contains.Name:
255+
return ast.Contains,nil
256+
case ast.StartsWith.Name:
257+
return ast.StartsWith,nil
258+
case ast.EndsWith.Name:
259+
return ast.EndsWith,nil
260+
case ast.IndexOf.Name:
261+
return ast.IndexOf,nil
262+
case ast.Substring.Name:
263+
return ast.Substring,nil
264+
default:
265+
return nil,verr.ErrUnsupportedFunction
266+
}
267+
}
268+
269+
func /*(et *ExprTranslator)*/ getOperationReturnType(opName string) (types.AtomicType,error) {
270+
funcMap := map[string]types.AtomicType{
271+
ast.Plus.Name: types.AtomicInt, // +
272+
ast.Minus.Name: types.AtomicInt, // -
273+
ast.Multiply.Name: types.AtomicInt, // *
274+
ast.Divide.Name: types.AtomicInt, // /
275+
ast.Equal.Name: types.AtomicBoolean, // ==
276+
ast.Equality.Name: types.AtomicBoolean, // =
277+
ast.Assign.Name: types.AtomicBoolean, // :=
278+
ast.GreaterThan.Name: types.AtomicBoolean, // >
279+
ast.GreaterThanEq.Name: types.AtomicBoolean, // >=
280+
ast.LessThan.Name: types.AtomicBoolean, // <
281+
ast.LessThanEq.Name: types.AtomicBoolean, // <=
282+
ast.Concat.Name: types.AtomicString, // concat
283+
ast.Contains.Name: types.AtomicBoolean, // contains
284+
ast.StartsWith.Name: types.AtomicBoolean, // startswith
285+
ast.EndsWith.Name: types.AtomicBoolean, // endswith
286+
ast.IndexOf.Name: types.AtomicInt, // indexof
287+
ast.Substring.Name: types.AtomicString, // substring
288+
// "length" does not exist
289+
}
290+
291+
// TODO: user defined functions
292+
if atomicType,found := funcMap[opName]; found {
293+
return atomicType,nil
294+
}
295+
return types.AtomicUndef,verr.ErrUnsupportedFunction
296+
}
297+
298+
func getAtomConstructorForOperation(op string) (string,types.AtomicType,error) {
299+
opType,err := getOperationReturnType(op)
300+
if err != nil {
301+
return "", "", verr.ErrUnsupportedFunction
302+
}
303+
return getAtomConstructorFromType(opType),opType,nil
304+
}
305+
306+
func getAtomConstructorFromType(t types.AtomicType) string {
307+
switch t {
308+
case types.AtomicString:
309+
return "OString"
310+
case types.AtomicInt:
311+
return "ONumber"
312+
case types.AtomicBoolean:
313+
return "OBoolean"
314+
default:
315+
return ""
316+
}
317+
}
318+
98319
// termToSmt converts a Rego AST term to its SMT-LIB string representation.
99320
//
100321
// Parameters:
@@ -148,6 +369,112 @@ func (et *ExprTranslator) termToSmt(term *ast.Term) (string, error) {
148369
}
149370
}
150371

372+
func (et *ExprTranslator) termToSmtValue(term *ast.Term) (*SmtValue, error) {
373+
switch v := term.Value.(type) {
374+
case ast.String:
375+
return NewSmtValueFromString(string(v)), nil
376+
case ast.Number:
377+
if val,ok := v.Int(); ok {
378+
return NewSmtValueFromInt(val), nil
379+
}
380+
return nil,verr.ErrUnsupportedAtomic
381+
case ast.Boolean:
382+
return NewSmtValueFromBoolean(bool(v)), nil
383+
case *ast.Array:
384+
return et.arrayToSmt(v)
385+
case ast.Object:
386+
return et.objectToSmt(v)
387+
case ast.Set:
388+
// Not directly supported in SMT-LIB, return error
389+
return nil, verr.ErrSetConversionNotSupported
390+
case ast.Var:
391+
return et.GetVarValue(v)
392+
case ast.Ref:
393+
name := removeQuotes(v[len(v)-1].String())
394+
tp, ok := et.TypeTrans.TypeInfo.Types[name]
395+
if !ok {
396+
return nil, verr.ErrTypeNotFound
397+
}
398+
return NewSmtValue(name, tp.TypeDepth()), nil
399+
case ast.Call:
400+
// Handle string functions and other builtins
401+
op := removeQuotes(v[0].String())
402+
args := make([]string, len(v)-1)
403+
for i := 1; i < len(v); i++ {
404+
s, err := et.termToSmt(v[i])
405+
if err != nil {
406+
return nil, err
407+
}
408+
args[i-1] = s
409+
}
410+
return et.getOperationValue(op, args, v)
411+
default:
412+
return nil, fmt.Errorf("%w: %T", verr.ErrUnsupportedTermType, v)
413+
}
414+
}
415+
416+
func (et *ExprTranslator) GetVarValue(v ast.Var) (*SmtValue, error) {
417+
return NewSmtValueFromVar(v, et)
418+
}
419+
420+
func (et *ExprTranslator) arrayToSmt(arr *ast.Array) (*SmtValue, error) {
421+
tp, ok := et.TypeTrans.TypeInfo.Types[arr.String()]
422+
if !ok {
423+
return nil, verr.ErrTypeNotFound
424+
}
425+
426+
depth := tp.TypeDepth()
427+
arrSmt := createConstArray("Int", depth)
428+
429+
for index := range arr.Len() {
430+
val := arr.Elem(index)
431+
valSmt,err := et.termToSmtValue(val)
432+
if err != nil {
433+
return nil,err
434+
}
435+
valSmt = valSmt.WrapToDepth(depth-1)
436+
arrSmt = fmt.Sprintf("(store %s %d %s)", arrSmt, index, valSmt.String())
437+
}
438+
439+
return &SmtValue{
440+
value: fmt.Sprintf("(OArray%d %s)",depth,arrSmt),
441+
depth: depth,
442+
}, nil
443+
}
444+
445+
func (et *ExprTranslator) objectToSmt(obj ast.Object) (*SmtValue, error) {
446+
tp, ok := et.TypeTrans.TypeInfo.Types[obj.String()]
447+
if !ok {
448+
return nil, verr.ErrTypeNotFound
449+
}
450+
451+
depth := tp.TypeDepth()
452+
objSmt := createConstArray("String", depth)
453+
454+
for _, key := range obj.Keys() {
455+
val := obj.Get(key)
456+
valSmt,err := et.termToSmtValue(val)
457+
if err != nil {
458+
return nil,err
459+
}
460+
valSmt = valSmt.WrapToDepth(depth-1)
461+
objSmt = fmt.Sprintf("(store %s %s %s)", objSmt, key.String(), valSmt.String())
462+
}
463+
464+
return &SmtValue{
465+
value: fmt.Sprintf("(OObj%d %s)",depth,objSmt),
466+
depth: depth,
467+
}, nil
468+
}
469+
470+
func createConstArray(keyType string, depth int) string {
471+
undefChild := "OUndef"
472+
for d := range depth-1 {
473+
undefChild = fmt.Sprintf("(Atom%d %s)",d+1,undefChild)
474+
}
475+
return fmt.Sprintf("((as const (Array %s OTypeD%d)) %s)",keyType ,depth-1, undefChild)
476+
}
477+
151478
// regoFuncToSmt converts a Rego function/operator name and its arguments to an SMT-LIB function application string.
152479
// If the operator is a known built-in, it maps to the corresponding SMT-LIB function. Otherwise, it declares an uninterpreted function
153480
// with the appropriate parameter and return types (using declareUnintFunc) and returns its application.

0 commit comments

Comments
 (0)