Skip to content

Commit 76bc708

Browse files
committed
dynaml-tensorflow: Improvements to tensorflow wave pde example
- added ability to plot field snapshots Signed-off-by: mandar2812 <[email protected]>
1 parent 8f1c559 commit 76bc708

File tree

1 file changed

+55
-4
lines changed

1 file changed

+55
-4
lines changed

scripts/tf-wave-pde.sc renamed to scripts/tf_wave_pde.sc

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import scala.util.Random
22
import org.platanios.tensorflow.api._
3-
import _root_.io.github.mandar2812.dynaml.tensorflow._
3+
import _root_.io.github.mandar2812.dynaml.utils
4+
import _root_.io.github.mandar2812.dynaml.graphics.plot3d
5+
import _root_.io.github.mandar2812.dynaml.graphics.plot3d.DelauneySurface
6+
import _root_.io.github.mandar2812.dynaml.tensorflow.dtf
47
import _root_.io.github.mandar2812.dynaml.repl.Router.main
58
import org.platanios.tensorflow.api.ops.NN.SamePadding
69
import org.platanios.tensorflow.api.ops.variables.ConstantInitializer
@@ -26,8 +29,56 @@ def laplace(x: Output): Output = {
2629
simple_conv(x, laplace_k)
2730
}
2831

32+
//Plot a snapshot of the solution as a 3d plot.
33+
def plot_field_snapshot(
34+
t: Tensor,
35+
xDomain: (Double, Double),
36+
yDomain: (Double, Double)): DelauneySurface = {
37+
38+
val (rows, cols) = (t.shape(0), t.shape(1))
39+
40+
val stencil_y = utils.range(min = xDomain._1, max = xDomain._2, rows)
41+
val stencil_x = utils.range(min = yDomain._1, max = yDomain._2, cols)
42+
43+
val data = t.entriesIterator
44+
.map(_.asInstanceOf[Float].toDouble)
45+
.toSeq.grouped(cols).zipWithIndex
46+
.flatMap(rowI => {
47+
val (row, row_index) = rowI
48+
49+
val y = stencil_y(row_index)
50+
51+
row.zipWithIndex.map(nI => {
52+
val (num, index) = nI
53+
val x = stencil_x(index)
54+
((x, y), num)
55+
})
56+
}).toStream
57+
58+
plot3d.draw(data)
59+
}
60+
61+
def plot_field(
62+
solution: Seq[(Tensor, Tensor)])(
63+
num_snapshots: Int,
64+
quantity: String = "displacement",
65+
xDomain: (Double, Double) = (-5.0, 5.0),
66+
yDomain: (Double, Double) = (-5.0, 5.0)): Seq[DelauneySurface] = {
67+
68+
val indices = utils.range(1.0, solution.length.toDouble, num_snapshots).map(_.toInt).filter(_ < 1000)
69+
70+
indices.map(i => plot_field_snapshot(
71+
if(quantity == "displacement") solution(i)._1 else solution(i)._2,
72+
xDomain, yDomain)
73+
)
74+
}
75+
2976
@main
30-
def main(size: Int = 500, num_iterations: Int = 1000) = {
77+
def main(
78+
size: Int = 500,
79+
num_iterations: Int = 1000,
80+
eps: Float = 0.001f,
81+
damping: Float = 0.04f): Seq[(Tensor, Tensor)] = {
3182

3283
//Start Tensorflow session
3384
val sess = Session()
@@ -57,7 +108,7 @@ def main(size: Int = 500, num_iterations: Int = 1000) = {
57108
initializer = ConstantInitializer(dtf.tensor_f32(size, size)(ut_init:_*)))
58109

59110
//Discretized PDE update rules
60-
val U_ = U + eps * Ut
111+
val U_ = U + eps * Ut
61112
val Ut_ = Ut + eps * (laplace(U) - damping * Ut)
62113

63114
//Operation to update the state
@@ -73,7 +124,7 @@ def main(size: Int = 500, num_iterations: Int = 1000) = {
73124
pprint.pprintln(i)
74125

75126
val step_output: (Tensor, Tensor) = sess.run(
76-
feeds = Map(eps -> Tensor(0.03f), damping -> Tensor(0.04f)),
127+
feeds = Map(eps -> Tensor(0.001f), damping -> Tensor(0.04f)),
77128
fetches = (U_, Ut_),
78129
targets = step)
79130

0 commit comments

Comments
 (0)