99from datetime import timedelta
1010from itertools import zip_longest
1111from pprint import pformat
12- from typing import Any , List , Literal , Never , Optional , Sequence , Type
12+ from typing import Any , List , Literal , Never , Optional , Sequence , Type , cast
1313from warnings import warn
1414
1515import pytest
@@ -983,9 +983,12 @@ def to_failure(
983983 else :
984984 raise TypeError (f"self.context is { type (self .context )} " )
985985
986+ assert isinstance (
987+ self .context , (WorkflowSerializationContext , ActivitySerializationContext )
988+ )
986989 test_traces [self .context .workflow_id ].append (
987990 TraceItem (
988- context_type = context_type ,
991+ context_type = cast ( Literal [ "workflow" , "activity" ], context_type ) ,
989992 in_workflow = workflow .in_workflow (),
990993 method = "to_failure" ,
991994 context = dataclasses .asdict (self .context ),
@@ -1004,9 +1007,12 @@ def from_failure(
10041007 else :
10051008 raise TypeError (f"self.context is { type (self .context )} " )
10061009
1010+ assert isinstance (
1011+ self .context , (WorkflowSerializationContext , ActivitySerializationContext )
1012+ )
10071013 test_traces [self .context .workflow_id ].append (
10081014 TraceItem (
1009- context_type = context_type ,
1015+ context_type = cast ( Literal [ "workflow" , "activity" ], context_type ) ,
10101016 in_workflow = workflow .in_workflow (),
10111017 method = "from_failure" ,
10121018 context = dataclasses .asdict (self .context ),
0 commit comments