Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Documentation = "https://docs.temporal.io/docs/python"

[dependency-groups]
dev = [
"basedpyright==1.34.0",
"cibuildwheel>=2.22.0,<3",
"grpcio-tools>=1.48.2,<2",
"mypy==1.18.2",
Expand Down Expand Up @@ -96,6 +97,7 @@ lint-docs = "uv run pydocstyle --ignore-decorators=overload"
lint-types = [
{ cmd = "uv run pyright" },
{ cmd = "uv run mypy --namespace-packages --check-untyped-defs ." },
{ cmd = "uv run basedpyright" }
]
run-bench = "uv run python scripts/run_bench.py"
test = "uv run pytest"
Expand Down Expand Up @@ -198,10 +200,12 @@ reportUnknownVariableType = "none"
reportUnnecessaryIsInstance = "none"
reportUnnecessaryTypeIgnoreComment = "none"
reportUnusedCallResult = "none"
reportUnknownLambdaType = "none"
include = ["temporalio", "tests"]
exclude = [
"temporalio/api",
"temporalio/bridge/proto",
"temporalio/bridge/_visitor.py",
"tests/worker/workflow_sandbox/testmodules/proto",
"temporalio/bridge/worker.py",
"temporalio/worker/_replayer.py",
Expand Down
1 change: 0 additions & 1 deletion scripts/gen_protos.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from collections.abc import Mapping
from functools import partial
from pathlib import Path
from typing import List

base_dir = Path(__file__).parent.parent
proto_dir = (
Expand Down
7 changes: 1 addition & 6 deletions temporalio/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,7 @@
from typing import (
TYPE_CHECKING,
Any,
List,
NoReturn,
Optional,
Tuple,
Type,
Union,
overload,
)

Expand Down Expand Up @@ -593,7 +588,7 @@ def _apply_to_callable(
if hasattr(fn, "__temporal_activity_definition"):
raise ValueError("Function already contains activity definition")
elif not callable(fn):
raise TypeError("Activity is not callable")
raise TypeError("Activity is not callable") # type:ignore[reportUnreachable]
# We do not allow keyword only arguments in activities
sig = inspect.signature(fn)
for param in sig.parameters.values():
Expand Down
29 changes: 12 additions & 17 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,7 @@
from typing import (
Any,
Concatenate,
Dict,
FrozenSet,
Generic,
Optional,
Text,
Tuple,
Type,
Union,
cast,
overload,
)
Expand Down Expand Up @@ -205,7 +198,9 @@ async def connect(
http_connect_proxy_config=http_connect_proxy_config,
)

def make_lambda(plugin, next):
def make_lambda(
plugin: Plugin, next: Callable[[ConnectConfig], Awaitable[ServiceClient]]
):
return lambda config: plugin.connect_service_client(config, next)

next_function = ServiceClient.connect
Expand Down Expand Up @@ -1335,8 +1330,8 @@ async def create_schedule(
| (
temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes
) = None,
static_summary: str | None = None,
static_details: str | None = None,
static_summary: str | None = None, # type:ignore[reportUnusedParameter] # https://github.com/temporalio/sdk-python/issues/1238
static_details: str | None = None, # type:ignore[reportUnusedParameter]
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> ScheduleHandle:
Expand Down Expand Up @@ -3405,7 +3400,7 @@ def next_page_token(self) -> bytes | None:
"""Token for the next page request if any."""
return self._next_page_token

async def fetch_next_page(self, *, page_size: int | None = None) -> None:
async def fetch_next_page(self, *, page_size: int | None = None) -> None: # type:ignore[reportUnusedParameter] # https://github.com/temporalio/sdk-python/issues/1239
"""Fetch the next page if any.

Args:
Expand Down Expand Up @@ -4156,7 +4151,7 @@ def __init__(
raise ValueError("Cannot schedule dynamic workflow explicitly")
workflow = defn.name
elif not isinstance(workflow, str):
raise TypeError("Workflow must be a string or callable")
raise TypeError("Workflow must be a string or callable") # type:ignore[reportUnreachable]
self.workflow = workflow
self.args = temporalio.common._arg_or_args(arg, args)
self.id = id
Expand Down Expand Up @@ -6040,9 +6035,9 @@ async def _populate_start_workflow_execution_request(
req.user_metadata.CopyFrom(metadata)
if input.start_delay is not None:
req.workflow_start_delay.FromTimedelta(input.start_delay)
if input.headers is not None:
if input.headers is not None: # type:ignore[reportUnnecessaryComparison]
await self._apply_headers(input.headers, req.header.fields)
if input.priority is not None:
if input.priority is not None: # type:ignore[reportUnnecessaryComparison]
req.priority.CopyFrom(input.priority._to_proto())
if input.versioning_override is not None:
req.versioning_override.CopyFrom(input.versioning_override._to_proto())
Expand Down Expand Up @@ -6138,7 +6133,7 @@ async def query_workflow(self, input: QueryWorkflowInput) -> Any:
req.query.query_args.payloads.extend(
await data_converter.encode(input.args)
)
if input.headers is not None:
if input.headers is not None: # type:ignore[reportUnnecessaryComparison]
await self._apply_headers(input.headers, req.query.header.fields)
try:
resp = await self._client.workflow_service.query_workflow(
Expand Down Expand Up @@ -6186,7 +6181,7 @@ async def signal_workflow(self, input: SignalWorkflowInput) -> None:
)
if input.args:
req.input.payloads.extend(await data_converter.encode(input.args))
if input.headers is not None:
if input.headers is not None: # type:ignore[reportUnnecessaryComparison]
await self._apply_headers(input.headers, req.header.fields)
await self._client.workflow_service.signal_workflow_execution(
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
Expand Down Expand Up @@ -6307,7 +6302,7 @@ async def _build_update_workflow_execution_request(
req.request.input.args.payloads.extend(
await data_converter.encode(input.args)
)
if input.headers is not None:
if input.headers is not None: # type:ignore[reportUnnecessaryComparison]
await self._apply_headers(input.headers, req.request.input.header.fields)
return req

Expand Down
26 changes: 10 additions & 16 deletions temporalio/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,8 @@
Any,
ClassVar,
Generic,
List,
Optional,
Text,
Tuple,
Type,
TypeAlias,
TypeVar,
Union,
get_origin,
get_type_hints,
overload,
Expand Down Expand Up @@ -198,13 +192,13 @@ def __setstate__(self, state: object) -> None:
# We choose to make this a list instead of an sequence so we can catch if people
# are not sending lists each time but maybe accidentally sending a string (which
# is a sequence)
SearchAttributeValues: TypeAlias = Union[
list[str], list[int], list[float], list[bool], list[datetime]
]
SearchAttributeValues: TypeAlias = (
list[str] | list[int] | list[float] | list[bool] | list[datetime]
)

SearchAttributes: TypeAlias = Mapping[str, SearchAttributeValues]

SearchAttributeValue: TypeAlias = Union[str, int, float, bool, datetime, Sequence[str]]
SearchAttributeValue: TypeAlias = str | int | float | bool | datetime | Sequence[str]

SearchAttributeValueType = TypeVar(
"SearchAttributeValueType", str, int, float, bool, datetime, Sequence[str]
Expand Down Expand Up @@ -492,7 +486,7 @@ def __contains__(self, key: object) -> bool:

This uses key equality so the key must be the same name and type.
"""
return any(k == key for k, v in self)
return any(k == key for k, _v in self)

@overload
def get(
Expand Down Expand Up @@ -544,7 +538,7 @@ def updated(self, *search_attributes: SearchAttributePair) -> TypedSearchAttribu
TypedSearchAttributes.empty = TypedSearchAttributes(search_attributes=[])


def _warn_on_deprecated_search_attributes(
def _warn_on_deprecated_search_attributes( # type:ignore[reportUnusedFunction]
attributes: SearchAttributes | Any | None,
stack_level: int = 2,
) -> None:
Expand All @@ -556,7 +550,7 @@ def _warn_on_deprecated_search_attributes(
)


MetricAttributes: TypeAlias = Mapping[str, Union[str, int, float, bool]]
MetricAttributes: TypeAlias = Mapping[str, str | int | float | bool]


class MetricMeter(ABC):
Expand Down Expand Up @@ -1157,15 +1151,15 @@ def _to_proto(self) -> temporalio.api.workflow.v1.VersioningOverride:
_arg_unset = object()


def _arg_or_args(arg: Any, args: Sequence[Any]) -> Sequence[Any]:
def _arg_or_args(arg: Any, args: Sequence[Any]) -> Sequence[Any]: # type:ignore[reportUnusedFunction]
if arg is not _arg_unset:
if args:
raise ValueError("Cannot have arg and args")
args = [arg]
return args


def _apply_headers(
def _apply_headers( # type:ignore[reportUnusedFunction]
source: Mapping[str, temporalio.api.common.v1.Payload] | None,
dest: google.protobuf.internal.containers.MessageMap[
str, temporalio.api.common.v1.Payload
Expand All @@ -1192,7 +1186,7 @@ def _apply_headers(
)


def _type_hints_from_func(
def _type_hints_from_func( # type:ignore[reportUnusedFunction]
func: Callable,
) -> tuple[list[type] | None, type | None]:
"""Extracts the type hints from the function.
Expand Down
3 changes: 0 additions & 3 deletions temporalio/contrib/openai_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
OpenAIAgentsPlugin,
OpenAIPayloadConverter,
)
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError

from . import testing, workflow
Expand Down
2 changes: 1 addition & 1 deletion temporalio/contrib/openai_agents/_heartbeat_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])


def _auto_heartbeater(fn: F) -> F:
def _auto_heartbeater(fn: F) -> F: # type:ignore[reportUnusedClass]
# Propagate type hints from the original callable.
@wraps(fn)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
Expand Down
30 changes: 15 additions & 15 deletions temporalio/contrib/openai_agents/_invoke_model_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
"""

import enum
import json
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Optional, Union
from typing import Any

from agents import (
AgentOutputSchemaBase,
Expand All @@ -33,10 +32,9 @@
AsyncOpenAI,
)
from openai.types.responses.tool_param import Mcp
from pydantic_core import to_json
from typing_extensions import Required, TypedDict

from temporalio import activity, workflow
from temporalio import activity
from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater
from temporalio.exceptions import ApplicationError

Expand Down Expand Up @@ -75,14 +73,14 @@ class HostedMCPToolInput:
tool_config: Mcp


ToolInput = Union[
FunctionToolInput,
FileSearchTool,
WebSearchTool,
ImageGenerationTool,
CodeInterpreterTool,
HostedMCPToolInput,
]
ToolInput = (
FunctionToolInput
| FileSearchTool
| WebSearchTool
| ImageGenerationTool
| CodeInterpreterTool
| HostedMCPToolInput
)


@dataclass
Expand Down Expand Up @@ -165,11 +163,13 @@ async def invoke_model_activity(self, input: ActivityModelInput) -> ModelRespons
"""Activity that invokes a model with the given input."""
model = self._model_provider.get_model(input.get("model_name"))

async def empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
async def empty_on_invoke_tool(
_ctx: RunContextWrapper[Any], _input: str
) -> str:
return ""

async def empty_on_invoke_handoff(
ctx: RunContextWrapper[Any], input: str
_ctx: RunContextWrapper[Any], _input: str
) -> Any:
return None

Expand Down Expand Up @@ -197,7 +197,7 @@ def make_tool(tool: ToolInput) -> Tool:
strict_json_schema=tool.strict_json_schema,
)
else:
raise UserError(f"Unknown tool type: {tool.name}")
raise UserError(f"Unknown tool type: {tool.name}") # type:ignore[reportUnreachable]

tools = [make_tool(x) for x in input.get("tools", [])]
handoffs: list[Handoff[Any, Any]] = [
Expand Down
Loading
Loading