Skip to content

Commit 7220d28

Browse files
authored
Features/kshape (#26)
* Adding khiva and updating copyright year to 2019 * Adding Clustering module. * Change Test clustering
1 parent 7c9c454 commit 7220d28

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

src/main/java/io/shapelets/khiva/Clustering.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ public class Clustering extends Library{
2828
* @param k The number of centroids.
2929
* @param tolerance The maximum error tolerance.
3030
* @param maxIterations The maximum number of iterations.
31+
*
3132
* @return An Array of arrays with the resulting centroids and labels.
3233
*/
3334
public static Array[] kMeans(Array tss, int k, float tolerance, int maxIterations) {
@@ -47,6 +48,7 @@ public static Array[] kMeans(Array tss, int k, float tolerance, int maxIteration
4748
* @param k The number of centroids.
4849
* @param tolerance The maximum error tolerance.
4950
* @param maxIterations The maximum number of iterations.
51+
*
5052
* @return An Array of arrays with the resulting centroids and labels.
5153
*/
5254
public static Array[] kShape(Array tss, int k, float tolerance, int maxIterations) {

src/test/java/io/shapelets/khiva/ClusteringTest.java

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,25 @@ public void testKMeans() throws Exception {
2626
8.0f, 5.0f, 3.0f, 1.0f, 15.0f, 10.0f, 5.0f, 0.0f, 7.0f, -7.0f, 1.0f, -1.0f};
2727
long[] dims = {4, 6, 1, 1};
2828

29-
float[] expected = {0.0f, 0.1667f, 0.3333f, 0.5f,
30-
1.5f, -1.5f, 0.8333f, -0.8333f,
31-
4.8333f, 3.6667f, 2.6667f, 1.6667f};
32-
3329
float tolerance = 1e-10f;
3430
int maxIterations = 100;
3531
int k = 3;
32+
3633
try (
3734
Array a = new Array(tss, dims)
3835
){
3936
Array[] result = Clustering.kMeans(a, k, tolerance, maxIterations);
37+
float[] expected = {0.0f, 0.1667f, 0.3333f, 0.5f, 1.5f, -1.5f, 0.8333f, -0.8333f, 4.8333f, 3.6667f,
38+
2.6667f, 1.6667f};
4039
float[] centroids = result[0].getData();
4140

4241
for (int i = 0; i < 4; i++){
4342
Assert.assertEquals(expected[i] + expected[i + 4] + expected[i + 8],
4443
centroids[i] + centroids[i + 4] + centroids[i + 8], 1e-4f);
4544
}
45+
46+
result[0].close();
47+
result[1].close();
4648
}
4749
}
4850

@@ -53,24 +55,25 @@ public void testKShape() throws Exception {
5355
-6.0f, -1.0f, 2.0f, 9.0f, -5.0f, -5.0f, -6.0f, 7.0f, 9.0f, 9.0f, 0.0f};
5456
long[] dims = {7, 5, 1, 1};
5557

56-
float[] expected_c = {-0.5234f, 0.1560f, -0.3627f, -1.2764f, -0.7781f, 0.9135f, 1.8711f,
57-
-0.7825f, 1.5990f, 0.1701f, 0.4082f, 0.8845f, -1.4969f, -0.7825f,
58-
-0.6278f, 1.3812f, -2.0090f, 0.5022f, 0.6278f, -0.0000f, 0.1256f};
59-
60-
6158
float tolerance = 1e-10f;
6259
int maxIterations = 100;
6360
int k = 3;
61+
6462
try (
6563
Array a = new Array(tss, dims)
6664
){
6765
Array[] result = Clustering.kShape(a, k, tolerance, maxIterations);
66+
float[] expected_c = {-0.5234f, 0.1560f, -0.3627f, -1.2764f, -0.7781f, 0.9135f, 1.8711f,
67+
-0.7825f, 1.5990f, 0.1701f, 0.4082f, 0.8845f, -1.4969f, -0.7825f,
68+
-0.6278f, 1.3812f, -2.0090f, 0.5022f, 0.6278f, -0.0000f, 0.1256f};
6869
float[] centroids = result[0].getData();
6970

7071
for (int i = 0; i < centroids.length; i++){
7172
Assert.assertEquals(expected_c[i], centroids[i],1e-4f);
7273
}
73-
}
7474

75+
result[0].close();
76+
result[1].close();
77+
}
7578
}
7679
}

0 commit comments

Comments
 (0)