diff --git a/app/src/main/java/to/bitkit/ext/Numbers.kt b/app/src/main/java/to/bitkit/ext/Numbers.kt index a746adeca..92de7037f 100644 --- a/app/src/main/java/to/bitkit/ext/Numbers.kt +++ b/app/src/main/java/to/bitkit/ext/Numbers.kt @@ -9,10 +9,3 @@ fun ULong.toActivityItemDate(): String { fun ULong.toActivityItemTime(): String { return Instant.ofEpochSecond(this.toLong()).formatted(DatePattern.ACTIVITY_TIME) } - -// TODO replace all usages of faulty `(ULong - ULong).coerceAtLeast(0u)` -/** - * Safely subtracts [other] from this ULong, returning 0 if the result would be negative, - * to prevent ULong wraparound by checking before subtracting, same as `x.saturating_sub(y)` in Rust. - */ -infix fun ULong.minusOrZero(other: ULong): ULong = if (this >= other) this - other else 0uL diff --git a/app/src/main/java/to/bitkit/models/USat.kt b/app/src/main/java/to/bitkit/models/USat.kt new file mode 100644 index 000000000..7737ebd1d --- /dev/null +++ b/app/src/main/java/to/bitkit/models/USat.kt @@ -0,0 +1,31 @@ +package to.bitkit.models + +/** + * A wrapper for [ULong] that provides saturating arithmetic operations. + * All operations prevent overflow/underflow by clamping to valid range [0, [ULong.MAX_VALUE]]. + * Similar to Rust's saturating arithmetic (e.g., `x.saturating_sub(y)`). + */ +@JvmInline +value class USat(val value: ULong) : Comparable { + + override fun compareTo(other: USat): Int = value.compareTo(other.value) + + /** Saturating subtraction: returns 0 if result would be negative. */ + operator fun minus(other: USat): ULong = + if (value >= other.value) value - other.value else 0uL + + /** Saturating addition: caps at ULong.MAX_VALUE if result would overflow. */ + operator fun plus(other: USat): ULong = + if (value <= ULong.MAX_VALUE - other.value) value + other.value else ULong.MAX_VALUE +} + +/** + * Wraps this ULong in a [USat] for saturating arithmetic operations. + * Use this when performing arithmetic that could overflow/underflow. + * + * Example: + * ``` + * val result = a.safe() - b.safe() // Returns 0 if a < b instead of wrapping + * ``` + */ +fun ULong.safe(): USat = USat(this) diff --git a/app/src/main/java/to/bitkit/usecases/DeriveBalanceStateUseCase.kt b/app/src/main/java/to/bitkit/usecases/DeriveBalanceStateUseCase.kt index a239c3ca7..fd2cd74dd 100644 --- a/app/src/main/java/to/bitkit/usecases/DeriveBalanceStateUseCase.kt +++ b/app/src/main/java/to/bitkit/usecases/DeriveBalanceStateUseCase.kt @@ -7,9 +7,9 @@ import to.bitkit.data.SettingsStore import to.bitkit.data.entities.TransferEntity import to.bitkit.ext.amountSats import to.bitkit.ext.channelId -import to.bitkit.ext.minusOrZero import to.bitkit.ext.totalNextOutboundHtlcLimitSats import to.bitkit.models.BalanceState +import to.bitkit.models.safe import to.bitkit.repositories.LightningRepo import to.bitkit.repositories.TransferRepo import to.bitkit.utils.Logger @@ -32,12 +32,11 @@ class DeriveBalanceStateUseCase @Inject constructor( val pendingChannelsSats = getPendingChannelsSats(activeTransfers, channels, balanceDetails) val toSavingsAmount = getTransferToSavingsSats(activeTransfers, channels, balanceDetails) - val toSpendingAmount = paidOrdersSats + pendingChannelsSats + val toSpendingAmount = paidOrdersSats.safe() + pendingChannelsSats.safe() val totalOnchainSats = balanceDetails.totalOnchainBalanceSats - val totalLightningSats = balanceDetails.totalLightningBalanceSats - .minusOrZero(pendingChannelsSats) - .minusOrZero(toSavingsAmount) + val afterPendingChannels = balanceDetails.totalLightningBalanceSats.safe() - pendingChannelsSats.safe() + val totalLightningSats = afterPendingChannels.safe() - toSavingsAmount.safe() val balanceState = BalanceState( totalOnchainSats = totalOnchainSats, @@ -113,7 +112,7 @@ class DeriveBalanceStateUseCase @Inject constructor( Logger.debug("Could not calculate max send amount, using fallback of: $fallback", context = TAG) }.getOrDefault(fallback) - return spendableOnchainSats.minusOrZero(fee) + return spendableOnchainSats.safe() - fee.safe() } companion object { diff --git a/app/src/main/java/to/bitkit/viewmodels/TransferViewModel.kt b/app/src/main/java/to/bitkit/viewmodels/TransferViewModel.kt index 5aa8fd1b3..ac2cac6ff 100644 --- a/app/src/main/java/to/bitkit/viewmodels/TransferViewModel.kt +++ b/app/src/main/java/to/bitkit/viewmodels/TransferViewModel.kt @@ -36,6 +36,7 @@ import to.bitkit.models.EUR_CURRENCY import to.bitkit.models.Toast import to.bitkit.models.TransactionSpeed import to.bitkit.models.TransferType +import to.bitkit.models.safe import to.bitkit.repositories.BlocktankRepo import to.bitkit.repositories.CurrencyRepo import to.bitkit.repositories.LightningRepo @@ -304,7 +305,7 @@ class TransferViewModel @Inject constructor( maxLspFee = estimate.feeSat // Calculate the available balance to send after LSP fee - val balanceAfterLspFee = availableAmount - maxLspFee + val balanceAfterLspFee = availableAmount.safe() - maxLspFee.safe() _spendingUiState.update { // Calculate the max available to send considering the current balance and LSP policy @@ -380,11 +381,11 @@ class TransferViewModel @Inject constructor( val maxChannelSize1 = (maxChannelSizeSat.toDouble() * 0.98).roundToLong().toULong() // The maximum channel size the user can open including existing channels - val maxChannelSize2 = (maxChannelSize1 - channelsSize).coerceAtLeast(0u) + val maxChannelSize2 = maxChannelSize1.safe() - channelsSize.safe() val maxChannelSizeAvailableToIncrease = min(maxChannelSize1, maxChannelSize2) val minLspBalance = getMinLspBalance(clientBalanceSat, minChannelSizeSat) - val maxLspBalance = (maxChannelSizeAvailableToIncrease - clientBalanceSat).coerceAtLeast(0u) + val maxLspBalance = maxChannelSizeAvailableToIncrease.safe() - clientBalanceSat.safe() val defaultLspBalance = getDefaultLspBalance(clientBalanceSat, maxLspBalance) val maxClientBalance = getMaxClientBalance(maxChannelSizeAvailableToIncrease) @@ -436,11 +437,11 @@ class TransferViewModel @Inject constructor( } val lspBalance = if (clientBalanceSat < threshold1) { // 0-225€: LSP balance = 450€ - client balance - defaultLspBalanceSats - clientBalanceSat + defaultLspBalanceSats.safe() - clientBalanceSat.safe() } else if (clientBalanceSat < threshold2) { // 225-495€: LSP balance = client balance clientBalanceSat } else if (clientBalanceSat < maxLspBalance) { // 495-950€: LSP balance = max - client balance - maxLspBalance - clientBalanceSat + maxLspBalance.safe() - clientBalanceSat.safe() } else { maxLspBalance } @@ -452,7 +453,7 @@ class TransferViewModel @Inject constructor( // LSP balance must be at least 2.5% of the channel size for LDK to accept (reserve balance) val ldkMinimum = (clientBalance.toDouble() * 0.025).toULong() // Channel size must be at least minChannelSize - val lspMinimum = if (minChannelSize > clientBalance) minChannelSize - clientBalance else 0u + val lspMinimum = minChannelSize.safe() - clientBalance.safe() return max(ldkMinimum, lspMinimum) } @@ -461,7 +462,7 @@ class TransferViewModel @Inject constructor( // Remote balance must be at least 2.5% of the channel size for LDK to accept (reserve balance) val minRemoteBalance = (maxChannelSize.toDouble() * 0.025).toULong() - return maxChannelSize - minRemoteBalance + return maxChannelSize.safe() - minRemoteBalance.safe() } /** Calculates the total value of channels connected to Blocktank nodes */ diff --git a/app/src/test/java/to/bitkit/models/USatTest.kt b/app/src/test/java/to/bitkit/models/USatTest.kt new file mode 100644 index 000000000..bb30c9ccb --- /dev/null +++ b/app/src/test/java/to/bitkit/models/USatTest.kt @@ -0,0 +1,127 @@ +package to.bitkit.models + +import org.junit.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class USatTest { + + // region Subtraction + @Test + fun `minus returns difference when a greater than b`() { + val result = USat(10uL) - USat(5uL) + assertEquals(5uL, result) + } + + @Test + fun `minus returns zero when a equals b`() { + val result = USat(5uL) - USat(5uL) + assertEquals(0uL, result) + } + + @Test + fun `minus returns zero when would underflow`() { + val result = USat(5uL) - USat(10uL) + assertEquals(0uL, result) + } + + @Test + fun `minus handles max ULong values`() { + val result = USat(0uL) - USat(ULong.MAX_VALUE) + assertEquals(0uL, result) + } + + @Test + fun `chained minus operations work correctly`() { + val intermediate = 100uL.safe() - 30uL.safe() + val result = intermediate.safe() - 20uL.safe() + assertEquals(50uL, result) + } + + @Test + fun `chained minus returns zero when intermediate would underflow`() { + val intermediate = 10uL.safe() - 20uL.safe() + val result = intermediate.safe() - 5uL.safe() + assertEquals(0uL, result) + } + // endregion + + // region Addition + @Test + fun `plus returns sum`() { + val result = USat(10uL) + USat(5uL) + assertEquals(15uL, result) + } + + @Test + fun `plus saturates at max when would overflow`() { + val result = USat(ULong.MAX_VALUE) + USat(1uL) + assertEquals(ULong.MAX_VALUE, result) + } + + @Test + fun `plus saturates when both values are large`() { + val result = USat(ULong.MAX_VALUE - 10uL) + USat(20uL) + assertEquals(ULong.MAX_VALUE, result) + } + + @Test + fun `chained plus operations work correctly`() { + val intermediate = 10uL.safe() + 20uL.safe() + val result = intermediate.safe() + 30uL.safe() + assertEquals(60uL, result) + } + // endregion + + // region Comparisons + @Test + fun `compareTo returns negative when less than`() { + assertTrue(USat(5uL) < USat(10uL)) + } + + @Test + fun `compareTo returns positive when greater than`() { + assertTrue(USat(10uL) > USat(5uL)) + } + + @Test + fun `compareTo returns zero when equal`() { + assertEquals(0, USat(10uL).compareTo(USat(10uL))) + } + + @Test + fun `comparison operators work correctly`() { + assertTrue(USat(5uL) <= USat(10uL)) + assertTrue(USat(10uL) >= USat(5uL)) + assertTrue(USat(10uL) <= USat(10uL)) + assertTrue(USat(10uL) >= USat(10uL)) + assertFalse(USat(10uL) < USat(10uL)) + assertFalse(USat(10uL) > USat(10uL)) + } + // endregion + + // region Realistic scenarios + @Test + fun `realistic bitcoin calculation`() { + val channelSize = 10_000_000uL // 0.1 BTC in sats + val balance = 1_000_000uL // 0.01 BTC in sats + + val maxSend = channelSize.safe() - balance.safe() + + assertEquals(9_000_000uL, maxSend) + } + + @Test + fun `minus prevents the coerceAtLeast bug`() { + // The old pattern: (5u - 10u).coerceAtLeast(0u) + // Would incorrectly return ULong.MAX_VALUE - 4 + val old = (5uL - 10uL).coerceAtLeast(0u) + assertTrue(old > 1000000u) // WRONG! Shows the bug + + // The new pattern: 5u.safe() - 10u.safe() + val new = 5uL.safe() - 10uL.safe() + assertEquals(0uL, new) // CORRECT! + } + // endregion +}