Skip to content

Commit b92d7cd

Browse files
authored
Merge pull request #137 from stanford-centaur/version/0.3.6
feat: Catch concrete exceptions
2 parents fc6bd09 + a5dccf2 commit b92d7cd

1 file changed

Lines changed: 68 additions & 42 deletions

File tree

pantograph/server.py

Lines changed: 68 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,11 @@ def __del__(self):
130130
pass #self._close()
131131

132132
def _close(self):
133-
if self.proc:
134-
self.proc.terminate()
135-
self.proc = None
133+
if not self.proc:
134+
return
135+
136+
self.proc.terminate()
137+
self.proc = None
136138

137139
def is_automatic(self):
138140
"""
@@ -141,6 +143,9 @@ def is_automatic(self):
141143
return self.options.get("automaticMode", True)
142144

143145
async def restart_async(self):
146+
"""
147+
Restart the server
148+
"""
144149
self._close()
145150
env = os.environ
146151
if self.lean_path:
@@ -181,13 +186,22 @@ async def run_async(self, cmd, payload):
181186
command = f"{cmd} {s}\n"
182187
self.proc.stdin.write(command.encode())
183188
await self.proc.stdin.drain()
189+
line = ""
184190
try:
185191
line = await asyncio.wait_for(self.proc.stdout.readline(), self.timeout)
192+
except asyncio.TimeoutError as e:
193+
self._close()
194+
raise ServerError("Server reached timeout limit") from e
195+
196+
try:
186197
line = line.decode().strip()
187198
return json.loads(line)
188-
except Exception as e:
199+
except UnicodeDecodeError as e:
200+
self._close()
201+
raise ServerError(f"Could not decode process output: {line}") from e
202+
except json.JSONDecodeError as e:
189203
self._close()
190-
raise ServerError("Cannot decode Json object. A server error may have occurred.") from e
204+
raise ServerError(f"Cannot decode Json object from: {line}") from e
191205

192206
run = to_sync(run_async)
193207

@@ -253,26 +267,27 @@ async def goal_tactic_async(self, state: GoalState, tactic: Tactic, site: Site =
253267
Execute a tactic on `goal_id` of `state`
254268
"""
255269
args = {"stateId": state.state_id, **site.serial()}
256-
if isinstance(tactic, str):
257-
args["tactic"] = tactic
258-
elif isinstance(tactic, TacticHave):
259-
args["have"] = tactic.branch
260-
if tactic.binder_name:
261-
args["binderName"] = tactic.binder_name
262-
elif isinstance(tactic, TacticLet):
263-
args["let"] = tactic.branch
264-
if tactic.binder_name:
265-
args["binderName"] = tactic.binder_name
266-
elif isinstance(tactic, TacticExpr):
267-
args["expr"] = tactic.expr
268-
elif isinstance(tactic, TacticDraft):
269-
args["draft"] = tactic.expr
270-
elif isinstance(tactic, TacticMode):
271-
args["mode"] = tactic.serial()
272-
else:
273-
raise RuntimeError(f"Invalid tactic type: {tactic}")
270+
match tactic:
271+
case str():
272+
args["tactic"] = tactic
273+
case TacticHave():
274+
args["have"] = tactic.branch
275+
if tactic.binder_name:
276+
args["binderName"] = tactic.binder_name
277+
case TacticLet():
278+
args["let"] = tactic.branch
279+
if tactic.binder_name:
280+
args["binderName"] = tactic.binder_name
281+
case TacticExpr():
282+
args["expr"] = tactic.expr
283+
case TacticDraft():
284+
args["draft"] = tactic.expr
285+
case TacticMode():
286+
args["mode"] = tactic.serial()
287+
case _:
288+
raise RuntimeError(f"Invalid tactic type: {type(tactic)}")
274289
result = await self.run_async('goal.tactic', args)
275-
nextStateId = result.get("nextStateId")
290+
next_state_id = result.get("nextStateId")
276291
if "error" in result:
277292
raise ServerError(result)
278293
if "parseError" in result:
@@ -283,10 +298,10 @@ async def goal_tactic_async(self, state: GoalState, tactic: Tactic, site: Site =
283298
raise TacticFailure([Message.parse(m) for m in messages])
284299

285300
if result["hasSorry"]:
286-
await self.run_async('goal.delete', {'stateIds': [nextStateId]})
301+
await self.run_async('goal.delete', {'stateIds': [next_state_id]})
287302
raise TacticFailure("Tactic generated sorry", messages)
288303
if result["hasUnsafe"]:
289-
await self.run_async('goal.delete', {'stateIds': [nextStateId]})
304+
await self.run_async('goal.delete', {'stateIds': [next_state_id]})
290305
raise TacticFailure("Tactic generated unsafe", messages)
291306

292307
return GoalState.parse(result, messages, self.to_remove_goal_states)
@@ -307,7 +322,8 @@ async def goal_continue_async(self, target: GoalState, branch: GoalState) -> Goa
307322
raise ServerError(result)
308323
if "parseError" in result:
309324
raise ServerError(result)
310-
return GoalState.parse(result, self.to_remove_goal_states)
325+
return GoalState.parse(result, [], self.to_remove_goal_states)
326+
311327
goal_continue = to_sync(goal_continue_async)
312328

313329
async def goal_resume_async(self, state: GoalState, goals: list[Goal]) -> GoalState:
@@ -324,10 +340,12 @@ async def goal_resume_async(self, state: GoalState, goals: list[Goal]) -> GoalSt
324340
raise ServerError(result)
325341
if "parseError" in result:
326342
raise ServerError(result)
327-
return GoalState.parse(result, self.to_remove_goal_states)
343+
return GoalState.parse(result, [], self.to_remove_goal_states)
328344
goal_resume = to_sync(goal_resume_async)
329345

330-
async def env_add_async(self, name: str, levels: list[str], t: Expr, v: Expr, is_theorem: bool = True):
346+
async def env_add_async(
347+
self, name: str, levels: list[str],
348+
t: Expr, v: Expr, is_theorem: bool = True):
331349
"""
332350
Adds a definition to the environment.
333351
@@ -379,17 +397,21 @@ async def env_module_read_async(self, module: str) -> dict:
379397
return result
380398
env_module_read = to_sync(env_module_read_async)
381399

382-
async def env_parse_async(self, input: str, category: str="tactic") -> Tuple[str, str]:
400+
async def env_parse_async(self, src: str, category: str = "tactic") -> Tuple[str, str]:
401+
"""
402+
Parse an input using a syntax category's parser. Returns the parsed
403+
component and the tail.
404+
"""
383405
result = await self.run_async('env.parse', {
384-
"input": input,
406+
"input": src,
385407
"category": category,
386408
})
387409
if "error" in result:
388410
if result['error'] == 'parse':
389411
raise ParseError(result["desc"])
390412
raise ServerError(result["desc"])
391413
pos = result["pos"]
392-
s = input.encode()
414+
s = src.encode()
393415
return s[:pos].decode(), s[pos:].decode()
394416

395417
env_parse = to_sync(env_parse_async)
@@ -448,7 +470,11 @@ async def goal_load_async(self, path: str) -> GoalState:
448470
})
449471
if "error" in result:
450472
raise ServerError(result["desc"])
451-
return GoalState.parse_inner(state_id, result['goals'], [], self.to_remove_goal_states)
473+
return GoalState.parse_inner(
474+
state_id,
475+
result['goals'], [],
476+
self.to_remove_goal_states,
477+
)
452478

453479
goal_load = to_sync(goal_load_async)
454480

@@ -520,8 +546,8 @@ async def load_definitions_async(self, snippet: str):
520546
async def check_compile_async(
521547
self,
522548
code: str,
523-
new_constants: bool=False,
524-
read_header: bool=False):
549+
new_constants: bool = False,
550+
read_header: bool = False):
525551
"""
526552
Check if some Lean code compiles
527553
"""
@@ -546,12 +572,12 @@ async def check_compile_async(
546572
async def load_sorry_async(
547573
self,
548574
src: str,
549-
binder_name: Optional[str]=None,
550-
ignore_values: bool=True) -> list[SearchTarget]:
575+
binder_name: Optional[str] = None,
576+
ignore_values: bool = True) -> list[SearchTarget]:
551577
"""
552578
Condense search target into goals
553579
"""
554-
args = { "file": src, "ignoreValues": ignore_values }
580+
args = {"file": src, "ignoreValues": ignore_values}
555581
if binder_name is not None:
556582
args["binderName"] = binder_name
557583
result = await self.run_async('frontend.distil', args)
@@ -569,7 +595,7 @@ async def check_track_async(self, src: str, dst: str) -> CheckTrackResult:
569595
"""
570596
Checks if `dst` file conforms to the specifications in `src`
571597
"""
572-
result = await self.run_async('frontend.track', { "src": src, "dst": dst })
598+
result = await self.run_async('frontend.track', {"src": src, "dst": dst})
573599
if "error" in result:
574600
raise ServerError(result)
575601
src_messages = [Message.parse(d) for d in result["srcMessages"]]
@@ -627,13 +653,13 @@ def test_server_init_del(self):
627653
with warnings.catch_warnings():
628654
warnings.simplefilter("error", ResourceWarning)
629655
server = Server()
630-
t = server.expr_type("forall (n m: Nat), n + m = m + n")
656+
server.expr_type("forall (n m: Nat), n + m = m + n")
631657
del server
632658
server = Server()
633-
t = server.expr_type("forall (n m: Nat), n + m = m + n")
659+
server.expr_type("forall (n m: Nat), n + m = m + n")
634660
del server
635661
server = Server()
636-
t = server.expr_type("forall (n m: Nat), n + m = m + n")
662+
server.expr_type("forall (n m: Nat), n + m = m + n")
637663
del server
638664

639665
def test_expr_type(self):

0 commit comments

Comments
 (0)