11import scala .util .Random
22import org .platanios .tensorflow .api ._
3- import _root_ .io .github .mandar2812 .dynaml .tensorflow ._
3+ import _root_ .io .github .mandar2812 .dynaml .utils
4+ import _root_ .io .github .mandar2812 .dynaml .graphics .plot3d
5+ import _root_ .io .github .mandar2812 .dynaml .graphics .plot3d .DelauneySurface
6+ import _root_ .io .github .mandar2812 .dynaml .tensorflow .dtf
47import _root_ .io .github .mandar2812 .dynaml .repl .Router .main
58import org .platanios .tensorflow .api .ops .NN .SamePadding
69import org .platanios .tensorflow .api .ops .variables .ConstantInitializer
@@ -26,8 +29,56 @@ def laplace(x: Output): Output = {
2629 simple_conv(x, laplace_k)
2730}
2831
32+ // Plot a snapshot of the solution as a 3d plot.
33+ def plot_field_snapshot (
34+ t : Tensor ,
35+ xDomain : (Double , Double ),
36+ yDomain : (Double , Double )): DelauneySurface = {
37+
38+ val (rows, cols) = (t.shape(0 ), t.shape(1 ))
39+
40+ val stencil_y = utils.range(min = xDomain._1, max = xDomain._2, rows)
41+ val stencil_x = utils.range(min = yDomain._1, max = yDomain._2, cols)
42+
43+ val data = t.entriesIterator
44+ .map(_.asInstanceOf [Float ].toDouble)
45+ .toSeq.grouped(cols).zipWithIndex
46+ .flatMap(rowI => {
47+ val (row, row_index) = rowI
48+
49+ val y = stencil_y(row_index)
50+
51+ row.zipWithIndex.map(nI => {
52+ val (num, index) = nI
53+ val x = stencil_x(index)
54+ ((x, y), num)
55+ })
56+ }).toStream
57+
58+ plot3d.draw(data)
59+ }
60+
61+ def plot_field (
62+ solution : Seq [(Tensor , Tensor )])(
63+ num_snapshots : Int ,
64+ quantity : String = " displacement" ,
65+ xDomain : (Double , Double ) = (- 5.0 , 5.0 ),
66+ yDomain : (Double , Double ) = (- 5.0 , 5.0 )): Seq [DelauneySurface ] = {
67+
68+ val indices = utils.range(1.0 , solution.length.toDouble, num_snapshots).map(_.toInt).filter(_ < 1000 )
69+
70+ indices.map(i => plot_field_snapshot(
71+ if (quantity == " displacement" ) solution(i)._1 else solution(i)._2,
72+ xDomain, yDomain)
73+ )
74+ }
75+
2976@ main
30- def main (size : Int = 500 , num_iterations : Int = 1000 ) = {
77+ def main (
78+ size : Int = 500 ,
79+ num_iterations : Int = 1000 ,
80+ eps : Float = 0.001f ,
81+ damping : Float = 0.04f ): Seq [(Tensor , Tensor )] = {
3182
3283 // Start Tensorflow session
3384 val sess = Session ()
@@ -57,7 +108,7 @@ def main(size: Int = 500, num_iterations: Int = 1000) = {
57108 initializer = ConstantInitializer (dtf.tensor_f32(size, size)(ut_init:_* )))
58109
59110 // Discretized PDE update rules
60- val U_ = U + eps * Ut
111+ val U_ = U + eps * Ut
61112 val Ut_ = Ut + eps * (laplace(U ) - damping * Ut )
62113
63114 // Operation to update the state
@@ -73,7 +124,7 @@ def main(size: Int = 500, num_iterations: Int = 1000) = {
73124 pprint.pprintln(i)
74125
75126 val step_output : (Tensor , Tensor ) = sess.run(
76- feeds = Map (eps -> Tensor (0.03f ), damping -> Tensor (0.04f )),
127+ feeds = Map (eps -> Tensor (0.001f ), damping -> Tensor (0.04f )),
77128 fetches = (U_ , Ut_ ),
78129 targets = step)
79130
0 commit comments