Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions app/src/main/java/to/bitkit/ext/Numbers.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,3 @@ fun ULong.toActivityItemDate(): String {
fun ULong.toActivityItemTime(): String {
return Instant.ofEpochSecond(this.toLong()).formatted(DatePattern.ACTIVITY_TIME)
}

// TODO replace all usages of faulty `(ULong - ULong).coerceAtLeast(0u)`
/**
* Safely subtracts [other] from this ULong, returning 0 if the result would be negative,
* to prevent ULong wraparound by checking before subtracting, same as `x.saturating_sub(y)` in Rust.
*/
infix fun ULong.minusOrZero(other: ULong): ULong = if (this >= other) this - other else 0uL
31 changes: 31 additions & 0 deletions app/src/main/java/to/bitkit/models/USat.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package to.bitkit.models

/**
* A wrapper for [ULong] that provides saturating arithmetic operations.
* All operations prevent overflow/underflow by clamping to valid range [0, [ULong.MAX_VALUE]].
* Similar to Rust's saturating arithmetic (e.g., `x.saturating_sub(y)`).
*/
@JvmInline
value class USat(val value: ULong) : Comparable<USat> {

override fun compareTo(other: USat): Int = value.compareTo(other.value)

/** Saturating subtraction: returns 0 if result would be negative. */
operator fun minus(other: USat): ULong =
if (value >= other.value) value - other.value else 0uL

/** Saturating addition: caps at ULong.MAX_VALUE if result would overflow. */
operator fun plus(other: USat): ULong =
if (value <= ULong.MAX_VALUE - other.value) value + other.value else ULong.MAX_VALUE
}

/**
* Wraps this ULong in a [USat] for saturating arithmetic operations.
* Use this when performing arithmetic that could overflow/underflow.
*
* Example:
* ```
* val result = a.safe() - b.safe() // Returns 0 if a < b instead of wrapping
* ```
*/
fun ULong.safe(): USat = USat(this)
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import to.bitkit.data.SettingsStore
import to.bitkit.data.entities.TransferEntity
import to.bitkit.ext.amountSats
import to.bitkit.ext.channelId
import to.bitkit.ext.minusOrZero
import to.bitkit.ext.totalNextOutboundHtlcLimitSats
import to.bitkit.models.BalanceState
import to.bitkit.models.safe
import to.bitkit.repositories.LightningRepo
import to.bitkit.repositories.TransferRepo
import to.bitkit.utils.Logger
Expand All @@ -32,12 +32,11 @@ class DeriveBalanceStateUseCase @Inject constructor(
val pendingChannelsSats = getPendingChannelsSats(activeTransfers, channels, balanceDetails)

val toSavingsAmount = getTransferToSavingsSats(activeTransfers, channels, balanceDetails)
val toSpendingAmount = paidOrdersSats + pendingChannelsSats
val toSpendingAmount = paidOrdersSats.safe() + pendingChannelsSats.safe()

val totalOnchainSats = balanceDetails.totalOnchainBalanceSats
val totalLightningSats = balanceDetails.totalLightningBalanceSats
.minusOrZero(pendingChannelsSats)
.minusOrZero(toSavingsAmount)
val afterPendingChannels = balanceDetails.totalLightningBalanceSats.safe() - pendingChannelsSats.safe()
val totalLightningSats = afterPendingChannels.safe() - toSavingsAmount.safe()

val balanceState = BalanceState(
totalOnchainSats = totalOnchainSats,
Expand Down Expand Up @@ -113,7 +112,7 @@ class DeriveBalanceStateUseCase @Inject constructor(
Logger.debug("Could not calculate max send amount, using fallback of: $fallback", context = TAG)
}.getOrDefault(fallback)

return spendableOnchainSats.minusOrZero(fee)
return spendableOnchainSats.safe() - fee.safe()
}

companion object {
Expand Down
15 changes: 8 additions & 7 deletions app/src/main/java/to/bitkit/viewmodels/TransferViewModel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import to.bitkit.models.EUR_CURRENCY
import to.bitkit.models.Toast
import to.bitkit.models.TransactionSpeed
import to.bitkit.models.TransferType
import to.bitkit.models.safe
import to.bitkit.repositories.BlocktankRepo
import to.bitkit.repositories.CurrencyRepo
import to.bitkit.repositories.LightningRepo
Expand Down Expand Up @@ -304,7 +305,7 @@ class TransferViewModel @Inject constructor(
maxLspFee = estimate.feeSat

// Calculate the available balance to send after LSP fee
val balanceAfterLspFee = availableAmount - maxLspFee
val balanceAfterLspFee = availableAmount.safe() - maxLspFee.safe()

_spendingUiState.update {
// Calculate the max available to send considering the current balance and LSP policy
Expand Down Expand Up @@ -380,11 +381,11 @@ class TransferViewModel @Inject constructor(
val maxChannelSize1 = (maxChannelSizeSat.toDouble() * 0.98).roundToLong().toULong()

// The maximum channel size the user can open including existing channels
val maxChannelSize2 = (maxChannelSize1 - channelsSize).coerceAtLeast(0u)
val maxChannelSize2 = maxChannelSize1.safe() - channelsSize.safe()
val maxChannelSizeAvailableToIncrease = min(maxChannelSize1, maxChannelSize2)

val minLspBalance = getMinLspBalance(clientBalanceSat, minChannelSizeSat)
val maxLspBalance = (maxChannelSizeAvailableToIncrease - clientBalanceSat).coerceAtLeast(0u)
val maxLspBalance = maxChannelSizeAvailableToIncrease.safe() - clientBalanceSat.safe()
val defaultLspBalance = getDefaultLspBalance(clientBalanceSat, maxLspBalance)
val maxClientBalance = getMaxClientBalance(maxChannelSizeAvailableToIncrease)

Expand Down Expand Up @@ -436,11 +437,11 @@ class TransferViewModel @Inject constructor(
}

val lspBalance = if (clientBalanceSat < threshold1) { // 0-225€: LSP balance = 450€ - client balance
defaultLspBalanceSats - clientBalanceSat
defaultLspBalanceSats.safe() - clientBalanceSat.safe()
} else if (clientBalanceSat < threshold2) { // 225-495€: LSP balance = client balance
clientBalanceSat
} else if (clientBalanceSat < maxLspBalance) { // 495-950€: LSP balance = max - client balance
maxLspBalance - clientBalanceSat
maxLspBalance.safe() - clientBalanceSat.safe()
} else {
maxLspBalance
}
Expand All @@ -452,7 +453,7 @@ class TransferViewModel @Inject constructor(
// LSP balance must be at least 2.5% of the channel size for LDK to accept (reserve balance)
val ldkMinimum = (clientBalance.toDouble() * 0.025).toULong()
// Channel size must be at least minChannelSize
val lspMinimum = if (minChannelSize > clientBalance) minChannelSize - clientBalance else 0u
val lspMinimum = minChannelSize.safe() - clientBalance.safe()

return max(ldkMinimum, lspMinimum)
}
Expand All @@ -461,7 +462,7 @@ class TransferViewModel @Inject constructor(
// Remote balance must be at least 2.5% of the channel size for LDK to accept (reserve balance)
val minRemoteBalance = (maxChannelSize.toDouble() * 0.025).toULong()

return maxChannelSize - minRemoteBalance
return maxChannelSize.safe() - minRemoteBalance.safe()
}

/** Calculates the total value of channels connected to Blocktank nodes */
Expand Down
127 changes: 127 additions & 0 deletions app/src/test/java/to/bitkit/models/USatTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package to.bitkit.models

import org.junit.Test
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue

class USatTest {

// region Subtraction
@Test
fun `minus returns difference when a greater than b`() {
val result = USat(10uL) - USat(5uL)
assertEquals(5uL, result)
}

@Test
fun `minus returns zero when a equals b`() {
val result = USat(5uL) - USat(5uL)
assertEquals(0uL, result)
}

@Test
fun `minus returns zero when would underflow`() {
val result = USat(5uL) - USat(10uL)
assertEquals(0uL, result)
}

@Test
fun `minus handles max ULong values`() {
val result = USat(0uL) - USat(ULong.MAX_VALUE)
assertEquals(0uL, result)
}

@Test
fun `chained minus operations work correctly`() {
val intermediate = 100uL.safe() - 30uL.safe()
val result = intermediate.safe() - 20uL.safe()
assertEquals(50uL, result)
}

@Test
fun `chained minus returns zero when intermediate would underflow`() {
val intermediate = 10uL.safe() - 20uL.safe()
val result = intermediate.safe() - 5uL.safe()
assertEquals(0uL, result)
}
// endregion

// region Addition
@Test
fun `plus returns sum`() {
val result = USat(10uL) + USat(5uL)
assertEquals(15uL, result)
}

@Test
fun `plus saturates at max when would overflow`() {
val result = USat(ULong.MAX_VALUE) + USat(1uL)
assertEquals(ULong.MAX_VALUE, result)
}

@Test
fun `plus saturates when both values are large`() {
val result = USat(ULong.MAX_VALUE - 10uL) + USat(20uL)
assertEquals(ULong.MAX_VALUE, result)
}

@Test
fun `chained plus operations work correctly`() {
val intermediate = 10uL.safe() + 20uL.safe()
val result = intermediate.safe() + 30uL.safe()
assertEquals(60uL, result)
}
// endregion

// region Comparisons
@Test
fun `compareTo returns negative when less than`() {
assertTrue(USat(5uL) < USat(10uL))
}

@Test
fun `compareTo returns positive when greater than`() {
assertTrue(USat(10uL) > USat(5uL))
}

@Test
fun `compareTo returns zero when equal`() {
assertEquals(0, USat(10uL).compareTo(USat(10uL)))
}

@Test
fun `comparison operators work correctly`() {
assertTrue(USat(5uL) <= USat(10uL))
assertTrue(USat(10uL) >= USat(5uL))
assertTrue(USat(10uL) <= USat(10uL))
assertTrue(USat(10uL) >= USat(10uL))
assertFalse(USat(10uL) < USat(10uL))
assertFalse(USat(10uL) > USat(10uL))
}
// endregion

// region Realistic scenarios
@Test
fun `realistic bitcoin calculation`() {
val channelSize = 10_000_000uL // 0.1 BTC in sats
val balance = 1_000_000uL // 0.01 BTC in sats

val maxSend = channelSize.safe() - balance.safe()

assertEquals(9_000_000uL, maxSend)
}

@Test
fun `minus prevents the coerceAtLeast bug`() {
// The old pattern: (5u - 10u).coerceAtLeast(0u)
// Would incorrectly return ULong.MAX_VALUE - 4
val old = (5uL - 10uL).coerceAtLeast(0u)
assertTrue(old > 1000000u) // WRONG! Shows the bug

// The new pattern: 5u.safe() - 10u.safe()
val new = 5uL.safe() - 10uL.safe()
assertEquals(0uL, new) // CORRECT!
}
// endregion
}
Loading