Skip to content

Commit 59ab8a1

Browse files
authored
docs: correct loss function naming and MF constraint formula (#5)
- Correct MF constraint formulas in tutorial page and example notebook - Fix incorrect naming for MSE and squared hinge losses - Add new loss functions to tutorial
1 parent b2cf099 commit 59ab8a1

File tree

6 files changed

+74
-132
lines changed

6 files changed

+74
-132
lines changed

doc/source/examples/MF.ipynb

Lines changed: 39 additions & 45 deletions
Large diffs are not rendered by default.

doc/source/tutorials/ReHLine_MF.rst

Lines changed: 8 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ Considering a User-Item-Rating triplet dataset :math:`(u, i, r_{ui})` derived fr
1010

1111
.. math::
1212
\min_{\substack{
13-
\mathbf{P} \in \mathbb{R}^{n \times r}\
13+
\mathbf{P} \in \mathbb{R}^{n \times k}\
1414
\pmb{\alpha} \in \mathbb{R}^n \\
15-
\mathbf{Q} \in \mathbb{R}^{m \times r}\
15+
\mathbf{Q} \in \mathbb{R}^{m \times k}\
1616
\pmb{\beta} \in \mathbb{R}^m
1717
}}
1818
\left[
@@ -26,31 +26,24 @@ Considering a User-Item-Rating triplet dataset :math:`(u, i, r_{ui})` derived fr
2626
2727
.. math::
2828
\ \text{ s.t. } \
29-
\mathbf{A} \begin{bmatrix}
30-
\pmb{\alpha} & \mathbf{P}
31-
\end{bmatrix}^T +
32-
\mathbf{b}\mathbf{1}_{n}^T \geq \mathbf{0}
33-
\ \text{ and } \
34-
\mathbf{A} \begin{bmatrix}
35-
\pmb{\beta} & \mathbf{Q}
36-
\end{bmatrix}^T +
37-
\mathbf{b}\mathbf{1}_{m}^T \geq \mathbf{0}
38-
29+
\mathbf{A} \begin{pmatrix} \alpha_u \\ \mathbf{p}_u \end{pmatrix} + \mathbf{b} \geq \mathbf{0},\ \forall u \in [n]
30+
\quad \text{and} \quad
31+
\mathbf{A} \begin{pmatrix} \beta_i \\ \mathbf{q}_i \end{pmatrix} + \mathbf{b} \geq \mathbf{0},\ \forall i \in [m]
3932
4033
where
4134

4235
- :math:`\text{PLQ}(\cdot , \cdot)`
4336
is a convex piecewise linear-quadratic loss function. You can find built-in loss functions in the `Loss <./loss.rst>`_ section.
4437

45-
- :math:`\mathbf{A}` is a :math:`K \times r` matrix and :math:`\mathbf{b}` is a :math:`K`-dimensional vector
46-
representing :math:`K` linear constraints. See `Constraints <./constraint.rst>`_ for more details.
38+
- :math:`\mathbf{A}` is a :math:`d \times (k+1)` matrix and :math:`\mathbf{b}` is a :math:`d`-dimensional vector
39+
representing :math:`d` linear constraints. See `Constraints <./constraint.rst>`_ for more details.
4740

4841
- :math:`\Omega`
4942
is a user-item collection that records all training data
5043

5144
- :math:`n` is number of users, :math:`m` is number of items
5245

53-
- :math:`r` is length of latent factors (rank of MF)
46+
- :math:`k` is length of latent factors (rank of MF)
5447

5548
- :math:`C` is regularization parameter, :math:`\rho` balances regularization strength between user and item
5649

@@ -214,69 +207,13 @@ The model complexity is mainly controlled by :code:`C` and :code:`rank`.
214207
mae = mean_absolute_error(y_test, y_pred)
215208
print(f"rank={rank_value}: MAE = {mae:.3f}")
216209
217-
Convergence Tracking
218-
^^^^^^^^^^^^^^^^^^^^
219-
220-
You can customize the optimization process by setting your preferred iteration counts and tolerance levels.
221-
Training progress can be monitored either by enabling :code:`verbose` output during fitting or by examining the :code:`history` attribute after fitting.
222-
223-
.. code-block:: python
224-
225-
clf = plqMF_Ridge(
226-
C=0.001,
227-
rank=6,
228-
loss={'name': 'mae'},
229-
n_users=user_num,
230-
n_items=item_num,
231-
max_iter_CD=15, ## Outer CD iterations
232-
tol_CD=1e-5, ## Outer CD tolerance
233-
max_iter=8000, ## ReHLine solver iterations
234-
tol=1e-2, ## ReHLine solver tolerance
235-
verbose=1, ## Enable progress output
236-
)
237-
clf.fit(X_train, y_train)
238-
239-
print(clf.history) ## Check training trace of cumulative loss and objection value
240-
241-
Different Gaussian initial conditions can be manually set by :code:`init_mean` and :code:`init_sd`:
242-
243-
.. code-block:: python
244-
245-
# Initialize model with positive shifted normal
246-
clf = plqMF_Ridge(
247-
C=0.001,
248-
rank=6,
249-
loss={'name': 'mae'},
250-
n_users=user_num,
251-
n_items=item_num,
252-
init_mean=1.0, ## Manually set mean of normal distribution
253-
init_sd=0.5 ## Manually set sd of normal distribution
254-
)
255-
256210
Practical Guidance
257211
^^^^^^^^^^^^^^^^^^
258212

259213
- The first column of :code:`X` corresponds to **users**, and the second column corresponds to **items**. Please ensure this aligns with your :code:`n_users` and :code:`n_items` parameters.
260214
- The default penalty strength is relatively weak; it is recommended to set a relatively small :code:`C` value initially.
261215
- When using larger :code:`C` values, consider increasing :code:`max_iter` to avoid ConvergenceWarning.
262216

263-
264-
Regularization Conversion
265-
-------------------------
266-
The regularization in this algorithm is tuned via :math:`C` and :math:`\rho`. For users who prefer to set the penalty strength directly, you may achieve conversion through the following formula:
267-
268-
.. math::
269-
\lambda_{\text{user}} = \frac{\rho}{Cn}
270-
\quad\text{and}\quad
271-
\lambda_{\text{item}} = \frac{(1 - \rho)}{Cm}
272-
273-
274-
.. math::
275-
C = \frac{1}{m \cdot \lambda_{\text{item}} + n \cdot \lambda_{\text{user}}}
276-
\quad\text{and}\quad
277-
\rho = \frac{1}{\frac{m \cdot \lambda_{\text{item}}}{ n \cdot \lambda_{\text{user}}}+1}
278-
279-
280217
Example
281218
-------
282219

doc/source/tutorials/loss.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ Classification loss
3434
- | ``name``: 'sSVM' / 'smooth SVM' / 'smooth hinge'
3535
- | ``loss={'name': 'sSVM'}``
3636

37+
* - **Squared SVM**
38+
- | ``name``: 'squared SVM' / 'squared svm' / 'squared hinge'
39+
- | ``loss={'name': 'squared SVM'}``
3740

3841
Regression loss
3942
~~~~~~~~~~~~~~~
@@ -61,6 +64,14 @@ Regression loss
6164
| ``epsilon`` (*float*): 0.1
6265
- | ``loss={'name': 'svr', 'epsilon': 0.1}``
6366

67+
* - **MAE**
68+
- | ``name``: 'MAE' / 'mae' / 'mean absolute error'
69+
- | ``loss={'name': 'mae'}``
70+
71+
* - **MSE**
72+
- | ``name``: 'MSE' / 'mse' / 'mean squared error'
73+
- | ``loss={'name': 'mse'}``
74+
6475
Related Examples
6576
----------------
6677

rehline/_base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,8 @@ def _make_loss_rehline_param(loss, X, y):
278278
"""The `_make_loss_rehline_param` function generates parameters for the ReHLine solver, based on the provided training data.
279279
280280
The function supports various loss functions, including:
281-
- 'hinge'
282-
- 'svm' or 'SVM'
281+
- 'hinge' or 'svm' or 'SVM'
282+
- 'squared hinge' or 'squared svm' or 'squared SVM'
283283
- 'mae' or 'MAE' or 'mean absolute error'
284284
- 'check' or 'quantile' or 'quantile regression' or 'QR'
285285
- 'sSVM' or 'smooth SVM' or 'smooth hinge'
@@ -393,16 +393,16 @@ def _make_loss_rehline_param(loss, X, y):
393393
U = np.array([[1.0] * n, [-1.0] * n])
394394
V = np.array([-y , y])
395395

396-
elif (loss['name'] == 'SVM square') \
397-
or (loss['name'] == 'svm square') \
398-
or (loss['name'] == 'hinge square'):
396+
elif (loss['name'] == 'squared SVM') \
397+
or (loss['name'] == 'squared svm') \
398+
or (loss['name'] == 'squared hinge'):
399399
Tau = np.inf * np.ones((1, n))
400400
S = - np.sqrt(2) * y.reshape(1,-1)
401401
T = np.sqrt(2) * np.ones((1, n))
402402

403403
elif (loss['name'] == 'MSE') \
404404
or (loss['name'] == 'mse') \
405-
or (loss['name'] == 'mean square error'):
405+
or (loss['name'] == 'mean squared error'):
406406
Tau = np.inf * np.ones((2, n))
407407
S = np.array([[np.sqrt(2)] * n, [-np.sqrt(2)] * n])
408408
T = np.array([-np.sqrt(2) * y , np.sqrt(2) * y])

rehline/_mf_class.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ class plqMF_Ridge(_BaseReHLine, BaseEstimator):
4545
The function supports various loss functions, including:
4646
- 'hinge', 'svm' or 'SVM'
4747
- 'MAE' or 'mae' or 'mean absolute error'
48-
- 'hinge square' or 'svm square' or 'SVM square'
49-
- 'MSE' or 'mse' or 'mean square error'
48+
- 'squared hinge' or 'squared svm' or 'squared SVM'
49+
- 'MSE' or 'mse' or 'mean squared error'
5050
5151
The following constraint types are supported:
5252
* 'nonnegative' or '>=0': A non-negativity constraint.
@@ -454,21 +454,21 @@ def obj(self, X, y, loss):
454454

455455
elif (loss['name'] == 'MSE') \
456456
or (loss['name'] == 'mse') \
457-
or (loss['name'] == 'mean square error'):
457+
or (loss['name'] == 'mean squared error'):
458458
loss_term = np.sum( (self.decision_function(X) - y) ** 2 )
459459

460460
elif (loss['name'] == 'hinge') \
461461
or (loss['name'] == 'svm') \
462462
or (loss['name'] == 'SVM'):
463463
loss_term = np.sum( np.maximum(0, 1 - y * self.decision_function(X)) )
464464

465-
elif (loss['name'] == 'hinge square') \
466-
or (loss['name'] == 'svm square') \
467-
or (loss['name'] == 'SVM square'):
465+
elif (loss['name'] == 'squared hinge') \
466+
or (loss['name'] == 'squared svm') \
467+
or (loss['name'] == 'squared SVM'):
468468
loss_term = np.sum( np.maximum(0, 1 - y * self.decision_function(X)) ** 2 )
469469

470470
else:
471471
raise ValueError(f"Unsupported loss function: {loss['name']}. "
472-
f"Supported losses are: 'mae', 'mse', 'hinge', 'hinge square'")
472+
f"Supported losses are: 'mae', 'mse', 'hinge', 'squared hinge'")
473473

474474
return loss_term, self.C * loss_term + penalty

tests/_test_mf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,9 @@ def evaluate_single_params(params):
178178
print("="*50)
179179

180180

181-
## Choose Hinge Square Loss
181+
## Choose Squared Hinge Loss
182182
fixed_params = {
183-
'loss': {'name': 'svm square'},
183+
'loss': {'name': 'squared svm'},
184184
'n_users': user_num,
185185
'n_items': item_num,
186186
'max_iter': 100000,
@@ -191,7 +191,7 @@ def evaluate_single_params(params):
191191

192192
best_params, best_score, best_acc, all_results = parallel_grid_search(plqMF_Ridge, param_grid, fixed_params, X_train, y_train_bin, X_test, y_test_bin, n_jobs)
193193
print("\n" + "="*50)
194-
print("BEST RESULTS(Using Hinge Square Loss)")
194+
print("BEST RESULTS(Using Squared Hinge Loss)")
195195
print("="*50)
196196
print("Optimal Parameters:")
197197
for param, value in best_params.items():

0 commit comments

Comments
 (0)