Skip to content
Merged
11 changes: 6 additions & 5 deletions plugin/code_actions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from ..protocol import CodeAction
from ..protocol import CodeActionKind
from ..protocol import CodeActionParams
from ..protocol import Command
from ..protocol import Diagnostic
from .core.promise import Promise
Expand Down Expand Up @@ -82,7 +83,7 @@ def request_for_region_async(
self.refactor_actions_cache.clear()
self.source_actions_cache.clear()

def request_factory(sb: SessionBufferProtocol) -> Request | None:
def request_factory(sb: SessionBufferProtocol) -> Request[CodeActionParams, list[CodeActionOrCommand] | None]:
diagnostics: list[Diagnostic] = []
for diag_sb, diags in session_buffer_diagnostics:
if diag_sb == sb:
Expand Down Expand Up @@ -125,15 +126,15 @@ def response_filter(sb: SessionBufferProtocol, actions: list[CodeActionOrCommand
def _collect_code_actions_async(
self,
listener: AbstractViewListener,
request_factory: Callable[[SessionBufferProtocol], Request | None],
request_factory: Callable[[SessionBufferProtocol], Request[CodeActionParams, list[CodeActionOrCommand] | None] | None], # noqa: E501
response_filter: Callable[[SessionBufferProtocol, list[CodeActionOrCommand]], list[CodeActionOrCommand]],
) -> Promise[list[CodeActionsByConfigName]]:

def on_response(
sb: SessionBufferProtocol, response: Error | list[CodeActionOrCommand] | None
) -> CodeActionsByConfigName:
actions = []
if response and not isinstance(response, Error) and response_filter:
if response and not isinstance(response, Error):
actions = response_filter(sb, response)
return (sb.session.config.name, actions)

Expand All @@ -145,7 +146,7 @@ def on_response(
listener.purge_changes_async()
sb.do_document_diagnostic_async(listener.view, listener.view.change_count())
response_handler = partial(on_response, sb)
task: Promise[list[CodeActionOrCommand] | None | Error] = session.send_request_task(request)
task = session.send_request_task(request)
tasks.append(task.then(response_handler))
# Return only results for non-empty lists.
return Promise.all(tasks) \
Expand Down Expand Up @@ -259,7 +260,7 @@ def _handle_response_async(
if self._cancelled:
return
view = self._task_runner.view
tasks: list[Promise] = []
tasks: list[Promise[None]] = []
config_name, code_actions = response
session = self._task_runner.session_by_name(config_name, 'codeActionProvider')
if session and code_actions:
Expand Down
6 changes: 3 additions & 3 deletions plugin/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
import webbrowser

SessionName: TypeAlias = str
CompletionResponse: TypeAlias = Union[List[CompletionItem], CompletionList, None]
ResolvedCompletions: TypeAlias = Tuple[Union[CompletionResponse, Error], 'weakref.ref[Session]']
CompletionResponse: TypeAlias = Union[List[CompletionItem], CompletionList, None, Error]
ResolvedCompletions: TypeAlias = Tuple[CompletionResponse, 'weakref.ref[Session]']
CompletionsStore: TypeAlias = Tuple[List[CompletionItem], CompletionItemDefaults]


Expand Down Expand Up @@ -203,7 +203,7 @@ def _create_completion_request_async(self, session: Session) -> Promise[Resolved
return promise.then(lambda response: self._on_completion_response_async(response, request_id, weak_session))

def _on_completion_response_async(
self, response: CompletionResponse | Error, request_id: int, weak_session: weakref.ref[Session]
self, response: CompletionResponse, request_id: int, weak_session: weakref.ref[Session]
) -> ResolvedCompletions:
self._pending_completion_requests.pop(request_id, None)
return (response, weak_session)
Expand Down
2 changes: 1 addition & 1 deletion plugin/core/active_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ActiveRequest:
Holds state per request.
"""

def __init__(self, sv: SessionViewProtocol, request_id: int, request: Request) -> None:
def __init__(self, sv: SessionViewProtocol, request_id: int, request: Request[Any, Any]) -> None:
# sv is the parent object; there is no need to keep it alive explicitly.
self.weaksv = ref(sv)
self.request_id = request_id
Expand Down
Loading