Skip to content

Commit ccbb541

Browse files
authored
Fix #1131 A class with async "__call__" method fails to work as a middleware (#1132)
1 parent 40f6d1e commit ccbb541

File tree

4 files changed

+32
-7
lines changed

4 files changed

+32
-7
lines changed

slack_bolt/app/async_app.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
AsyncMessageListenerMatches,
2525
)
2626
from slack_bolt.oauth.async_internals import select_consistent_installation_store
27-
from slack_bolt.util.utils import get_name_for_callable
27+
from slack_bolt.util.utils import get_name_for_callable, is_coroutine_function
2828
from slack_bolt.workflows.step.async_step import (
2929
AsyncWorkflowStep,
3030
AsyncWorkflowStepBuilder,
@@ -778,7 +778,7 @@ async def custom_error_handler(error, body, logger):
778778
func: The function that is supposed to be executed
779779
when getting an unhandled error in Bolt app.
780780
"""
781-
if not inspect.iscoroutinefunction(func):
781+
if not is_coroutine_function(func):
782782
name = get_name_for_callable(func)
783783
raise BoltError(error_listener_function_must_be_coro_func(name))
784784
self._async_listener_runner.listener_error_handler = AsyncCustomListenerErrorHandler(
@@ -1410,7 +1410,7 @@ def _register_listener(
14101410
value_to_return = functions[0]
14111411

14121412
for func in functions:
1413-
if not inspect.iscoroutinefunction(func):
1413+
if not is_coroutine_function(func):
14141414
name = get_name_for_callable(func)
14151415
raise BoltError(error_listener_function_must_be_coro_func(name))
14161416

@@ -1422,7 +1422,7 @@ def _register_listener(
14221422
for m in middleware or []:
14231423
if isinstance(m, AsyncMiddleware):
14241424
listener_middleware.append(m)
1425-
elif isinstance(m, Callable) and inspect.iscoroutinefunction(m):
1425+
elif isinstance(m, Callable) and is_coroutine_function(m):
14261426
listener_middleware.append(AsyncCustomMiddleware(app_name=self.name, func=m, base_logger=self._base_logger))
14271427
else:
14281428
raise ValueError(error_unexpected_listener_middleware(type(m)))

slack_bolt/middleware/async_custom_middleware.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import inspect
21
from logging import Logger
32
from typing import Callable, Awaitable, Any, Sequence, Optional
43

@@ -7,7 +6,7 @@
76
from slack_bolt.request.async_request import AsyncBoltRequest
87
from slack_bolt.response import BoltResponse
98
from .async_middleware import AsyncMiddleware
10-
from slack_bolt.util.utils import get_name_for_callable, get_arg_names_of_callable
9+
from slack_bolt.util.utils import get_name_for_callable, get_arg_names_of_callable, is_coroutine_function
1110

1211

1312
class AsyncCustomMiddleware(AsyncMiddleware):
@@ -24,7 +23,7 @@ def __init__(
2423
base_logger: Optional[Logger] = None,
2524
):
2625
self.app_name = app_name
27-
if inspect.iscoroutinefunction(func):
26+
if is_coroutine_function(func):
2827
self.func = func
2928
else:
3029
raise ValueError("Async middleware function must be an async function")

slack_bolt/util/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,9 @@ def get_name_for_callable(func: Callable) -> str:
8888

8989
def get_arg_names_of_callable(func: Callable) -> List[str]:
9090
return inspect.getfullargspec(inspect.unwrap(func)).args
91+
92+
93+
def is_coroutine_function(func: Optional[Any]) -> bool:
94+
return func is not None and (
95+
inspect.iscoroutinefunction(func) or (hasattr(func, "__call__") and inspect.iscoroutinefunction(func.__call__))
96+
)

tests/scenario_tests_async/test_app_using_methods_in_class.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,14 @@ async def test_instance_methods(self):
149149
app.shortcut("test-shortcut")(awesome.instance_method)
150150
await self.run_app_and_verify(app)
151151

152+
@pytest.mark.asyncio
153+
async def test_callable_class(self):
154+
app = AsyncApp(client=self.web_client, signing_secret=self.signing_secret)
155+
instance = CallableClass("Slackbot")
156+
app.use(instance)
157+
app.shortcut("test-shortcut")(instance.event_handler)
158+
await self.run_app_and_verify(app)
159+
152160
@pytest.mark.asyncio
153161
async def test_instance_methods_uncommon_name_1(self):
154162
app = AsyncApp(client=self.web_client, signing_secret=self.signing_secret)
@@ -225,6 +233,18 @@ async def static_method(context: AsyncBoltContext, say: AsyncSay, ack: AsyncAck)
225233
await say(f"Hello <@{context.user_id}>!")
226234

227235

236+
class CallableClass:
237+
def __init__(self, name: str):
238+
self.name = name
239+
240+
async def __call__(self, next: Callable):
241+
await next()
242+
243+
async def event_handler(self, context: AsyncBoltContext, say: AsyncSay, ack: AsyncAck):
244+
await ack()
245+
await say(f"Hello <@{context.user_id}>! My name is {self.name}")
246+
247+
228248
async def top_level_function(invalid_arg, ack, say):
229249
assert invalid_arg is None
230250
await ack()

0 commit comments

Comments
 (0)