diff --git a/taskiq/cli/scheduler/args.py b/taskiq/cli/scheduler/args.py index d1f6d82..b110cb3 100644 --- a/taskiq/cli/scheduler/args.py +++ b/taskiq/cli/scheduler/args.py @@ -12,6 +12,7 @@ class SchedulerArgs: scheduler: Union[str, TaskiqScheduler] modules: List[str] + app_dir: Optional[str] = None log_level: str = LogLevel.INFO.name configure_logging: bool = True fs_discover: bool = False @@ -42,6 +43,16 @@ def from_cli(cls, args: Optional[Sequence[str]] = None) -> "SchedulerArgs": help="List of modules where to look for tasks.", nargs=ZERO_OR_MORE, ) + parser.add_argument( + "--app-dir", + "-d", + default=None, + help=( + "Path to application directory. " + "This path will be used to import tasks modules. " + "If not specified, current working directory will be used." + ), + ) parser.add_argument( "--fs-discover", "-fsd", diff --git a/taskiq/cli/scheduler/run.py b/taskiq/cli/scheduler/run.py index 7a7d9f5..cefeb1c 100644 --- a/taskiq/cli/scheduler/run.py +++ b/taskiq/cli/scheduler/run.py @@ -249,7 +249,7 @@ async def run_scheduler(args: SchedulerArgs) -> None: getLogger("taskiq").setLevel(level=getLevelName(args.log_level)) if isinstance(args.scheduler, str): - scheduler = import_object(args.scheduler) + scheduler = import_object(args.scheduler, app_dir=args.app_dir) if inspect.isfunction(scheduler): scheduler = scheduler() else: diff --git a/taskiq/cli/utils.py b/taskiq/cli/utils.py index 22c15ae..3502c96 100644 --- a/taskiq/cli/utils.py +++ b/taskiq/cli/utils.py @@ -4,7 +4,7 @@ from importlib import import_module from logging import getLogger from pathlib import Path -from typing import Any, Generator, List, Sequence, Union +from typing import Any, Generator, List, Sequence, Union, Optional logger = getLogger("taskiq.worker") @@ -35,11 +35,12 @@ def add_cwd_in_path() -> Generator[None, None, None]: logger.warning(f"Cannot remove '{cwd}' from sys.path") -def import_object(object_spec: str) -> Any: +def import_object(object_spec: str, app_dir: Optional[str] = None) -> Any: """ It parses python object spec and imports it. :param object_spec: string in format like `package.module:variable` + :param app_dir: directory to add in sys.path for importing. :raises ValueError: if spec has unknown format. :returns: imported broker. """ @@ -47,6 +48,8 @@ def import_object(object_spec: str) -> Any: if len(import_spec) != 2: raise ValueError("You should provide object path in `module:variable` format.") with add_cwd_in_path(): + if app_dir: + sys.path.insert(0, app_dir) module = import_module(import_spec[0]) return getattr(module, import_spec[1]) diff --git a/taskiq/cli/worker/args.py b/taskiq/cli/worker/args.py index ef187ab..8bdc54c 100644 --- a/taskiq/cli/worker/args.py +++ b/taskiq/cli/worker/args.py @@ -26,6 +26,7 @@ class WorkerArgs: broker: str modules: List[str] + app_dir: Optional[str] = None tasks_pattern: Sequence[str] = ("**/tasks.py",) fs_discover: bool = False configure_logging: bool = True @@ -73,6 +74,16 @@ def from_cli( "'module.module:variable' format." ), ) + parser.add_argument( + "--app-dir", + "-d", + default=None, + help=( + "Path to application directory. " + "This path will be used to import tasks modules. " + "If not specified, current working directory will be used." + ), + ) parser.add_argument( "--receiver", default="taskiq.receiver:Receiver", diff --git a/taskiq/cli/worker/run.py b/taskiq/cli/worker/run.py index 3d58fed..7a44d33 100644 --- a/taskiq/cli/worker/run.py +++ b/taskiq/cli/worker/run.py @@ -64,7 +64,7 @@ def get_receiver_type(args: WorkerArgs) -> Type[Receiver]: :raises ValueError: if receiver is not a Receiver type. :return: Receiver type. """ - receiver_type = import_object(args.receiver) + receiver_type = import_object(args.receiver, app_dir=args.app_dir) if not (isinstance(receiver_type, type) and issubclass(receiver_type, Receiver)): raise ValueError("Unknown receiver type. Please use Receiver class.") return receiver_type @@ -133,7 +133,7 @@ def interrupt_handler(signum: int, _frame: Any) -> None: # We must set this field before importing tasks, # so broker will remember all tasks it's related to. - broker = import_object(args.broker) + broker = import_object(args.broker, app_dir=args.app_dir) if inspect.isfunction(broker): broker = broker() if not isinstance(broker, AsyncBroker):