|
1 | 1 | from abc import ABCMeta, abstractmethod |
2 | | -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union, no_type_check |
| 2 | +from typing import ( |
| 3 | + TYPE_CHECKING, |
| 4 | + Any, |
| 5 | + Dict, |
| 6 | + Optional, |
| 7 | + Type, |
| 8 | + TypeVar, |
| 9 | + Union, |
| 10 | + no_type_check, |
| 11 | +) |
3 | 12 |
|
4 | 13 | import numpy as np |
5 | 14 | from optuna.trial import Trial |
|
16 | 25 | UserIndexArray, |
17 | 26 | ) |
18 | 27 |
|
| 28 | +R = TypeVar("R", bound="BaseRecommender") |
| 29 | + |
19 | 30 |
|
20 | 31 | def _sparse_to_array(U: Any) -> np.ndarray: |
21 | 32 | if sps.issparse(U): |
@@ -78,15 +89,17 @@ def __init__(self, X_train_all: InteractionMatrix, **kwargs: Any) -> None: |
78 | 89 |
|
79 | 90 | @classmethod |
80 | 91 | def from_config( |
81 | | - cls, X_train_all: InteractionMatrix, config: RecommenderConfig |
82 | | - ) -> "BaseRecommender": |
| 92 | + cls: Type[R], |
| 93 | + X_train_all: InteractionMatrix, |
| 94 | + config: RecommenderConfig, |
| 95 | + ) -> R: |
83 | 96 | if not isinstance(config, cls.config_class): |
84 | 97 | raise ValueError( |
85 | 98 | f"Different config has been given. config must be {cls.config_class}" |
86 | 99 | ) |
87 | 100 | return cls(X_train_all, **config.dict()) |
88 | 101 |
|
89 | | - def learn(self) -> "BaseRecommender": |
| 102 | + def learn(self: R) -> R: |
90 | 103 | """Learns and returns itself. |
91 | 104 |
|
92 | 105 | Returns: |
@@ -245,7 +258,7 @@ def get_score_block(self, begin: int, end: int) -> DenseScoreArray: |
245 | 258 | return _sparse_to_array(self.U[begin:end].dot(self._X_csc)) |
246 | 259 |
|
247 | 260 |
|
248 | | -class BaseRecommenderWithUserEmbedding(BaseRecommender): |
| 261 | +class BaseRecommenderWithUserEmbedding: |
249 | 262 | """Defines a recommender with user embedding (e.g., matrix factorization.). |
250 | 263 | These class can be a base CF estimator for CB2CF (with user profile -> user embedding NN). |
251 | 264 | """ |
@@ -276,7 +289,7 @@ def get_score_from_user_embedding( |
276 | 289 | raise NotImplementedError("get_score_from_item_embedding must be implemtented.") |
277 | 290 |
|
278 | 291 |
|
279 | | -class BaseRecommenderWithItemEmbedding(BaseRecommender): |
| 292 | +class BaseRecommenderWithItemEmbedding: |
280 | 293 | """Defines a recommender with item embedding (e.g., matrix factorization.). |
281 | 294 | These class can be a base CF estimator for CB2CF (with item profile -> item embedding NN). |
282 | 295 | """ |
|
0 commit comments