55from __future__ import annotations
66
77import inspect
8- from typing import TYPE_CHECKING , Union
8+ from typing import TYPE_CHECKING , Union , overload
99
1010from .._types import NOT_GIVEN , NotGiven
1111from ._models import ModelVersionIdentifier
@@ -60,7 +60,15 @@ def _parse_model_args(
6060 "Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
6161 )
6262
63- return model_owner , model_name
63+ return model_owner , model_name # type: ignore[return-value]
64+
65+
66+ @overload
67+ def patch_models_resource (models_resource : "ModelsResource" ) -> "ModelsResource" : ...
68+
69+
70+ @overload
71+ def patch_models_resource (models_resource : "AsyncModelsResource" ) -> "AsyncModelsResource" : ...
6472
6573
6674def patch_models_resource (
@@ -72,7 +80,7 @@ def patch_models_resource(
7280
7381 if is_async :
7482
75- async def get_wrapper (
83+ async def async_get_wrapper (
7684 model_or_owner : str | NotGiven = NOT_GIVEN ,
7785 * ,
7886 model_owner : str | NotGiven = NOT_GIVEN ,
@@ -83,17 +91,19 @@ async def get_wrapper(
8391 timeout : "float | httpx.Timeout | None | NotGiven" = NOT_GIVEN ,
8492 ) -> "ModelGetResponse" :
8593 owner , name = _parse_model_args (model_or_owner , model_owner , model_name )
86- return await original_get (
94+ return await models_resource . _original_get ( # type: ignore[misc,no-any-return,attr-defined]
8795 model_owner = owner ,
8896 model_name = name ,
8997 extra_headers = extra_headers ,
9098 extra_query = extra_query ,
9199 extra_body = extra_body ,
92100 timeout = timeout ,
93101 )
102+
103+ wrapper = async_get_wrapper
94104 else :
95105
96- def get_wrapper (
106+ def sync_get_wrapper (
97107 model_or_owner : str | NotGiven = NOT_GIVEN ,
98108 * ,
99109 model_owner : str | NotGiven = NOT_GIVEN ,
@@ -104,7 +114,7 @@ def get_wrapper(
104114 timeout : "float | httpx.Timeout | None | NotGiven" = NOT_GIVEN ,
105115 ) -> "ModelGetResponse" :
106116 owner , name = _parse_model_args (model_or_owner , model_owner , model_name )
107- return original_get (
117+ return models_resource . _original_get ( # type: ignore[misc,return-value,attr-defined]
108118 model_owner = owner ,
109119 model_name = name ,
110120 extra_headers = extra_headers ,
@@ -113,7 +123,9 @@ def get_wrapper(
113123 timeout = timeout ,
114124 )
115125
126+ wrapper = sync_get_wrapper # type: ignore[assignment]
127+
116128 # Store original method for tests and replace with wrapper
117- models_resource ._original_get = original_get
118- models_resource .get = get_wrapper
129+ models_resource ._original_get = original_get # type: ignore[attr-defined,union-attr]
130+ models_resource .get = wrapper # type: ignore[method-assign]
119131 return models_resource
0 commit comments