|
6 | 6 |
|
7 | 7 | "github.com/open-policy-agent/opa/v1/ast" |
8 | 8 | verr "github.com/vhavlena/verirego/pkg/err" |
| 9 | + "github.com/vhavlena/verirego/pkg/types" |
9 | 10 | ) |
10 | 11 |
|
11 | 12 | // ExprTranslator handles the translation of Rego expressions to SMT-LIB format. |
@@ -95,6 +96,226 @@ func (et *ExprTranslator) ExprToSmt(expr *ast.Expr) (string, error) { |
95 | 96 | return smtStr, nil |
96 | 97 | } |
97 | 98 |
|
| 99 | +type varDef struct { string; SmtValue } |
| 100 | + |
| 101 | +func (et *ExprTranslator) BodyToSmt(ruleBody *ast.Body) (*SmtProposition,[]varDef,error) { |
| 102 | + bodySmts := make([]SmtProposition, 0, len(*ruleBody)) |
| 103 | + definedVars := make(map[string]bool, 0) |
| 104 | + localVarDefs := make([]varDef, 0) |
| 105 | + for _, expr := range *ruleBody { |
| 106 | + // single term |
| 107 | + if term, ok := expr.Terms.(*ast.Term); ok { |
| 108 | + smtVal,err := et.termToSmtValue(term) |
| 109 | + if err != nil { |
| 110 | + return nil, localVarDefs, err |
| 111 | + } |
| 112 | + bodySmts = append(bodySmts, *smtVal.Holds()) |
| 113 | + continue |
| 114 | + } |
| 115 | + |
| 116 | + // call |
| 117 | + terms, ok := expr.Terms.([]*ast.Term) |
| 118 | + if !ok || len(terms) == 0 { |
| 119 | + return nil, localVarDefs, verr.ErrInvalidEmptyTerm |
| 120 | + } |
| 121 | + |
| 122 | + opStr := removeQuotes(terms[0].String()) |
| 123 | + op,err := getOperation(opStr) |
| 124 | + if err != nil { |
| 125 | + return nil, localVarDefs, err |
| 126 | + } |
| 127 | + |
| 128 | + arity := op.Decl.Arity() |
| 129 | + params := len(terms)-1 |
| 130 | + if arity < params { // the return is a part of the call |
| 131 | + def, err := et.handleAssigningFunction(opStr, terms) |
| 132 | + if err != nil { |
| 133 | + return nil, localVarDefs, err |
| 134 | + } |
| 135 | + localVarDefs = append(localVarDefs, *def) |
| 136 | + definedVars[def.string] = true |
| 137 | + continue |
| 138 | + } |
| 139 | + |
| 140 | + // we handle ast.Equality separately, because it can be both assignment and comparison, based on the context |
| 141 | + if op == ast.Equality { |
| 142 | + if variable,ok := terms[1].Value.(ast.Var); ok { |
| 143 | + // create variable |
| 144 | + rhs := terms[2] |
| 145 | + val,err := et.termToSmtValue(rhs) |
| 146 | + if err != nil { |
| 147 | + return nil, localVarDefs, err |
| 148 | + } |
| 149 | + |
| 150 | + name := removeQuotes(variable.String()) |
| 151 | + if definedVars[name] != true { |
| 152 | + localVarDefs = append(localVarDefs, varDef{name, *val}) |
| 153 | + definedVars[name] = true |
| 154 | + } else { |
| 155 | + varSmt, err := et.GetVarValue(variable) |
| 156 | + if err != nil { |
| 157 | + return nil, nil, err |
| 158 | + } |
| 159 | + bodySmts = append(bodySmts, *varSmt.Equals(val)) |
| 160 | + } |
| 161 | + continue |
| 162 | + } |
| 163 | + } |
| 164 | + |
| 165 | + // Convert all arguments |
| 166 | + args, err := et.getOperationArgSmts(terms) |
| 167 | + if err != nil { |
| 168 | + return nil, localVarDefs, err |
| 169 | + } |
| 170 | + |
| 171 | + // Use regoFuncToSmt to get the SMT-LIB string for the operator |
| 172 | + bodySmt, err := et.getOperationValue(opStr, args, terms) |
| 173 | + if err != nil { |
| 174 | + return nil, localVarDefs, err |
| 175 | + } |
| 176 | + bodySmts = append(bodySmts, *bodySmt.Holds()) |
| 177 | + } |
| 178 | + |
| 179 | + bodySmt := And(bodySmts) |
| 180 | + return bodySmt, localVarDefs, nil |
| 181 | +} |
| 182 | + |
| 183 | +func (et *ExprTranslator) handleAssigningFunction(op string, terms []*ast.Term) (*varDef, error) { |
| 184 | + if name,ok := terms[len(terms)-1].Value.(ast.Var); ok { // creating variable |
| 185 | + // remove assigned variable from call |
| 186 | + rhs := terms[0:len(terms)-1] |
| 187 | + args, err := et.getOperationArgSmts(rhs) |
| 188 | + if err != nil { |
| 189 | + return nil, err |
| 190 | + } |
| 191 | + val, err := et.getOperationValue(op,args,rhs) |
| 192 | + if err != nil { |
| 193 | + return nil, err |
| 194 | + } |
| 195 | + |
| 196 | + return &varDef{name.String(), *val}, nil |
| 197 | + } |
| 198 | + return nil, verr.ErrUnsupportedFunction // this should be unreachable |
| 199 | +} |
| 200 | + |
| 201 | +func (et *ExprTranslator) getOperationArgSmts(terms []*ast.Term) ([]string, error) { |
| 202 | + args := make([]string, len(terms)-1) |
| 203 | + for i := 1; i < len(terms); i++ { |
| 204 | + s, err := et.termToSmt(terms[i]) |
| 205 | + if err != nil { |
| 206 | + return nil, err |
| 207 | + } |
| 208 | + args[i-1] = s |
| 209 | + } |
| 210 | + return args, nil |
| 211 | +} |
| 212 | + |
| 213 | +func (et *ExprTranslator) getOperationValue(op string, args []string, rhs []*ast.Term) (*SmtValue, error) { |
| 214 | + val,err := et.regoFuncToSmt(op,args,rhs) |
| 215 | + if err != nil { |
| 216 | + return nil, err |
| 217 | + } |
| 218 | + construct, opType, err := getAtomConstructorForOperation(op) |
| 219 | + if err != nil { |
| 220 | + return nil, err |
| 221 | + } |
| 222 | + if construct != "" { |
| 223 | + val = fmt.Sprintf("(%s %s)", construct, val) |
| 224 | + } |
| 225 | + return &SmtValue{value: val, depth: 0, atomics: []types.AtomicType{opType,types.AtomicUndef}}, nil // TODO: user-defined functions |
| 226 | +} |
| 227 | + |
| 228 | +func getOperation(op string) (*ast.Builtin,error) { |
| 229 | + switch op { |
| 230 | + case ast.Plus.Name: |
| 231 | + return ast.Plus,nil |
| 232 | + case ast.Minus.Name: |
| 233 | + return ast.Minus,nil |
| 234 | + case ast.Multiply.Name: |
| 235 | + return ast.Multiply,nil |
| 236 | + case ast.Divide.Name: |
| 237 | + return ast.Divide,nil |
| 238 | + case ast.Equal.Name: |
| 239 | + return ast.Equal,nil |
| 240 | + case ast.Equality.Name: |
| 241 | + return ast.Equality,nil |
| 242 | + case ast.Assign.Name: |
| 243 | + return ast.Assign,nil |
| 244 | + case ast.GreaterThan.Name: |
| 245 | + return ast.GreaterThan,nil |
| 246 | + case ast.GreaterThanEq.Name: |
| 247 | + return ast.GreaterThanEq,nil |
| 248 | + case ast.LessThan.Name: |
| 249 | + return ast.LessThan,nil |
| 250 | + case ast.LessThanEq.Name: |
| 251 | + return ast.LessThanEq,nil |
| 252 | + case ast.Concat.Name: |
| 253 | + return ast.Concat,nil |
| 254 | + case ast.Contains.Name: |
| 255 | + return ast.Contains,nil |
| 256 | + case ast.StartsWith.Name: |
| 257 | + return ast.StartsWith,nil |
| 258 | + case ast.EndsWith.Name: |
| 259 | + return ast.EndsWith,nil |
| 260 | + case ast.IndexOf.Name: |
| 261 | + return ast.IndexOf,nil |
| 262 | + case ast.Substring.Name: |
| 263 | + return ast.Substring,nil |
| 264 | + default: |
| 265 | + return nil,verr.ErrUnsupportedFunction |
| 266 | + } |
| 267 | +} |
| 268 | + |
| 269 | +func /*(et *ExprTranslator)*/ getOperationReturnType(opName string) (types.AtomicType,error) { |
| 270 | + funcMap := map[string]types.AtomicType{ |
| 271 | + ast.Plus.Name: types.AtomicInt, // + |
| 272 | + ast.Minus.Name: types.AtomicInt, // - |
| 273 | + ast.Multiply.Name: types.AtomicInt, // * |
| 274 | + ast.Divide.Name: types.AtomicInt, // / |
| 275 | + ast.Equal.Name: types.AtomicBoolean, // == |
| 276 | + ast.Equality.Name: types.AtomicBoolean, // = |
| 277 | + ast.Assign.Name: types.AtomicBoolean, // := |
| 278 | + ast.GreaterThan.Name: types.AtomicBoolean, // > |
| 279 | + ast.GreaterThanEq.Name: types.AtomicBoolean, // >= |
| 280 | + ast.LessThan.Name: types.AtomicBoolean, // < |
| 281 | + ast.LessThanEq.Name: types.AtomicBoolean, // <= |
| 282 | + ast.Concat.Name: types.AtomicString, // concat |
| 283 | + ast.Contains.Name: types.AtomicBoolean, // contains |
| 284 | + ast.StartsWith.Name: types.AtomicBoolean, // startswith |
| 285 | + ast.EndsWith.Name: types.AtomicBoolean, // endswith |
| 286 | + ast.IndexOf.Name: types.AtomicInt, // indexof |
| 287 | + ast.Substring.Name: types.AtomicString, // substring |
| 288 | + // "length" does not exist |
| 289 | + } |
| 290 | + |
| 291 | + // TODO: user defined functions |
| 292 | + if atomicType,found := funcMap[opName]; found { |
| 293 | + return atomicType,nil |
| 294 | + } |
| 295 | + return types.AtomicUndef,verr.ErrUnsupportedFunction |
| 296 | +} |
| 297 | + |
| 298 | +func getAtomConstructorForOperation(op string) (string,types.AtomicType,error) { |
| 299 | + opType,err := getOperationReturnType(op) |
| 300 | + if err != nil { |
| 301 | + return "", "", verr.ErrUnsupportedFunction |
| 302 | + } |
| 303 | + return getAtomConstructorFromType(opType),opType,nil |
| 304 | +} |
| 305 | + |
| 306 | +func getAtomConstructorFromType(t types.AtomicType) string { |
| 307 | + switch t { |
| 308 | + case types.AtomicString: |
| 309 | + return "OString" |
| 310 | + case types.AtomicInt: |
| 311 | + return "ONumber" |
| 312 | + case types.AtomicBoolean: |
| 313 | + return "OBoolean" |
| 314 | + default: |
| 315 | + return "" |
| 316 | + } |
| 317 | +} |
| 318 | + |
98 | 319 | // termToSmt converts a Rego AST term to its SMT-LIB string representation. |
99 | 320 | // |
100 | 321 | // Parameters: |
@@ -148,6 +369,112 @@ func (et *ExprTranslator) termToSmt(term *ast.Term) (string, error) { |
148 | 369 | } |
149 | 370 | } |
150 | 371 |
|
| 372 | +func (et *ExprTranslator) termToSmtValue(term *ast.Term) (*SmtValue, error) { |
| 373 | + switch v := term.Value.(type) { |
| 374 | + case ast.String: |
| 375 | + return NewSmtValueFromString(string(v)), nil |
| 376 | + case ast.Number: |
| 377 | + if val,ok := v.Int(); ok { |
| 378 | + return NewSmtValueFromInt(val), nil |
| 379 | + } |
| 380 | + return nil,verr.ErrUnsupportedAtomic |
| 381 | + case ast.Boolean: |
| 382 | + return NewSmtValueFromBoolean(bool(v)), nil |
| 383 | + case *ast.Array: |
| 384 | + return et.arrayToSmt(v) |
| 385 | + case ast.Object: |
| 386 | + return et.objectToSmt(v) |
| 387 | + case ast.Set: |
| 388 | + // Not directly supported in SMT-LIB, return error |
| 389 | + return nil, verr.ErrSetConversionNotSupported |
| 390 | + case ast.Var: |
| 391 | + return et.GetVarValue(v) |
| 392 | + case ast.Ref: |
| 393 | + name := removeQuotes(v[len(v)-1].String()) |
| 394 | + tp, ok := et.TypeTrans.TypeInfo.Types[name] |
| 395 | + if !ok { |
| 396 | + return nil, verr.ErrTypeNotFound |
| 397 | + } |
| 398 | + return NewSmtValue(name, tp.TypeDepth()), nil |
| 399 | + case ast.Call: |
| 400 | + // Handle string functions and other builtins |
| 401 | + op := removeQuotes(v[0].String()) |
| 402 | + args := make([]string, len(v)-1) |
| 403 | + for i := 1; i < len(v); i++ { |
| 404 | + s, err := et.termToSmt(v[i]) |
| 405 | + if err != nil { |
| 406 | + return nil, err |
| 407 | + } |
| 408 | + args[i-1] = s |
| 409 | + } |
| 410 | + return et.getOperationValue(op, args, v) |
| 411 | + default: |
| 412 | + return nil, fmt.Errorf("%w: %T", verr.ErrUnsupportedTermType, v) |
| 413 | + } |
| 414 | +} |
| 415 | + |
| 416 | +func (et *ExprTranslator) GetVarValue(v ast.Var) (*SmtValue, error) { |
| 417 | + return NewSmtValueFromVar(v, et) |
| 418 | +} |
| 419 | + |
| 420 | +func (et *ExprTranslator) arrayToSmt(arr *ast.Array) (*SmtValue, error) { |
| 421 | + tp, ok := et.TypeTrans.TypeInfo.Types[arr.String()] |
| 422 | + if !ok { |
| 423 | + return nil, verr.ErrTypeNotFound |
| 424 | + } |
| 425 | + |
| 426 | + depth := tp.TypeDepth() |
| 427 | + arrSmt := createConstArray("Int", depth) |
| 428 | + |
| 429 | + for index := range arr.Len() { |
| 430 | + val := arr.Elem(index) |
| 431 | + valSmt,err := et.termToSmtValue(val) |
| 432 | + if err != nil { |
| 433 | + return nil,err |
| 434 | + } |
| 435 | + valSmt = valSmt.WrapToDepth(depth-1) |
| 436 | + arrSmt = fmt.Sprintf("(store %s %d %s)", arrSmt, index, valSmt.String()) |
| 437 | + } |
| 438 | + |
| 439 | + return &SmtValue{ |
| 440 | + value: fmt.Sprintf("(OArray%d %s)",depth,arrSmt), |
| 441 | + depth: depth, |
| 442 | + }, nil |
| 443 | +} |
| 444 | + |
| 445 | +func (et *ExprTranslator) objectToSmt(obj ast.Object) (*SmtValue, error) { |
| 446 | + tp, ok := et.TypeTrans.TypeInfo.Types[obj.String()] |
| 447 | + if !ok { |
| 448 | + return nil, verr.ErrTypeNotFound |
| 449 | + } |
| 450 | + |
| 451 | + depth := tp.TypeDepth() |
| 452 | + objSmt := createConstArray("String", depth) |
| 453 | + |
| 454 | + for _, key := range obj.Keys() { |
| 455 | + val := obj.Get(key) |
| 456 | + valSmt,err := et.termToSmtValue(val) |
| 457 | + if err != nil { |
| 458 | + return nil,err |
| 459 | + } |
| 460 | + valSmt = valSmt.WrapToDepth(depth-1) |
| 461 | + objSmt = fmt.Sprintf("(store %s %s %s)", objSmt, key.String(), valSmt.String()) |
| 462 | + } |
| 463 | + |
| 464 | + return &SmtValue{ |
| 465 | + value: fmt.Sprintf("(OObj%d %s)",depth,objSmt), |
| 466 | + depth: depth, |
| 467 | + }, nil |
| 468 | +} |
| 469 | + |
| 470 | +func createConstArray(keyType string, depth int) string { |
| 471 | + undefChild := "OUndef" |
| 472 | + for d := range depth-1 { |
| 473 | + undefChild = fmt.Sprintf("(Atom%d %s)",d+1,undefChild) |
| 474 | + } |
| 475 | + return fmt.Sprintf("((as const (Array %s OTypeD%d)) %s)",keyType ,depth-1, undefChild) |
| 476 | +} |
| 477 | + |
151 | 478 | // regoFuncToSmt converts a Rego function/operator name and its arguments to an SMT-LIB function application string. |
152 | 479 | // If the operator is a known built-in, it maps to the corresponding SMT-LIB function. Otherwise, it declares an uninterpreted function |
153 | 480 | // with the appropriate parameter and return types (using declareUnintFunc) and returns its application. |
|
0 commit comments