1414import sys
1515import threading
1616import types
17+ import warnings
1718from contextlib import ExitStack , contextmanager
1819from typing import (
1920 Any ,
2930 Set ,
3031 Tuple ,
3132 TypeVar ,
33+ no_type_check ,
3234)
3335
3436from typing_extensions import ParamSpec
@@ -107,6 +109,33 @@ def restrict_built_in(name: str, orig: Any, *args, **kwargs):
107109 )
108110 )
109111
112+ # Need to unwrap params for isinstance and issubclass. We have
113+ # chosen to do it this way instead of customize __instancecheck__
114+ # and __subclasscheck__ because we may have proxied the second
115+ # parameter which does not have a way to override. It is unfortunate
116+ # we have to change these globals for everybody.
117+ def unwrap_second_param (orig : Any , a : Any , b : Any ) -> Any :
118+ a = RestrictionContext .unwrap_if_proxied (a )
119+ b = RestrictionContext .unwrap_if_proxied (b )
120+ return orig (a , b )
121+
122+ thread_local_is_inst = _get_thread_local_builtin ("isinstance" )
123+ self .restricted_builtins .append (
124+ (
125+ "isinstance" ,
126+ thread_local_is_inst ,
127+ functools .partial (unwrap_second_param , thread_local_is_inst .orig ),
128+ )
129+ )
130+ thread_local_is_sub = _get_thread_local_builtin ("issubclass" )
131+ self .restricted_builtins .append (
132+ (
133+ "issubclass" ,
134+ thread_local_is_sub ,
135+ functools .partial (unwrap_second_param , thread_local_is_sub .orig ),
136+ )
137+ )
138+
110139 @contextmanager
111140 def applied (self ) -> Iterator [None ]:
112141 """Context manager to apply this restrictive import.
@@ -153,17 +182,21 @@ def _import(
153182 fromlist : Sequence [str ] = (),
154183 level : int = 0 ,
155184 ) -> types .ModuleType :
185+ # We have to resolve the full name, it can be relative at different
186+ # levels
187+ full_name = _resolve_module_name (name , globals , level )
188+
156189 # Check module restrictions and passthrough modules
157- if name not in sys .modules :
190+ if full_name not in sys .modules :
158191 # Make sure not an entirely invalid module
159- self ._assert_valid_module (name )
192+ self ._assert_valid_module (full_name )
160193
161194 # Check if passthrough
162- passthrough_mod = self ._maybe_passthrough_module (name )
195+ passthrough_mod = self ._maybe_passthrough_module (full_name )
163196 if passthrough_mod :
164197 # Load all parents. Usually Python does this for us, but not on
165198 # passthrough.
166- parent , _ , child = name .rpartition ("." )
199+ parent , _ , child = full_name .rpartition ("." )
167200 if parent and parent not in sys .modules :
168201 _trace (
169202 "Importing parent module %s before passing through %s" ,
@@ -174,17 +207,17 @@ def _import(
174207 # Set the passthrough on the parent
175208 setattr (sys .modules [parent ], child , passthrough_mod )
176209 # Set the passthrough on sys.modules and on the parent
177- sys .modules [name ] = passthrough_mod
210+ sys .modules [full_name ] = passthrough_mod
178211 # Put it on the parent
179212 if parent :
180- setattr (sys .modules [parent ], child , sys .modules [name ])
213+ setattr (sys .modules [parent ], child , sys .modules [full_name ])
181214
182215 # If the module is __temporal_main__ and not already in sys.modules,
183216 # we load it from whatever file __main__ was originally in
184- if name == "__temporal_main__" :
217+ if full_name == "__temporal_main__" :
185218 orig_mod = _thread_local_sys_modules .orig ["__main__" ]
186219 new_spec = importlib .util .spec_from_file_location (
187- name , orig_mod .__file__
220+ full_name , orig_mod .__file__
188221 )
189222 if not new_spec :
190223 raise ImportError (
@@ -195,7 +228,7 @@ def _import(
195228 f"Spec for __main__ file at { orig_mod .__file__ } has no loader"
196229 )
197230 new_mod = importlib .util .module_from_spec (new_spec )
198- sys .modules [name ] = new_mod
231+ sys .modules [full_name ] = new_mod
199232 new_spec .loader .exec_module (new_mod )
200233
201234 mod = importlib .__import__ (name , globals , locals , fromlist , level )
@@ -219,10 +252,20 @@ def _assert_valid_module(self, name: str) -> None:
219252 raise RestrictedWorkflowAccessError (name )
220253
221254 def _maybe_passthrough_module (self , name : str ) -> Optional [types .ModuleType ]:
222- if not self .restrictions .passthrough_modules .match_access (
223- self .restriction_context , * name .split ("." )
255+ # If imports not passed through and name not in passthrough modules,
256+ # check parents
257+ if (
258+ not temporalio .workflow .unsafe .is_imports_passed_through ()
259+ and name not in self .restrictions .passthrough_modules
224260 ):
225- return None
261+ end_dot = - 1
262+ while True :
263+ end_dot = name .find ("." , end_dot + 1 )
264+ if end_dot == - 1 :
265+ return None
266+ elif name [:end_dot ] in self .restrictions .passthrough_modules :
267+ break
268+ # Do the pass through
226269 with self ._unapplied ():
227270 _trace ("Passing module %s through from host" , name )
228271 global _trace_depth
@@ -409,3 +452,50 @@ def _get_thread_local_builtin(name: str) -> _ThreadLocalCallable:
409452 ret = _ThreadLocalCallable (getattr (builtins , name ))
410453 _thread_local_builtins [name ] = ret
411454 return ret
455+
456+
457+ def _resolve_module_name (
458+ name : str , globals : Optional [Mapping [str , object ]], level : int
459+ ) -> str :
460+ if level == 0 :
461+ return name
462+ # Calc the package from globals
463+ package = _calc___package__ (globals or {})
464+ # Logic taken from importlib._resolve_name
465+ bits = package .rsplit ("." , level - 1 )
466+ if len (bits ) < level :
467+ raise ImportError ("Attempted relative import beyond top-level package" )
468+ base = bits [0 ]
469+ return f"{ base } .{ name } " if name else base
470+
471+
472+ # Copied from importlib._calc__package__
473+ @no_type_check
474+ def _calc___package__ (globals : Mapping [str , object ]) -> str :
475+ """Calculate what __package__ should be.
476+ __package__ is not guaranteed to be defined or could be set to None
477+ to represent that its proper value is unknown.
478+ """
479+ package = globals .get ("__package__" )
480+ spec = globals .get ("__spec__" )
481+ if package is not None :
482+ if spec is not None and package != spec .parent :
483+ warnings .warn (
484+ "__package__ != __spec__.parent " f"({ package !r} != { spec .parent !r} )" ,
485+ DeprecationWarning ,
486+ stacklevel = 3 ,
487+ )
488+ return package
489+ elif spec is not None :
490+ return spec .parent
491+ else :
492+ warnings .warn (
493+ "can't resolve package from __spec__ or __package__, "
494+ "falling back on __name__ and __path__" ,
495+ ImportWarning ,
496+ stacklevel = 3 ,
497+ )
498+ package = globals ["__name__" ]
499+ if "__path__" not in globals :
500+ package = package .rpartition ("." )[0 ]
501+ return package
0 commit comments