Skip to content

Commit ff4aa66

Browse files
committed
feat: add ClassicalMDS
1 parent 7b6d5ae commit ff4aa66

File tree

3 files changed

+137
-0
lines changed

3 files changed

+137
-0
lines changed

rumale-manifold/lib/rumale/manifold.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
require_relative 'manifold/locally_linear_embedding'
77
require_relative 'manifold/hessian_eigenmaps'
88
require_relative 'manifold/local_tangent_space_alignment'
9+
require_relative 'manifold/classical_mds'
910
require_relative 'manifold/mds'
1011
require_relative 'manifold/tsne'
1112
require_relative 'manifold/version'
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# frozen_string_literal: true
2+
3+
require 'rumale/base/estimator'
4+
require 'rumale/base/transformer'
5+
require 'rumale/utils'
6+
require 'rumale/validation'
7+
require 'rumale/pairwise_metric'
8+
9+
module Rumale
10+
module Manifold
11+
# ClassicalMDS is a class that implements classical multi-dimensional scaling.
12+
#
13+
# @example
14+
# require 'rumale/manifold/classical_mds'
15+
#
16+
# mds = Rumale::Manifold::ClassicalMDS.new(n_components: 2)
17+
# representations = mds.fit_transform(data)
18+
#
19+
class ClassicalMDS < Rumale::Base::Estimator
20+
include Rumale::Base::Transformer
21+
22+
# Return the data in representation space.
23+
# @return [Numo::DFloat] (shape: [n_samples, n_components])
24+
attr_reader :embedding
25+
26+
# Create a new transformer with Classical MDS.
27+
#
28+
# @param n_components [Integer] The number of dimensions on representation space.
29+
# @param metric [String] The metric to calculate the distances in original space.
30+
# If metric is 'euclidean', Euclidean distance is calculated for distance in original space.
31+
# If metric is 'precomputed', the fit and fit_transform methods expect to be given a distance matrix.
32+
def initialize(n_components: 2, metric: 'euclidean')
33+
super()
34+
@params = {
35+
n_components: n_components,
36+
metric: metric
37+
}
38+
end
39+
40+
# Fit the model with given training data.
41+
#
42+
# @overload fit(x) -> ClassicalMDS
43+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
44+
# If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
45+
# @return [ClassicalMDS] The learned transformer itself.
46+
def fit(x, _not_used = nil)
47+
raise 'ClassicalMDS#fit requires Numo::Linalg but that is not loaded' unless enable_linalg?(warning: false)
48+
49+
x = ::Rumale::Validation.check_convert_sample_array(x)
50+
if @params[:metric] == 'precomputed' && x.shape[0] != x.shape[1]
51+
raise ArgumentError, 'Expect the input distance matrix to be square.'
52+
end
53+
54+
n_samples = x.shape[0]
55+
distance_mat = @params[:metric] == 'precomputed' ? x : ::Rumale::PairwiseMetric.euclidean_distance(x)
56+
57+
centering_mat = Numo::DFloat.eye(n_samples) - Numo::DFloat.new(n_samples, n_samples).fill(1.fdiv(n_samples))
58+
kernel_mat = -0.5 * centering_mat.dot(distance_mat * distance_mat).dot(centering_mat)
59+
eig_vals, eig_vecs = Numo::Linalg.eigh(kernel_mat, vals_range: (n_samples - @params[:n_components])...n_samples)
60+
eig_vals = eig_vals.reverse
61+
eig_vecs = eig_vecs.reverse(1)
62+
@embedding = eig_vecs.dot(Numo::NMath.sqrt(eig_vals.abs).diag)
63+
64+
self
65+
end
66+
67+
# Fit the model with training data, and then transform them with the learned model.
68+
#
69+
# @overload fit_transform(x) -> Numo::DFloat
70+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
71+
# If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
72+
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
73+
def fit_transform(x, _not_used = nil)
74+
raise 'ClassicalMDS#fit_transform requires Numo::Linalg but that is not loaded' unless enable_linalg?(warning: false)
75+
76+
x = ::Rumale::Validation.check_convert_sample_array(x)
77+
fit(x)
78+
@embedding.dup
79+
end
80+
end
81+
end
82+
end
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# frozen_string_literal: true
2+
3+
require 'spec_helper'
4+
5+
RSpec.describe Rumale::Manifold::ClassicalMDS do
6+
let(:base_samples) { two_clusters_dataset[0] }
7+
let(:samples) { Rumale::KernelApproximation::RBF.new(n_components: 32, random_seed: 1).fit_transform(base_samples) }
8+
let(:n_samples) { samples.shape[0] }
9+
let(:n_features) { samples.shape[1] }
10+
let(:n_components) { 2 }
11+
let(:metric) { 'euclidean' }
12+
let(:mds) { described_class.new(n_components: n_components, metric: metric) }
13+
let(:low_samples) { mds.fit_transform(x) }
14+
15+
context 'when metric is "euclidean"' do
16+
let(:metric) { 'euclidean' }
17+
let(:x) { samples }
18+
19+
it 'maps high-dimensional data into low-dimensional data', :aggregate_failures do
20+
expect(low_samples).to be_a(Numo::DFloat)
21+
expect(low_samples).to be_contiguous
22+
expect(low_samples.ndim).to eq(2)
23+
expect(low_samples.shape[0]).to eq(n_samples)
24+
expect(low_samples.shape[1]).to eq(n_components)
25+
expect(mds.embedding).to be_a(Numo::DFloat)
26+
expect(mds.embedding).to be_contiguous
27+
expect(mds.embedding.ndim).to eq(2)
28+
expect(mds.embedding.shape[0]).to eq(n_samples)
29+
expect(mds.embedding.shape[1]).to eq(n_components)
30+
end
31+
end
32+
33+
context 'when metric is "precomputed"' do
34+
let(:metric) { 'precomputed' }
35+
let(:x) { Rumale::PairwiseMetric.euclidean_distance(samples) }
36+
37+
it 'maps high-dimensional data represented by distance matrix', :aggregate_failures do
38+
expect(low_samples).to be_a(Numo::DFloat)
39+
expect(low_samples).to be_contiguous
40+
expect(low_samples.ndim).to eq(2)
41+
expect(low_samples.shape[0]).to eq(n_samples)
42+
expect(low_samples.shape[1]).to eq(n_components)
43+
expect(mds.embedding).to be_a(Numo::DFloat)
44+
expect(mds.embedding).to be_contiguous
45+
expect(mds.embedding.ndim).to eq(2)
46+
expect(mds.embedding.shape[0]).to eq(n_samples)
47+
expect(mds.embedding.shape[1]).to eq(n_components)
48+
end
49+
50+
it 'raises ArgumentError when given a non-square matrix', :aggregate_failures do
51+
expect { mds.fit(Numo::DFloat.new(5, 3).rand) }.to raise_error(ArgumentError)
52+
end
53+
end
54+
end

0 commit comments

Comments
 (0)