Skip to content

Commit 2ae9051

Browse files
committed
Refactor SocketClient
Since the SocketClient was not complex to begin with, combining the start-poll-stop lifecycle within one function is much easier to maintain cleaner, rather than having the start/stop contract be unclear.
1 parent a0f22da commit 2ae9051

File tree

4 files changed

+96
-89
lines changed

4 files changed

+96
-89
lines changed

workflow-trace-viewer/src/jvmMain/kotlin/com/squareup/workflow1/traceviewer/App.kt

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@ import androidx.compose.ui.geometry.Offset
1717
import com.squareup.workflow1.traceviewer.model.Node
1818
import com.squareup.workflow1.traceviewer.model.NodeUpdate
1919
import com.squareup.workflow1.traceviewer.ui.FrameSelectTab
20-
import com.squareup.workflow1.traceviewer.util.RenderTrace
2120
import com.squareup.workflow1.traceviewer.ui.RightInfoPanel
2221
import com.squareup.workflow1.traceviewer.ui.TraceModeToggleSwitch
22+
import com.squareup.workflow1.traceviewer.util.RenderTrace
2323
import com.squareup.workflow1.traceviewer.util.SandboxBackground
24-
import com.squareup.workflow1.traceviewer.util.SocketClient
2524
import com.squareup.workflow1.traceviewer.util.UploadFile
2625
import io.github.vinceglb.filekit.PlatformFile
2726

@@ -40,7 +39,6 @@ internal fun App(
4039
// Default to File mode, and can be toggled to be in Live mode.
4140
var traceMode by remember { mutableStateOf<TraceMode>(TraceMode.File(null)) }
4241
var selectedTraceFile by remember { mutableStateOf<PlatformFile?>(null) }
43-
val socket = remember { SocketClient() }
4442

4543
LaunchedEffect(sandboxState) {
4644
snapshotFlow { frameIndex }.collect {
@@ -52,7 +50,6 @@ internal fun App(
5250
modifier = modifier
5351
) {
5452
fun resetStates() {
55-
socket.close()
5653
selectedTraceFile = null
5754
selectedNode = null
5855
frameIndex = 0
@@ -66,7 +63,6 @@ internal fun App(
6663
// if there is not a file selected and trace mode is live, then don't render anything.
6764
val readyForFileTrace = traceMode is TraceMode.File && selectedTraceFile != null
6865
val readyForLiveTrace = traceMode is TraceMode.Live
69-
7066
if (readyForFileTrace || readyForLiveTrace) {
7167
RenderTrace(
7268
traceSource = traceMode,
@@ -106,8 +102,7 @@ internal fun App(
106102
frames get populated, so we avoid off by one when indexing into the frames.
107103
*/
108104
frameIndex = -1
109-
socket.open()
110-
TraceMode.Live(socket)
105+
TraceMode.Live
111106
}
112107
},
113108
traceMode = traceMode,
@@ -139,5 +134,5 @@ internal class SandboxState {
139134

140135
internal sealed interface TraceMode {
141136
data class File(val file: PlatformFile?) : TraceMode
142-
data class Live(val socket: SocketClient) : TraceMode
137+
data object Live : TraceMode
143138
}

workflow-trace-viewer/src/jvmMain/kotlin/com/squareup/workflow1/traceviewer/Main.kt

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,6 @@ import androidx.compose.ui.window.singleWindowApplication
88
* Main entry point for the desktop application, see [README.md] for more details.
99
*/
1010
fun main() {
11-
Runtime.getRuntime().addShutdownHook(
12-
Thread {
13-
ProcessBuilder("adb", "forward", "--remove-all")
14-
.start().waitFor()
15-
}
16-
)
1711
singleWindowApplication(title = "Workflow Trace Viewer", exitProcessOnExit = false) {
1812
App(Modifier.fillMaxSize())
1913
}
Lines changed: 79 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,99 @@
11
package com.squareup.workflow1.traceviewer.util
22

33
import kotlinx.coroutines.Dispatchers
4-
import kotlinx.coroutines.channels.Channel
4+
import kotlinx.coroutines.awaitCancellation
5+
import kotlinx.coroutines.coroutineScope
6+
import kotlinx.coroutines.ensureActive
7+
import kotlinx.coroutines.launch
8+
import kotlinx.coroutines.runInterruptible
59
import kotlinx.coroutines.withContext
10+
import okio.IOException
611
import java.net.Socket
7-
import java.net.SocketException
812

913
/**
10-
* This is a client that can connect to any server socket that sends render pass data while using
11-
* the Workflow framework.
14+
* Collects data from a server socket and serves them back to the caller via callback.
1215
*
13-
* [start] and [close] are idempotent commands, so this socket can only be started and closed once.
16+
* Two cases that are guaranteed to fail:
17+
* 1) The app is not running
18+
* 2) A reattempt at establishing socket connection without restarting the app
1419
*
15-
* Since this app is on JVM and the server is on Android, we use ADB to forward the port onto the socket.
20+
* @param onNewRenderPass is called from an arbitrary thread, so it is important to ensure that the
21+
* caller is thread safe
1622
*/
17-
internal class SocketClient {
18-
private lateinit var socket: Socket
19-
private var initialized = false
20-
val renderPassChannel: Channel<String> = Channel(Channel.BUFFERED)
21-
22-
/**
23-
* We use any available ports on the host machine to connect to the emulator.
24-
*
25-
* `workflow-trace` is the name of the unix socket created, and since Android uses
26-
* `LocalServerSocket` -- which creates a unix socket on the linux abstract namespace -- we use
27-
* `localabstract:` to connect to it.
28-
*/
29-
fun open() {
30-
if (initialized) {
31-
return
23+
suspend fun pollSocket(onNewRenderPass: suspend (String) -> Unit) {
24+
withContext(Dispatchers.IO) {
25+
try {
26+
runForwardingPortThroughAdb { port ->
27+
Socket("localhost", port).useWithCancellation { socket ->
28+
val reader = socket.getInputStream().bufferedReader()
29+
do {
30+
ensureActive()
31+
val input = reader.readLine()
32+
if (input != null) {
33+
onNewRenderPass(input)
34+
}
35+
} while (input != null)
36+
}
37+
}
38+
} catch (e: IOException) {
39+
// NoOp
3240
}
33-
initialized = true
34-
val process = ProcessBuilder(
35-
"adb", "forward", "tcp:0", "localabstract:workflow-trace"
36-
).start()
41+
}
42+
}
3743

38-
// The adb forward command will output the port number it picks to connect.
39-
process.waitFor()
40-
val port = process.inputStream.bufferedReader().readText()
41-
.trim().toInt()
44+
/**
45+
* Force [pollSocket] to exit with exception if the coroutine is cancelled. See comment below.
46+
*/
47+
private suspend fun Socket.useWithCancellation(block: suspend (Socket) -> Unit) {
48+
val socket = this
49+
coroutineScope {
50+
// This coroutine is responsible for forcibly closing the socket when the coroutine is
51+
// cancelled. This causes any code reading from the socket to throw a CancellationException.
52+
// We also need to explicitly cancel this coroutine if the block returns on its own, otherwise
53+
// the coroutineScope will never exit.
54+
val socketJob = launch {
55+
socket.use {
56+
awaitCancellation()
57+
}
58+
}
4259

43-
socket = Socket("localhost", port)
60+
block(socket)
61+
socketJob.cancel()
4462
}
63+
}
4564

46-
fun close() {
47-
if (!initialized) {
48-
return
49-
}
50-
socket.close()
65+
/**
66+
* Call adb to setup a port forwarding to the server socket, and calls block with the allocated
67+
* port number if successful.
68+
*
69+
* If block throws or returns on finish, the port forwarding is removed via adb (best effort).
70+
*/
71+
@Suppress("BlockingMethodInNonBlockingContext")
72+
private suspend inline fun runForwardingPortThroughAdb(block: (port: Int) -> Unit) {
73+
val process = ProcessBuilder(
74+
"adb", "forward", "tcp:0", "localabstract:workflow-trace"
75+
).start()
76+
77+
// The adb forward command will output the port number it picks to connect.
78+
val forwardReturnCode = runInterruptible {
79+
process.waitFor()
80+
}
81+
if (forwardReturnCode != 0) {
82+
return
5183
}
5284

53-
/**
54-
* Polls the socket's input stream and sends the data into [renderPassChannel].
55-
* The caller should handle the scope of the coroutine that this function is called in.
56-
*
57-
* To better separate the responsibility of reading from the socket, we use a channel for the caller
58-
* to handle parsing and amalgamating the render passes.
59-
*/
60-
suspend fun pollSocket() {
61-
withContext(Dispatchers.IO) {
62-
val reader = socket.getInputStream().bufferedReader()
63-
reader.use {
64-
try {
65-
while (true) {
66-
val input = reader.readLine()
67-
renderPassChannel.trySend(input)
68-
}
69-
} catch (e: SocketException) {
70-
e.printStackTrace()
71-
}
72-
}
85+
val port = process.inputStream.bufferedReader().readText()
86+
.trim().toInt()
87+
88+
try {
89+
block(port)
90+
} finally {
91+
// We don't care if this fails since there's nothing we can do then anyway. It just means
92+
// there's an extra forward left open, but that's not a big deal.
93+
runCatching {
94+
ProcessBuilder(
95+
"adb", "forward", "--remove", "tcp:$port"
96+
).start()
7397
}
7498
}
7599
}

workflow-trace-viewer/src/jvmMain/kotlin/com/squareup/workflow1/traceviewer/util/TraceParser.kt

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@ import com.squareup.moshi.JsonAdapter
1414
import com.squareup.workflow1.traceviewer.TraceMode
1515
import com.squareup.workflow1.traceviewer.model.Node
1616
import com.squareup.workflow1.traceviewer.ui.DrawTree
17-
import kotlinx.coroutines.Dispatchers
17+
import kotlinx.coroutines.channels.Channel
1818
import kotlinx.coroutines.launch
19-
import kotlinx.coroutines.withContext
2019
import java.net.SocketException
2120

2221
/**
@@ -35,7 +34,7 @@ internal fun RenderTrace(
3534
modifier: Modifier = Modifier
3635
) {
3736
var isLoading by remember(traceSource) { mutableStateOf(true) }
38-
var error by remember(traceSource) { mutableStateOf<Throwable?>(null) }
37+
var error by remember(traceSource) { mutableStateOf<String?>(null) }
3938
val frames = remember { mutableStateListOf<Node>() }
4039
val fullTree = remember { mutableStateListOf<Node>() }
4140
val affectedNodes = remember { mutableStateListOf<Set<Node>>() }
@@ -57,7 +56,7 @@ internal fun RenderTrace(
5756
) {
5857
when (parseResult) {
5958
is ParseResult.Failure -> {
60-
error = parseResult.error
59+
error = parseResult.error.toString()
6160
}
6261
is ParseResult.Success -> {
6362
addToStates(
@@ -81,34 +80,29 @@ internal fun RenderTrace(
8180
}
8281

8382
is TraceMode.Live -> {
84-
val socket = traceSource.socket
83+
val renderPassChannel: Channel<String> = Channel(Channel.BUFFERED)
8584
launch {
8685
try {
87-
socket.pollSocket()
88-
} catch (e: SocketException) {
89-
error = SocketException("Socket has already been closed or is not available: ${e.message}")
90-
return@launch
86+
pollSocket(onNewRenderPass = renderPassChannel::send)
87+
error = "Socket has already been closed or is not available."
88+
} finally {
89+
renderPassChannel.close()
9190
}
9291
}
93-
if (error != null) {
94-
return@LaunchedEffect
95-
}
9692
val adapter: JsonAdapter<List<Node>> = createMoshiAdapter<Node>()
9793

98-
withContext(Dispatchers.Default) {
99-
// Since channel implements ChannelIterator, we can for-loop through on the receiver end.
100-
for (renderPass in socket.renderPassChannel) {
101-
val currentTree = fullTree.lastOrNull()
102-
val parseResult = parseLiveTrace(renderPass, adapter, currentTree)
103-
handleParseResult(parseResult, onNewFrame)
104-
}
94+
// Since channel implements ChannelIterator, we can for-loop through on the receiver end.
95+
for (renderPass in renderPassChannel) {
96+
val currentTree = fullTree.lastOrNull()
97+
val parseResult = parseLiveTrace(renderPass, adapter, currentTree)
98+
handleParseResult(parseResult, onNewFrame)
10599
}
106100
}
107101
}
108102
}
109103

110104
if (error != null) {
111-
Text("Error parsing: ${error?.message}")
105+
Text("Error parsing: ${error}")
112106
return
113107
}
114108

0 commit comments

Comments
 (0)