@@ -592,14 +592,122 @@ points, while constrains the sum of distances between dissimilar points:
592592 -with-side-information.pdf> `_. NIPS 2002
593593 .. [2 ] Adapted from Matlab code http://www.cs.cmu.edu/%7Eepxing/papers/Old_papers/code_Metric_online.tar.gz
594594
595+ .. _learning_on_triplets :
596+
597+ Learning on triplets
598+ ====================
599+
600+ Some metric learning algorithms learn on triplets of samples. In this case,
601+ one should provide the algorithm with `n_samples ` triplets of points. The
602+ semantic of each triplet is that the first point should be closer to the
603+ second point than to the third one.
604+
605+ Fitting
606+ -------
607+ Here is an example for fitting on triplets (see :ref: `fit_ws ` for more
608+ details on the input data format and how to fit, in the general case of
609+ learning on tuples).
610+
611+ >>> from metric_learn import SCML
612+ >>> triplets = np.array([[[1.2 , 3.2 ], [2.3 , 5.5 ], [2.1 , 0.6 ]],
613+ >>> [[4.5 , 2.3 ], [2.1 , 2.3 ], [7.3 , 3.4 ]]])
614+ >>> scml = SCML(random_state = 42 )
615+ >>> scml.fit(triplets)
616+ SCML(beta=1e-5, B=None, max_iter=100000, verbose=False,
617+ preprocessor=None, random_state=None)
618+
619+ Or alternatively (using a preprocessor):
620+
621+ >>> X = np.array([[[1.2 , 3.2 ],
622+ >>> [2.3 , 5.5 ],
623+ >>> [2.1 , 0.6 ],
624+ >>> [4.5 , 2.3 ],
625+ >>> [2.1 , 2.3 ],
626+ >>> [7.3 , 3.4 ]])
627+ >>> triplets_indices = np.array([[0 , 1 , 2 ], [3 , 4 , 5 ]])
628+ >>> scml = SCML(preprocessor = X, random_state = 42 )
629+ >>> scml.fit(triplets_indices)
630+ SCML(beta=1e-5, B=None, max_iter=100000, verbose=False,
631+ preprocessor=array([[1.2, 3.2],
632+ [2.3, 5.5],
633+ [2.4, 6.7],
634+ [2.1, 0.6],
635+ [4.5, 2.3],
636+ [2.1, 2.3],
637+ [0.6, 1.2],
638+ [7.3, 3.4]]),
639+ random_state=None)
640+
641+
642+ Here, we want to learn a metric that, for each of the two
643+ `triplets `, will make the first point closer to the
644+ second point than to the third one.
645+
646+ .. _triplets_predicting :
647+
648+ Prediction
649+ ----------
650+
651+ When a triplets learner is fitted, it is also able to predict, for an
652+ upcoming triplet, whether the first point is closer to the second point
653+ than to the third one (+1), or not (-1).
654+
655+ >>> triplets_test = np.array(
656+ ... [[[5.6 , 5.3 ], [2.2 , 2.1 ], [1.2 , 3.4 ]],
657+ ... [[6.0 , 4.2 ], [4.3 , 1.2 ], [0.1 , 7.8 ]]])
658+ >>> scml.predict(triplets_test)
659+ array([-1., 1.])
660+
661+ .. _triplets_scoring :
662+
663+ Scoring
664+ -------
665+
666+ Triplet metric learners can also return a `decision_function ` for a set of triplets,
667+ which corresponds to the distance between the first two points minus the distance
668+ between the first and last points of the triplet (the higher the value, the more
669+ similar the first point to the second point compared to the last one). This "score"
670+ can be interpreted as a measure of likeliness of having a +1 prediction for this
671+ triplet.
672+
673+ >>> scml.decision_function(triplets_test)
674+ array([-1.75700306, 4.98982131])
675+
676+ In the above example, for the first triplet in `triplets_test `, the first
677+ point is predicted less similar to the second point than to the last point
678+ (they are further away in the transformed space).
679+
680+ Unlike pairs learners, triplets learners do not allow to give a `y ` when fitting: we
681+ assume that the ordering of points within triplets is such that the training triplets
682+ are all positive. Therefore, it is not possible to use scikit-learn scoring functions
683+ (such as 'f1_score') for triplets learners.
684+
685+ However, triplets learners do have a default scoring function, which will
686+ basically return the accuracy score on a given test set, i.e. the proportion
687+ of triplets that have the right predicted ordering.
688+
689+ >>> scml.score(triplets_test)
690+ 0.5
691+
692+ .. note ::
693+ See :ref: `fit_ws ` for more details on metric learners functions that are
694+ not specific to learning on pairs, like `transform `, `score_pairs `,
695+ `get_metric ` and `get_mahalanobis_matrix `.
696+
697+
698+
699+
700+ Algorithms
701+ ----------
702+
595703
596704.. _learning_on_quadruplets :
597705
598706Learning on quadruplets
599707=======================
600708
601709Some metric learning algorithms learn on quadruplets of samples. In this case,
602- one should provide the algorithm with `n_samples ` quadruplets of points. Th
710+ one should provide the algorithm with `n_samples ` quadruplets of points. The
603711semantic of each quadruplet is that the first two points should be closer
604712together than the last two points.
605713
@@ -666,14 +774,12 @@ array([-1., 1.])
666774Scoring
667775-------
668776
669- Quadruplet metric learners can also
670- return a `decision_function ` for a set of pairs. This is basically the "score"
671- which sign will be taken to find the prediction for the pair, which
672- corresponds to the difference between the distance between the two last points,
673- and the distance between the two last points of the quadruplet (higher
674- score means the two last points are more likely to be more dissimilar than
675- the two first points (i.e. more likely to have a +1 prediction since it's
676- the right ordering)).
777+ Quadruplet metric learners can also return a `decision_function ` for a set of
778+ quadruplets, which corresponds to the distance between the first pair of points minus
779+ the distance between the second pair of points of the triplet (the higher the value,
780+ the more similar the first pair is than the last pair).
781+ This "score" can be interpreted as a measure of likeliness of having a +1 prediction
782+ for this quadruplet.
677783
678784>>> lsml.decision_function(quadruplets_test)
679785array([-1.75700306, 4.98982131])
@@ -682,17 +788,10 @@ In the above example, for the first quadruplet in `quadruplets_test`, the
682788two first points are predicted less similar than the two last points (they
683789are further away in the transformed space).
684790
685- Unlike for pairs learners, quadruplets learners don't allow to give a `y `
686- when fitting, which does not allow to use scikit-learn scoring functions
687- like:
688-
689- >>> from sklearn.model_selection import cross_val_score
690- >>> cross_val_score(lsml, quadruplets, scoring = ' f1_score' ) # this won't work
691-
692- (This is actually intentional, for more details
693- about that, see
694- `this comment <https://github.com/scikit-learn-contrib/metric-learn/pull/168#pullrequestreview-203730742 >`_
695- on github.)
791+ Like triplet learners, quadruplets learners do not allow to give a `y ` when fitting: we
792+ assume that the ordering of points within triplets is such that the training triplets
793+ are all positive. Therefore, it is not possible to use scikit-learn scoring functions
794+ (such as 'f1_score') for triplets learners.
696795
697796However, quadruplets learners do have a default scoring function, which will
698797basically return the accuracy score on a given test set, i.e. the proportion
0 commit comments