|
| 1 | +import |
| 2 | + typetraits, strutils, |
| 3 | + shims/macros, results |
| 4 | + |
| 5 | +const |
| 6 | + enforce_error_handling {.strdefine.}: string = "yes" |
| 7 | + errorHandlingEnforced = parseBool(enforce_error_handling) |
| 8 | + |
| 9 | +type |
| 10 | + VoidResult = object |
| 11 | + Raising*[ErrorList: tuple, ResultType] = distinct ResultType |
| 12 | + |
| 13 | +let |
| 14 | + raisesPragmaId {.compileTime.} = ident"raises" |
| 15 | + |
| 16 | +proc mergeTupleTypeSets(lhs, rhs: NimNode): NimNode = |
| 17 | + result = newTree(nnkPar) |
| 18 | + |
| 19 | + for i in 1 ..< lhs.len: |
| 20 | + result.add lhs[i] |
| 21 | + |
| 22 | + for i in 1 ..< rhs.len: |
| 23 | + block findMatch: |
| 24 | + for j in 1 ..< lhs.len: |
| 25 | + if sameType(rhs[i], lhs[i]): |
| 26 | + break findMatch |
| 27 | + |
| 28 | + result.add rhs[i] |
| 29 | + |
| 30 | +macro `++`*(lhs: type[tuple], rhs: type[tuple]): type = |
| 31 | + result = mergeTupleTypeSets(getType(lhs)[1], getType(rhs)[1]) |
| 32 | + |
| 33 | +proc genForwardingCall(procDef: NimNode): NimNode = |
| 34 | + result = newCall(procDef.name) |
| 35 | + for param, _ in procDef.typedParams: |
| 36 | + result.add param |
| 37 | + |
| 38 | +macro noerrors*(procDef: untyped) = |
| 39 | + let raisesPragma = procDef.pragma.findPragma(raisesPragmaId) |
| 40 | + if raisesPragma != nil: |
| 41 | + error "You should not specify `noerrors` and `raises` at the same time", |
| 42 | + raisesPragma |
| 43 | + var raisesList = newTree(nnkBracket, bindSym"Defect") |
| 44 | + procDef.addPragma newColonExpr(ident"raises", raisesList) |
| 45 | + return procDef |
| 46 | + |
| 47 | +macro errors*(ErrorsTuple: typed, procDef: untyped) = |
| 48 | + let raisesPragma = procDef.pragma.findPragma(raisesPragmaId) |
| 49 | + if raisesPragma != nil: |
| 50 | + error "You should not specify `errors` and `raises` at the same time", |
| 51 | + raisesPragma |
| 52 | + |
| 53 | + var raisesList = newTree(nnkBracket, bindSym"Defect") |
| 54 | + |
| 55 | + for i in 1 ..< ErrorsTuple.len: |
| 56 | + raisesList.add ErrorsTuple[i] |
| 57 | + |
| 58 | + procDef.addPragma newColonExpr(ident"raises", raisesList) |
| 59 | + |
| 60 | + when errorHandlingEnforced: |
| 61 | + # We are going to create a wrapper proc or a template |
| 62 | + # that calls the original one and wraps the returned |
| 63 | + # value in a Raising type. To achieve this, we must |
| 64 | + # generate a new name for the original proc: |
| 65 | + |
| 66 | + let |
| 67 | + generateTemplate = true |
| 68 | + OrigResultType = procDef.params[0] |
| 69 | + |
| 70 | + # Create the wrapper |
| 71 | + var |
| 72 | + wrapperDef: NimNode |
| 73 | + RaisingType: NimNode |
| 74 | + |
| 75 | + if generateTemplate: |
| 76 | + wrapperDef = newNimNode(nnkTemplateDef, procDef) |
| 77 | + procDef.copyChildrenTo wrapperDef |
| 78 | + # We must remove the raises list from the original proc |
| 79 | + wrapperDef.pragma = newEmptyNode() |
| 80 | + else: |
| 81 | + wrapperDef = copy procDef |
| 82 | + |
| 83 | + # Change the original proc name |
| 84 | + procDef.name = genSym(nskProc, $procDef.name) |
| 85 | + |
| 86 | + var wrapperBody = newNimNode(nnkStmtList, procDef.body) |
| 87 | + if OrigResultType.kind == nnkEmpty or eqIdent(OrigResultType, "void"): |
| 88 | + RaisingType = newTree(nnkBracketExpr, ident"Raising", |
| 89 | + ErrorsTuple, bindSym"VoidResult") |
| 90 | + wrapperBody.add( |
| 91 | + genForwardingCall(procDef), |
| 92 | + newCall(RaisingType, newTree(nnkObjConstr, bindSym"VoidResult"))) |
| 93 | + else: |
| 94 | + RaisingType = newTree(nnkBracketExpr, ident"Raising", |
| 95 | + ErrorsTuple, OrigResultType) |
| 96 | + wrapperBody.add newCall(RaisingType, genForwardingCall(procDef)) |
| 97 | + |
| 98 | + wrapperDef.params[0] = if generateTemplate: ident"untyped" |
| 99 | + else: RaisingType |
| 100 | + wrapperDef.body = wrapperBody |
| 101 | + |
| 102 | + result = newStmtList(procDef, wrapperDef) |
| 103 | + else: |
| 104 | + result = procDef |
| 105 | + |
| 106 | + storeMacroResult result |
| 107 | + |
| 108 | +macro checkForUnhandledErrors(origHandledErrors, raisedErrors: typed) = |
| 109 | + # This macro is executed with two tuples: |
| 110 | + # |
| 111 | + # 1. The list of errors handled at the call-site which will |
| 112 | + # have a line info matching the call-site. |
| 113 | + # 2. The list of errors that the called function is raising. |
| 114 | + # The lineinfo here points to the definition of the function. |
| 115 | + |
| 116 | + # For accidental reasons, the first tuple will be recognized as a |
| 117 | + # typedesc, while the second won't be (beware because this can be |
| 118 | + # considered a bug in Nim): |
| 119 | + var handledErrors = getTypeInst(origHandledErrors) |
| 120 | + if handledErrors.kind == nnkBracketExpr: |
| 121 | + handledErrors = handledErrors[1] |
| 122 | + |
| 123 | + assert handledErrors.kind == nnkTupleConstr and |
| 124 | + raisedErrors.kind == nnkTupleConstr |
| 125 | + |
| 126 | + # Here, we'll store the list of errors that the user missed: |
| 127 | + var unhandledErrors = newTree(nnkPar) |
| 128 | + |
| 129 | + # We loop through the raised errors and check whether they have |
| 130 | + # an appropriate handler: |
| 131 | + for raised in raisedErrors: |
| 132 | + block findHandler: |
| 133 | + template tryFindingHandler(raisedType) = |
| 134 | + for handled in handledErrors: |
| 135 | + if sameType(raisedType, handled): |
| 136 | + break findHandler |
| 137 | + |
| 138 | + tryFindingHandler raised |
| 139 | + # A base type of the raised exception may be handled instead |
| 140 | + for baseType in raised.baseTypes: |
| 141 | + tryFindingHandler baseType |
| 142 | + |
| 143 | + unhandledErrors.add raised |
| 144 | + |
| 145 | + if unhandledErrors.len > 0: |
| 146 | + let errMsg = "The following errors are not handled: $1" % [unhandledErrors.repr] |
| 147 | + error errMsg, origHandledErrors |
| 148 | + |
| 149 | +template raising*[E, R](x: Raising[E, R]): R = |
| 150 | + ## `raising` is used to mark locations in the code that might |
| 151 | + ## raise exceptions. It disarms the type-safety checks imposed |
| 152 | + ## by the `errors` pragma. |
| 153 | + distinctBase(x) |
| 154 | + |
| 155 | +macro chk*[R, E](x: Raising[R, E], handlers: untyped): untyped = |
| 156 | + ## The `chk` macro can be used in 2 different ways |
| 157 | + ## |
| 158 | + ## 1) Try to get the result of an expression. In case of any |
| 159 | + ## errors, substitute the result with a default value: |
| 160 | + ## |
| 161 | + ## ``` |
| 162 | + ## let x = chk(foo(), defaultValue) |
| 163 | + ## ``` |
| 164 | + ## |
| 165 | + ## We'll handle this case with a simple rewrite to |
| 166 | + ## |
| 167 | + ## ``` |
| 168 | + ## let x = try: distinctBase(foo()) |
| 169 | + ## except CatchableError: defaultValue |
| 170 | + ## ``` |
| 171 | + ## |
| 172 | + ## 2) Try to get the result of an expression while providing exception |
| 173 | + ## handlers that must cover all possible recoverable errors. |
| 174 | + ## |
| 175 | + ## ``` |
| 176 | + ## let x = chk foo(): |
| 177 | + ## KeyError as err: defaultValue |
| 178 | + ## ValueError: return |
| 179 | + ## _: raise |
| 180 | + ## ``` |
| 181 | + ## |
| 182 | + ## The above example will be rewritten to: |
| 183 | + ## |
| 184 | + ## ``` |
| 185 | + ## let x = try: |
| 186 | + ## foo() |
| 187 | + ## except KeyError as err: |
| 188 | + ## defaultValue |
| 189 | + ## except ValueError: |
| 190 | + ## return |
| 191 | + ## except CatchableError: |
| 192 | + ## raise |
| 193 | + ## ``` |
| 194 | + ## |
| 195 | + ## Please note that the special case `_` is considered equivalent to |
| 196 | + ## `CatchableError`. |
| 197 | + ## |
| 198 | + ## If the `chk` block lacks a default handler and there are unlisted |
| 199 | + ## recoverable errors, the compiler will fail to compile the code with |
| 200 | + ## a message indicating the missing ones. |
| 201 | + let |
| 202 | + RaisingType = getTypeInst(x) |
| 203 | + ErrorsSetTuple = RaisingType[1] |
| 204 | + ResultType = RaisingType[2] |
| 205 | + |
| 206 | + # The `try` branch is the same in all scenarios. We generate it here. |
| 207 | + # The target AST looks roughly like this: |
| 208 | + # |
| 209 | + # TryStmt |
| 210 | + # StmtList |
| 211 | + # Call |
| 212 | + # Ident "distinctBase" |
| 213 | + # Call |
| 214 | + # Ident "foo" |
| 215 | + # ExceptBranch |
| 216 | + # Ident "CatchableError" |
| 217 | + # StmtList |
| 218 | + # Ident "defaultValue" |
| 219 | + result = newTree(nnkTryStmt, newStmtList( |
| 220 | + newCall(bindSym"distinctBase", x))) |
| 221 | + |
| 222 | + # Check how the API was used: |
| 223 | + if handlers.kind != nnkStmtList: |
| 224 | + # This is usage type 1: chk(foo(), defaultValue) |
| 225 | + result.add newTree(nnkExceptBranch, |
| 226 | + bindSym("CatchableError"), |
| 227 | + newStmtList(handlers)) |
| 228 | + else: |
| 229 | + var |
| 230 | + # This will be a tuple of all the errors handled by the `chk` block. |
| 231 | + # In the end, we'll compare it to the Raising list. |
| 232 | + HandledErrorsTuple = newNimNode(nnkPar, x) |
| 233 | + # Has the user provided a default `_: value` handler? |
| 234 | + defaultCatchProvided = false |
| 235 | + |
| 236 | + for handler in handlers: |
| 237 | + template err(msg: string) = error msg, handler |
| 238 | + template unexpectedSyntax = err( |
| 239 | + "The `chk` handlers block should consist of `ExceptionType: Value/Block` pairs") |
| 240 | + |
| 241 | + case handler.kind |
| 242 | + of nnkCommentStmt: |
| 243 | + continue |
| 244 | + of nnkInfix: |
| 245 | + if eqIdent(handler[0], "as"): |
| 246 | + if handler.len != 4: |
| 247 | + err "The expected syntax is `ExceptionType as exceptionVar: Value/Block`" |
| 248 | + let |
| 249 | + ExceptionType = handler[1] |
| 250 | + exceptionVar = handler[2] |
| 251 | + valueBlock = handler[3] |
| 252 | + |
| 253 | + HandledErrorsTuple.add ExceptionType |
| 254 | + result.add newTree(nnkExceptBranch, infix(ExceptionType, "as", exceptionVar), |
| 255 | + valueBlock) |
| 256 | + else: |
| 257 | + err "The '$1' operator is not expected in a `chk` block" % [$handler[0]] |
| 258 | + of nnkCall: |
| 259 | + if handler.len != 2: |
| 260 | + unexpectedSyntax |
| 261 | + let ExceptionType = handler[0] |
| 262 | + if eqIdent(ExceptionType, "_"): |
| 263 | + if defaultCatchProvided: |
| 264 | + err "Only a single default handler is expected" |
| 265 | + handler[0] = bindSym"CatchableError" |
| 266 | + defaultCatchProvided = true |
| 267 | + |
| 268 | + result.add newTree(nnkExceptBranch, handler[0], handler[1]) |
| 269 | + HandledErrorsTuple.add handler[0] |
| 270 | + else: |
| 271 | + unexpectedSyntax |
| 272 | + |
| 273 | + result = newTree(nnkStmtListExpr, |
| 274 | + newCall(bindSym"checkForUnhandledErrors", HandledErrorsTuple, ErrorsSetTuple), |
| 275 | + result) |
| 276 | + |
| 277 | + storeMacroResult result |
0 commit comments