1- import scala .util .Random
21import org .platanios .tensorflow .api ._
32import _root_ .io .github .mandar2812 .dynaml .utils
3+ import _root_ .io .github .mandar2812 .dynaml .pipes .TupleIntegerEncoder
44import _root_ .io .github .mandar2812 .dynaml .graphics .plot3d
55import _root_ .io .github .mandar2812 .dynaml .graphics .plot3d .DelauneySurface
66import _root_ .io .github .mandar2812 .dynaml .tensorflow .dtf
@@ -29,6 +29,15 @@ def laplace(x: Output): Output = {
2929 simple_conv(x, laplace_k)
3030}
3131
32+ // The mexican hat wavelet function
33+ val mexican = (sigma : Double ) => (x : Double , y : Double ) => {
34+ (1.0 / sigma* sigma* math.Pi )* (1.0 - 0.5 * (x* x + y* y)/ (sigma* sigma))* math.exp(- 0.5 * (x* x + y* y)/ (sigma* sigma))
35+ }
36+
37+ val gaussian = (sigma : Double ) => (x : Double , y : Double ) => {
38+ (1.0 / sigma* math.sqrt(2 * math.Pi ))* math.exp(- 0.5 * (x* x + y* y)/ (sigma* sigma))
39+ }
40+
3241// Plot a snapshot of the solution as a 3d plot.
3342def plot_field_snapshot (
3443 t : Tensor ,
@@ -65,7 +74,7 @@ def plot_field(
6574 xDomain : (Double , Double ) = (- 5.0 , 5.0 ),
6675 yDomain : (Double , Double ) = (- 5.0 , 5.0 )): Seq [DelauneySurface ] = {
6776
68- val indices = utils.range(1.0 , solution.length.toDouble, num_snapshots).map(_.toInt).filter(_ < 1000 )
77+ val indices = utils.range(1.0 , solution.length.toDouble, num_snapshots).map(_.toInt)
6978
7079 indices.map(i => plot_field_snapshot(
7180 if (quantity == " displacement" ) solution(i)._1 else solution(i)._2,
@@ -78,14 +87,29 @@ def main(
7887 size : Int = 500 ,
7988 num_iterations : Int = 1000 ,
8089 eps : Float = 0.001f ,
81- damping : Float = 0.04f ): Seq [(Tensor , Tensor )] = {
90+ damping : Float = 0.04f ,
91+ xDomain : (Double , Double ) = (- 5.0 , 5.0 ),
92+ yDomain : (Double , Double ) = (- 5.0 , 5.0 ),
93+ u_0 : (Double , Double ) => Double = mexican(1.0 )): Seq [(Tensor , Tensor )] = {
8294
8395 // Start Tensorflow session
8496 val sess = Session ()
8597
98+ val (x_grid, y_grid) = (
99+ utils.range(xDomain._1, xDomain._2, size),
100+ utils.range(yDomain._1, yDomain._2, size))
101+
102+ val encoder = TupleIntegerEncoder (List (size, size))
103+
86104 // Initial Conditions -- some rain drops hit a pond
87105 val (u_init, ut_init) = (
88- Seq .tabulate[Double ](size* size)(_ => if (Random .nextDouble() <= 0.95 ) 0d else Random .nextDouble()),
106+ Seq .tabulate[Double ](size* size)(k => {
107+ val List (i, j) = encoder.i(k)
108+
109+ val (x, y) = (x_grid(i), y_grid(j))
110+
111+ u_0(x, y)
112+ }),
89113 Seq .fill[Double ](size* size)(0d )
90114 )
91115
0 commit comments