Skip to content

Commit 888a785

Browse files
committed
fix: add checks to prevent access to uninitialized natives
1 parent 5519e96 commit 888a785

File tree

4 files changed

+142
-56
lines changed

4 files changed

+142
-56
lines changed

build.gradle.kts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ tasks {
7474
useJUnitPlatform()
7575
}
7676

77-
register<Copy>("generateTemplates") {
77+
val generateTemplates = register<Copy>("generateTemplates") {
7878
from(templateSrc)
7979
into(templateDst)
8080
expand(templateProps)
@@ -84,6 +84,10 @@ tasks {
8484
outputs.dir(templateDst)
8585
}
8686

87+
withType<Jar> {
88+
dependsOn(generateTemplates)
89+
}
90+
8791
compileKotlin {
8892
dependsOn("generateTemplates")
8993
}
@@ -102,6 +106,11 @@ sourceSets.main {
102106
}
103107
}
104108

109+
java {
110+
withSourcesJar()
111+
withJavadocJar()
112+
}
113+
105114
allprojects {
106115
apply<MavenPublishPlugin>()
107116
apply<BasePlugin>()

src/main/kotlin/dev/silenium/multimedia/compose/player/VideoPlayer.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ class VideoPlayer(hwdec: Boolean = false) : AutoCloseable {
7878
}
7979

8080
override fun close() {
81-
mpv.command("stop")
8281
render?.close()
8382
mpv.close()
8483
}

src/main/kotlin/dev/silenium/multimedia/core/mpv/MPV.kt

Lines changed: 129 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ import java.util.concurrent.ConcurrentHashMap
1414
import java.util.concurrent.ConcurrentLinkedQueue
1515
import java.util.concurrent.atomic.AtomicBoolean
1616
import java.util.concurrent.atomic.AtomicLong
17+
import kotlin.contracts.ExperimentalContracts
18+
import kotlin.contracts.InvocationKind
19+
import kotlin.contracts.contract
1720
import kotlin.coroutines.EmptyCoroutineContext
1821
import kotlin.coroutines.resume
1922
import kotlin.reflect.KClass
@@ -45,13 +48,39 @@ class MPV : NativeCleanable, MPVAsyncListener {
4548

4649
private val propertyUpdates = MutableSharedFlow<Property<*>>()
4750

51+
@OptIn(ExperimentalContracts::class)
52+
private inline fun <R : Any> guard(other: R? = null, block: () -> Result<R>): Result<R> {
53+
contract { callsInPlace(block, InvocationKind.AT_MOST_ONCE) }
54+
if (!initialized.get()) {
55+
if (other != null) {
56+
return Result.success(other)
57+
}
58+
return Result.failure(IllegalStateException("MPV is not initialized"))
59+
}
60+
return block()
61+
}
62+
63+
@OptIn(ExperimentalContracts::class)
64+
@JvmName("guardNonNull")
65+
private inline fun <R> guardNonNull(other: R? = null, block: () -> Result<R?>): Result<R?> {
66+
contract { callsInPlace(block, InvocationKind.AT_MOST_ONCE) }
67+
if (!initialized.get()) {
68+
if (other != null) {
69+
return Result.success(other)
70+
}
71+
return Result.failure(IllegalStateException("MPV is not initialized"))
72+
}
73+
return block()
74+
}
75+
4876
override val nativePointer = NativePointer(createN().getOrThrow()) {
49-
callback?.let(::unsetCallbackN)
50-
listener.close()
5177
destroyN(it)
5278
}
5379

5480
fun setOption(name: String, value: String) {
81+
if (nativePointer.closed) {
82+
error("MPV is closed")
83+
}
5584
if (initialized.get()) {
5685
logger.warn("Cannot set option after initialization, ignoring")
5786
return
@@ -76,15 +105,17 @@ class MPV : NativeCleanable, MPVAsyncListener {
76105
name: String,
77106
value: T,
78107
fn: (Long, String, T, Long) -> Result<Unit>,
79-
): Result<Unit> = suspendCancellableCoroutine { continuation ->
80-
val subscriptionId = propertySetCallbackId.getAndIncrement()
81-
propertySetCallbacks[subscriptionId] = { result ->
82-
continuation.resume(result)
83-
}
84-
fn(nativePointer.address, name, value, subscriptionId).onFailure {
85-
propertySetCallbacks.remove(subscriptionId)
86-
logger.error("Failed to set property $name", it)
87-
continuation.resume(Result.failure(it))
108+
): Result<Unit> = guard(Unit) {
109+
suspendCancellableCoroutine { continuation ->
110+
val subscriptionId = propertySetCallbackId.getAndIncrement()
111+
propertySetCallbacks[subscriptionId] = { result ->
112+
continuation.resume(result)
113+
}
114+
fn(nativePointer.address, name, value, subscriptionId).onFailure {
115+
propertySetCallbacks.remove(subscriptionId)
116+
logger.error("Failed to set property $name", it)
117+
continuation.resume(Result.failure(it))
118+
}
88119
}
89120
}
90121

@@ -102,18 +133,20 @@ class MPV : NativeCleanable, MPVAsyncListener {
102133
name: String,
103134
type: KClass<T>,
104135
fn: (Long, String, Long) -> Result<Unit>,
105-
): Result<T?> = suspendCancellableCoroutine { continuation ->
106-
val subscriptionId = propertyGetCallbackId.getAndIncrement()
107-
propertyGetCallbacks[subscriptionId] = { result ->
108-
continuation.resume(result.map {
109-
logger.debug("Got property {}: {}", name, it)
110-
it?.let(type::cast)
111-
})
112-
}
113-
fn(nativePointer.address, name, subscriptionId).onFailure {
114-
propertyGetCallbacks.remove(subscriptionId)
115-
logger.error("Failed to get property $name", it)
116-
continuation.resume(Result.failure(it))
136+
): Result<T?> = guardNonNull<T>(null) {
137+
suspendCancellableCoroutine { continuation ->
138+
val subscriptionId = propertyGetCallbackId.getAndIncrement()
139+
propertyGetCallbacks[subscriptionId] = { result ->
140+
continuation.resume(result.map {
141+
logger.debug("Got property {}: {}", name, it)
142+
it?.let(type::cast)
143+
})
144+
}
145+
fn(nativePointer.address, name, subscriptionId).onFailure {
146+
propertyGetCallbacks.remove(subscriptionId)
147+
logger.error("Failed to get property $name", it)
148+
continuation.resume(Result.failure(it))
149+
}
117150
}
118151
}
119152

@@ -127,23 +160,24 @@ class MPV : NativeCleanable, MPVAsyncListener {
127160
fun getPropertyDouble(name: String) = getPropertyDoubleN(nativePointer.address, name)
128161
fun getPropertyFlag(name: String) = getPropertyFlagN(nativePointer.address, name)
129162

130-
private fun subscribe(name: String, type: KClass<*>, fn: (Long, String, Long) -> Result<Unit>): Result<Unit> {
131-
if (propertySubscriptions.containsKey(name)) {
132-
logger.debug("Property $name is already being observed")
133-
return Result.success(Unit)
163+
private fun subscribe(name: String, type: KClass<*>, fn: (Long, String, Long) -> Result<Unit>): Result<Unit> =
164+
guard(Unit) {
165+
if (propertySubscriptions.containsKey(name)) {
166+
logger.debug("Property $name is already being observed")
167+
return Result.success(Unit)
168+
}
169+
logger.debug("Observing property $name")
170+
val subscriptionId = subscriptionId.getAndIncrement()
171+
return fn(nativePointer.address, name, subscriptionId)
172+
.map { propertySubscriptions[name] = subscriptionId to type }
134173
}
135-
logger.debug("Observing property $name")
136-
val subscriptionId = subscriptionId.getAndIncrement()
137-
return fn(nativePointer.address, name, subscriptionId)
138-
.map { propertySubscriptions[name] = subscriptionId to type }
139-
}
140174

141175
fun observePropertyString(name: String) = subscribe(name, String::class, ::observePropertyStringN)
142176
fun observePropertyLong(name: String) = subscribe(name, Long::class, ::observePropertyLongN)
143177
fun observePropertyDouble(name: String) = subscribe(name, Double::class, ::observePropertyDoubleN)
144178
fun observePropertyFlag(name: String) = subscribe(name, Boolean::class, ::observePropertyFlagN)
145179

146-
fun unobserveProperty(name: String): Result<Unit> {
180+
fun unobserveProperty(name: String): Result<Unit> = guard(Unit) {
147181
val (id, _) = propertySubscriptions[name] ?: run {
148182
logger.debug("Property $name is not being observed")
149183
return Result.success(Unit)
@@ -160,17 +194,27 @@ class MPV : NativeCleanable, MPVAsyncListener {
160194
override fun command(subscriptionCount: StateFlow<Int>): Flow<SharingCommand> {
161195
return wrapped.command(subscriptionCount).onEach { command ->
162196
when (command) {
163-
SharingCommand.START -> subscribe(name).getOrThrow()
197+
SharingCommand.START -> subscribe(name).getOrElse {
198+
logger.error("Failed to subscribe to property $name", it)
199+
return@onEach
200+
}
201+
164202
SharingCommand.STOP,
165203
SharingCommand.STOP_AND_RESET_REPLAY_CACHE,
166-
-> unsubscribe(name).getOrThrow()
204+
-> unsubscribe(name).getOrElse {
205+
logger.error("Failed to unsubscribe from property $name", it)
206+
return@onEach
207+
}
167208
}
168209
}
169210
}
170211
}
171212

172213
suspend fun propertyFlowString(name: String): StateFlow<String?> {
173-
val initialValue = getPropertyStringAsync(name).getOrThrow()
214+
val initialValue = getPropertyStringAsync(name).getOrElse {
215+
logger.error("Failed to get initial value for property $name", it)
216+
null
217+
}
174218
val flow = propertyUpdates.filter { it.name == name }.filterIsInstance<StringProperty>().map { it.value }
175219
return flow.stateIn(
176220
CoroutineScope(EmptyCoroutineContext),
@@ -180,7 +224,10 @@ class MPV : NativeCleanable, MPVAsyncListener {
180224
}
181225

182226
suspend fun propertyFlowLong(name: String): StateFlow<Long?> {
183-
val initialValue = getPropertyLongAsync(name).getOrThrow()
227+
val initialValue = getPropertyLongAsync(name).getOrElse {
228+
logger.error("Failed to get initial value for property $name", it)
229+
null
230+
}
184231
val flow = propertyUpdates.filter { it.name == name }.filterIsInstance<LongProperty>().map { it.value }
185232
return flow.stateIn(
186233
CoroutineScope(EmptyCoroutineContext),
@@ -190,7 +237,10 @@ class MPV : NativeCleanable, MPVAsyncListener {
190237
}
191238

192239
suspend fun propertyFlowDouble(name: String): StateFlow<Double?> {
193-
val initialValue = getPropertyDoubleAsync(name).getOrThrow()
240+
val initialValue = getPropertyDoubleAsync(name).getOrElse {
241+
logger.error("Failed to get initial value for property $name", it)
242+
null
243+
}
194244
val flow = propertyUpdates.filter { it.name == name }.filterIsInstance<DoubleProperty>().map { it.value }
195245
return flow.stateIn(
196246
CoroutineScope(EmptyCoroutineContext),
@@ -200,7 +250,10 @@ class MPV : NativeCleanable, MPVAsyncListener {
200250
}
201251

202252
suspend fun propertyFlowFlag(name: String): StateFlow<Boolean?> {
203-
val initialValue = getPropertyFlagAsync(name).getOrThrow()
253+
val initialValue = getPropertyFlagAsync(name).getOrElse {
254+
logger.error("Failed to get initial value for property $name", it)
255+
null
256+
}
204257
val flow = propertyUpdates.filter { it.name == name }.filterIsInstance<FlagProperty>().map { it.value }
205258
return flow.stateIn(
206259
CoroutineScope(EmptyCoroutineContext),
@@ -251,20 +304,27 @@ class MPV : NativeCleanable, MPVAsyncListener {
251304
else -> error("Unsupported property type: ${value::class}")
252305
}
253306

254-
fun command(command: Array<String>) = commandN(nativePointer.address, command)
255-
fun command(command: String) = commandStringN(nativePointer.address, command)
307+
fun command(command: Array<String>) = guard(Unit) {
308+
commandN(nativePointer.address, command)
309+
}
310+
311+
fun command(command: String) = guard(Unit) {
312+
commandStringN(nativePointer.address, command)
313+
}
256314

257315
@JvmName("commandAsyncVararg")
258316
suspend fun commandAsync(vararg command: String) = commandAsync(command.toList().toTypedArray())
259-
suspend fun commandAsync(command: Array<String>): Result<Unit> = suspendCancellableCoroutine { continuation ->
260-
val subscriptionId = commandReplyCallbackId.getAndIncrement()
261-
commandReplyCallbacks[subscriptionId] = { result ->
262-
continuation.resume(result)
263-
}
264-
commandAsyncN(nativePointer.address, command, subscriptionId).onFailure {
265-
commandReplyCallbacks.remove(subscriptionId)
266-
logger.error("Failed to execute command $command", it)
267-
continuation.resume(Result.failure(it))
317+
suspend fun commandAsync(command: Array<String>): Result<Unit> = guard(Unit) {
318+
suspendCancellableCoroutine { continuation ->
319+
val subscriptionId = commandReplyCallbackId.getAndIncrement()
320+
commandReplyCallbacks[subscriptionId] = { result ->
321+
continuation.resume(result)
322+
}
323+
commandAsyncN(nativePointer.address, command, subscriptionId).onFailure {
324+
commandReplyCallbacks.remove(subscriptionId)
325+
logger.error("Failed to execute command $command", it)
326+
continuation.resume(Result.failure(it))
327+
}
268328
}
269329
}
270330

@@ -304,8 +364,20 @@ class MPV : NativeCleanable, MPVAsyncListener {
304364
createRenderN(mpv.nativePointer.address, this, advancedControl).getOrThrow()
305365
.asNativePointer(::destroyRenderN)
306366

307-
fun render(fbo: FBO): Result<Unit> {
308-
return renderN(
367+
@OptIn(ExperimentalContracts::class)
368+
private fun <R> guard(other: R? = null, block: () -> Result<R>): Result<R> {
369+
contract { callsInPlace(block, InvocationKind.AT_MOST_ONCE) }
370+
if (nativePointer.closed) {
371+
if (other != null) {
372+
return Result.success(other)
373+
}
374+
return Result.failure(IllegalStateException("Render is closed"))
375+
}
376+
return block()
377+
}
378+
379+
fun render(fbo: FBO): Result<Unit> = guard(Unit) {
380+
renderN(
309381
nativePointer.address,
310382
fbo.id,
311383
fbo.size.width,
@@ -325,6 +397,11 @@ class MPV : NativeCleanable, MPVAsyncListener {
325397
}
326398
}
327399

400+
override fun close() {
401+
initialized.set(false)
402+
super.close()
403+
}
404+
328405
companion object {
329406
private val logger = LoggerFactory.getLogger(MPV::class.java)
330407

src/main/kotlin/dev/silenium/multimedia/core/util/NativePointer.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ import org.slf4j.LoggerFactory
44
import java.util.concurrent.atomic.AtomicBoolean
55

66
data class NativePointer(val address: Long, val clean: (Long) -> Unit) : AutoCloseable {
7-
private val closed = AtomicBoolean(false)
7+
private val _closed = AtomicBoolean(false)
8+
val closed get() = _closed.get()
89
override fun close() {
910
if (address == 0L) {
1011
logger.warn("Attempt to close NULL NativePointer")
1112
return
1213
}
13-
if (closed.compareAndSet(false, true)) {
14+
if (_closed.compareAndSet(false, true)) {
1415
clean(address)
1516
} else {
1617
logger.warn("Attempt to close already closed NativePointer: $this")

0 commit comments

Comments
 (0)