Skip to content

Commit 2abad06

Browse files
committed
dynaml-tensorflow: Improvements to tensorflow wave pde example
Signed-off-by: mandar2812 <[email protected]>
1 parent 76bc708 commit 2abad06

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

scripts/tf_wave_pde.sc

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import scala.util.Random
21
import org.platanios.tensorflow.api._
32
import _root_.io.github.mandar2812.dynaml.utils
3+
import _root_.io.github.mandar2812.dynaml.pipes.TupleIntegerEncoder
44
import _root_.io.github.mandar2812.dynaml.graphics.plot3d
55
import _root_.io.github.mandar2812.dynaml.graphics.plot3d.DelauneySurface
66
import _root_.io.github.mandar2812.dynaml.tensorflow.dtf
@@ -29,6 +29,15 @@ def laplace(x: Output): Output = {
2929
simple_conv(x, laplace_k)
3030
}
3131

32+
//The mexican hat wavelet function
33+
val mexican = (sigma: Double) => (x: Double, y: Double) => {
34+
(1.0/sigma*sigma*math.Pi)*(1.0 - 0.5*(x*x + y*y)/(sigma*sigma))*math.exp(-0.5*(x*x + y*y)/(sigma*sigma))
35+
}
36+
37+
val gaussian = (sigma: Double) => (x: Double, y: Double) => {
38+
(1.0/sigma*math.sqrt(2*math.Pi))*math.exp(-0.5*(x*x + y*y)/(sigma*sigma))
39+
}
40+
3241
//Plot a snapshot of the solution as a 3d plot.
3342
def plot_field_snapshot(
3443
t: Tensor,
@@ -65,7 +74,7 @@ def plot_field(
6574
xDomain: (Double, Double) = (-5.0, 5.0),
6675
yDomain: (Double, Double) = (-5.0, 5.0)): Seq[DelauneySurface] = {
6776

68-
val indices = utils.range(1.0, solution.length.toDouble, num_snapshots).map(_.toInt).filter(_ < 1000)
77+
val indices = utils.range(1.0, solution.length.toDouble, num_snapshots).map(_.toInt)
6978

7079
indices.map(i => plot_field_snapshot(
7180
if(quantity == "displacement") solution(i)._1 else solution(i)._2,
@@ -78,14 +87,29 @@ def main(
7887
size: Int = 500,
7988
num_iterations: Int = 1000,
8089
eps: Float = 0.001f,
81-
damping: Float = 0.04f): Seq[(Tensor, Tensor)] = {
90+
damping: Float = 0.04f,
91+
xDomain: (Double, Double) = (-5.0, 5.0),
92+
yDomain: (Double, Double) = (-5.0, 5.0),
93+
u_0: (Double, Double) => Double = mexican(1.0)): Seq[(Tensor, Tensor)] = {
8294

8395
//Start Tensorflow session
8496
val sess = Session()
8597

98+
val (x_grid, y_grid) = (
99+
utils.range(xDomain._1, xDomain._2, size),
100+
utils.range(yDomain._1, yDomain._2, size))
101+
102+
val encoder = TupleIntegerEncoder(List(size, size))
103+
86104
//Initial Conditions -- some rain drops hit a pond
87105
val (u_init, ut_init) = (
88-
Seq.tabulate[Double](size*size)(_ => if(Random.nextDouble() <= 0.95) 0d else Random.nextDouble()),
106+
Seq.tabulate[Double](size*size)(k => {
107+
val List(i, j) = encoder.i(k)
108+
109+
val (x, y) = (x_grid(i), y_grid(j))
110+
111+
u_0(x, y)
112+
}),
89113
Seq.fill[Double](size*size)(0d)
90114
)
91115

0 commit comments

Comments
 (0)