Skip to content

Commit fcc51d5

Browse files
committed
dynamo-core: Additions to tf utilities
- Added gaussian and min-max scalers for TF data sets - Added `dtfpipe` a workflows/pipes library for working with TF data
1 parent 5185f7c commit fcc51d5

File tree

6 files changed

+170
-3
lines changed

6 files changed

+170
-3
lines changed

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ packageSummary := "Scala Library/REPL for Machine Learning Research"
99

1010
packageDescription := "DynaML is a Scala environment for conducting research and education in Machine Learning. DynaML comes packaged with a powerful library of classes for various predictive models and a Scala REPL where one can not only build custom models but also play around with data work-flows. It can also be used as an educational/research tool for data analysis."
1111

12-
val mainVersion = "v1.5.2"
12+
val mainVersion = "v1.5.3-beta.1"
1313

1414
val dataDirectory = settingKey[File]("The directory holding the data files for running example scripts")
1515

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/tensorflow/package.scala

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ package io.github.mandar2812.dynaml
2121
import java.nio.ByteBuffer
2222

2323
import io.github.mandar2812.dynaml.probability._
24+
import io.github.mandar2812.dynaml.pipes._
25+
import io.github.mandar2812.dynaml.tensorflow.utils._
2426
import org.platanios.tensorflow.api._
2527
import org.platanios.tensorflow.api.core.Shape
2628
import org.platanios.tensorflow.api.ops.NN.SamePadding
@@ -287,8 +289,11 @@ package object tensorflow {
287289
}
288290

289291
/**
292+
* <h4>DynaML Neural Net Building Blocks</h4>
290293
*
291-
*
294+
* The [[dtflearn]] object contains components
295+
* that can be used to create custom neural architectures,
296+
* from basic building blocks.
292297
* */
293298
object dtflearn {
294299

@@ -397,4 +402,62 @@ package object tensorflow {
397402

398403
}
399404

405+
/**
406+
* <h4>DynaML Tensorflow Pipes</h4>
407+
*
408+
* The [[dtfpipe]] contains workflows/pipelines to simplify working
409+
* with tensorflow data sets and models.
410+
* */
411+
object dtfpipe {
412+
413+
val gaussian_standardization: DataPipe2[Tensor, Tensor, ((Tensor, Tensor), (GaussianScalerTF, GaussianScalerTF))] =
414+
DataPipe2((features: Tensor, labels: Tensor) => {
415+
416+
val (features_mean, labels_mean) = (features.mean(axes = 0), labels.mean(axes = 0))
417+
418+
val n_data = features.shape(0).scalar.asInstanceOf[Int].toDouble
419+
420+
val (features_sd, labels_sd) = (
421+
features.subtract(features_mean).square.mean(axes = 0).multiply(n_data/(n_data - 1d)).sqrt,
422+
labels.subtract(labels_mean).square.mean(axes = 0).multiply(n_data/(n_data - 1d)).sqrt
423+
)
424+
425+
val (features_scaler, labels_scaler) = (
426+
GaussianScalerTF(features_mean, features_sd),
427+
GaussianScalerTF(labels_mean, labels_sd)
428+
)
429+
430+
val (features_scaled, labels_scaled) = (
431+
features_scaler(features),
432+
labels_scaler(labels)
433+
)
434+
435+
((features_scaled, labels_scaled), (features_scaler, labels_scaler))
436+
})
437+
438+
val minmax_standardization: DataPipe2[Tensor, Tensor, ((Tensor, Tensor), (MinMaxScalerTF, MinMaxScalerTF))] =
439+
DataPipe2((features: Tensor, labels: Tensor) => {
440+
441+
val (features_min, labels_min) = (features.min(axes = 0), labels.min(axes = 0))
442+
443+
val (features_max, labels_max) = (
444+
features.max(axes = 0),
445+
labels.max(axes = 0)
446+
)
447+
448+
val (features_scaler, labels_scaler) = (
449+
MinMaxScalerTF(features_min, features_max),
450+
MinMaxScalerTF(labels_min, labels_max)
451+
)
452+
453+
val (features_scaled, labels_scaled) = (
454+
features_scaler(features),
455+
labels_scaler(labels)
456+
)
457+
458+
((features_scaled, labels_scaled), (features_scaler, labels_scaler))
459+
})
460+
461+
}
462+
400463
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
* */
19+
package io.github.mandar2812.dynaml.tensorflow.utils
20+
21+
import org.platanios.tensorflow.api._
22+
import _root_.io.github.mandar2812.dynaml.pipes._
23+
24+
/**
25+
* Scales attributes of a vector pattern using the sample mean and variance of
26+
* each dimension. This assumes that there is no covariance between the data
27+
* dimensions.
28+
*
29+
* @param mean Sample mean of the data
30+
* @param sigma Sample variance of each data dimension
31+
* @author mandar2812 date: 07/03/2018.
32+
*
33+
* */
34+
case class GaussianScalerTF(mean: Tensor, sigma: Tensor) extends TFScaler {
35+
36+
override val i: Scaler[Tensor] = Scaler((xc: Tensor) => xc.multiply(sigma).add(mean))
37+
38+
override def run(data: Tensor): Tensor = data.subtract(mean).divide(sigma)
39+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
* */
19+
package io.github.mandar2812.dynaml.tensorflow.utils
20+
21+
import org.platanios.tensorflow.api._
22+
import _root_.io.github.mandar2812.dynaml.pipes._
23+
24+
/**
25+
* Scales attributes of a vector pattern using the sample minimum and maximum of
26+
* each dimension.
27+
*
28+
* @param min Sample minimum of the data
29+
* @param max Sample maximum of each data dimension
30+
* @author mandar2812 date: 07/03/2018.
31+
*
32+
* */
33+
case class MinMaxScalerTF(min: Tensor, max: Tensor) extends TFScaler {
34+
35+
val delta: Tensor = max.subtract(min)
36+
37+
override val i: Scaler[Tensor] = Scaler((xc: Tensor) => xc.multiply(delta).add(min))
38+
39+
override def run(data: Tensor): Tensor = data.subtract(min).divide(delta)
40+
41+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one
3+
or more contributor license agreements. See the NOTICE file
4+
distributed with this work for additional information
5+
regarding copyright ownership. The ASF licenses this file
6+
to you under the Apache License, Version 2.0 (the
7+
"License"); you may not use this file except in compliance
8+
with the License. You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing,
13+
software distributed under the License is distributed on an
14+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
KIND, either express or implied. See the License for the
16+
specific language governing permissions and limitations
17+
under the License.
18+
* */
19+
package io.github.mandar2812.dynaml.tensorflow.utils
20+
21+
import io.github.mandar2812.dynaml.pipes.ReversibleScaler
22+
import org.platanios.tensorflow.api.Tensor
23+
24+
abstract class TFScaler extends ReversibleScaler[Tensor]

dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/tensorflow/TensorBasis.scala renamed to dynaml-core/src/main/scala-2.11/io/github/mandar2812/dynaml/tensorflow/utils/TensorBasis.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ KIND, either express or implied. See the License for the
1616
specific language governing permissions and limitations
1717
under the License.
1818
* */
19-
package io.github.mandar2812.dynaml.tensorflow
19+
package io.github.mandar2812.dynaml.tensorflow.utils
2020

2121
import io.github.mandar2812.dynaml.pipes.DataPipe
2222
import org.apache.spark.annotation.Experimental

0 commit comments

Comments
 (0)