|
| 1 | +""" |
| 2 | + Note: Do not remove base_flow and locator import. |
| 3 | +""" |
| 4 | + |
| 5 | + |
| 6 | +from typing import Any, Dict, List, Optional, Tuple |
| 7 | + |
| 8 | +import copy |
| 9 | + |
| 10 | +from freak.engine import Engine |
| 11 | +from freak.flows.base_flow import base_flow as choice_flow |
| 12 | +from freak.flows.base_flow import locator |
| 13 | +from freak.models.request import RequestContext |
| 14 | +from freak.models.response import EngineResponse, Response |
| 15 | + |
| 16 | + |
| 17 | +class ChoiceFlowEngine(Engine): |
| 18 | + def __init__( |
| 19 | + self, module_name: str, decorator_name: str = "choice_flow" |
| 20 | + ) -> None: |
| 21 | + super().__init__(module_name=module_name, decorator_name=decorator_name) |
| 22 | + |
| 23 | + def get_following_steps( |
| 24 | + self, from_step: Optional[str], path_traversed: Dict[str, Any] |
| 25 | + ) -> List[str]: |
| 26 | + step_graph = self.flow.predecessor |
| 27 | + if not from_step: |
| 28 | + # pick up root of flow. |
| 29 | + from_step = step_graph[from_step][0] |
| 30 | + |
| 31 | + next_steps = step_graph.get(from_step, []) |
| 32 | + |
| 33 | + last_step = path_traversed.get("last_step", "") |
| 34 | + if from_step not in step_graph.get(last_step, []) and last_step: |
| 35 | + raise Exception("CannotExecuteError") |
| 36 | + |
| 37 | + return next_steps |
| 38 | + |
| 39 | + def get_next_step_uid( |
| 40 | + self, resp_ctx: Response, next_steps: List[str] |
| 41 | + ) -> str: |
| 42 | + if resp_ctx.choice: |
| 43 | + if resp_ctx.choice not in next_steps: |
| 44 | + raise Exception("InvalidChoice") |
| 45 | + |
| 46 | + return resp_ctx.choice |
| 47 | + |
| 48 | + assert len(next_steps) == 1 |
| 49 | + |
| 50 | + return next_steps[0] |
| 51 | + |
| 52 | + def execute( |
| 53 | + self, |
| 54 | + from_step: Optional[str], |
| 55 | + data: Dict[str, Any], |
| 56 | + executed_steps: Dict[str, Any] = {"traversed": {}, "last_step": ""}, |
| 57 | + ) -> Tuple[EngineResponse, Dict[str, Any]]: |
| 58 | + path_traversed = copy.deepcopy(executed_steps) |
| 59 | + |
| 60 | + step = self.get_step(from_step=from_step) |
| 61 | + next_steps = self.get_following_steps( |
| 62 | + from_step=from_step, |
| 63 | + path_traversed=path_traversed, |
| 64 | + ) |
| 65 | + |
| 66 | + responses = [] |
| 67 | + last_successful_step, to_step, from_ = ( |
| 68 | + step.order, |
| 69 | + step.order, |
| 70 | + step.order, |
| 71 | + ) |
| 72 | + |
| 73 | + while True: |
| 74 | + ctx = RequestContext(input=data, name=step.name, order=step.order) |
| 75 | + |
| 76 | + resp_ctx = step.function(ctx=ctx) # type: ignore |
| 77 | + responses.append(resp_ctx) |
| 78 | + to_step = step.order # this will refer to last performed step. |
| 79 | + if not resp_ctx.success: |
| 80 | + break |
| 81 | + |
| 82 | + path = next_steps.copy() |
| 83 | + if next_steps: |
| 84 | + next_step_uid = self.get_next_step_uid( |
| 85 | + resp_ctx=resp_ctx, next_steps=next_steps |
| 86 | + ) |
| 87 | + path = [next_step_uid] |
| 88 | + |
| 89 | + path_traversed["traversed"][step.uid] = path |
| 90 | + path_traversed["last_step"] = step.uid |
| 91 | + |
| 92 | + data = resp_ctx.input |
| 93 | + |
| 94 | + # this will refer last successfully performed action. |
| 95 | + last_successful_step = step.order |
| 96 | + |
| 97 | + if not next_steps: |
| 98 | + break |
| 99 | + |
| 100 | + next_step_uid = self.get_next_step_uid( |
| 101 | + resp_ctx=resp_ctx, next_steps=next_steps |
| 102 | + ) |
| 103 | + |
| 104 | + next_steps = self.get_following_steps( |
| 105 | + from_step=next_step_uid, |
| 106 | + path_traversed=path_traversed, |
| 107 | + ) |
| 108 | + step = self.get_step(from_step=next_step_uid) |
| 109 | + |
| 110 | + return ( |
| 111 | + EngineResponse( |
| 112 | + responses=responses, |
| 113 | + from_step=from_, |
| 114 | + to_step=to_step, |
| 115 | + last_successful_step=last_successful_step, |
| 116 | + ), |
| 117 | + path_traversed, |
| 118 | + ) |
0 commit comments