Skip to content
Open
15 changes: 15 additions & 0 deletions app/src/main/java/to/bitkit/di/HttpModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ import io.ktor.client.plugins.defaultRequest
import io.ktor.client.plugins.logging.LogLevel
import io.ktor.client.plugins.logging.Logging
import io.ktor.client.plugins.logging.LoggingConfig
import io.ktor.client.request.head
import io.ktor.http.ContentType
import io.ktor.http.contentType
import io.ktor.http.isSuccess
import io.ktor.serialization.kotlinx.json.json
import kotlinx.serialization.json.Json
import to.bitkit.utils.UrlValidator
import to.bitkit.utils.AppError
import to.bitkit.utils.Logger
import javax.inject.Singleton
import io.ktor.client.plugins.logging.Logger as KtorLogger
Expand Down Expand Up @@ -43,6 +47,17 @@ object HttpModule {
}
}

@Provides
@Singleton
fun provideUrlValidator(httpClient: HttpClient) = UrlValidator { url ->
runCatching {
val response = httpClient.head(url)
if (!response.status.isSuccess()) {
throw AppError("Server returned '${response.status}'")
}
}
}

@Suppress("MagicNumber")
private fun HttpTimeoutConfig.defaultTimeoutConfig() {
requestTimeoutMillis = 60_000
Expand Down
12 changes: 12 additions & 0 deletions app/src/main/java/to/bitkit/repositories/LightningRepo.kt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ import to.bitkit.services.NodeEventHandler
import to.bitkit.utils.AppError
import to.bitkit.utils.Logger
import to.bitkit.utils.ServiceError
import to.bitkit.utils.UrlValidator
import java.io.File
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean
Expand All @@ -105,6 +106,7 @@ class LightningRepo @Inject constructor(
private val preActivityMetadataRepo: PreActivityMetadataRepo,
private val connectivityRepo: ConnectivityRepo,
private val vssBackupClientLdk: VssBackupClientLdk,
private val urlValidator: UrlValidator,
) {
private val _lightningState = MutableStateFlow(LightningState())
val lightningState = _lightningState.asStateFlow()
Expand Down Expand Up @@ -619,6 +621,8 @@ class LightningRepo @Inject constructor(
suspend fun restartWithRgsServer(newRgsUrl: String): Result<Unit> = withContext(bgDispatcher) {
Logger.info("Changing ldk-node RGS server to: '$newRgsUrl'", context = TAG)

validateRgsUrl(newRgsUrl).onFailure { return@withContext Result.failure(it) }

waitForNodeToStop().onFailure { return@withContext Result.failure(it) }
stop().onFailure {
Logger.error("Failed to stop node during RGS server change", it, context = TAG)
Expand All @@ -640,6 +644,14 @@ class LightningRepo @Inject constructor(
}
}

private suspend fun validateRgsUrl(url: String): Result<Unit> = withContext(bgDispatcher) {
val initialTimestamp = 0
val testUrl = "${url.trimEnd('/')}/$initialTimestamp"
urlValidator.validate(testUrl).onFailure {
Logger.warn("RGS server unreachable at '$testUrl'", it, context = TAG)
}
}

suspend fun getBalanceForAddressType(addressType: AddressType): Result<ULong> = withContext(bgDispatcher) {
executeWhenNodeRunning("getBalanceForAddressType") {
runCatching {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package to.bitkit.ui.settings.advanced

import androidx.compose.runtime.Stable
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
Expand All @@ -15,7 +18,9 @@ import to.bitkit.data.SettingsStore
import to.bitkit.di.BgDispatcher
import to.bitkit.env.Env
import to.bitkit.repositories.LightningRepo
import java.net.URI
import javax.inject.Inject
import kotlin.time.Duration.Companion.seconds

@HiltViewModel
class RgsServerViewModel @Inject constructor(
Expand All @@ -24,9 +29,20 @@ class RgsServerViewModel @Inject constructor(
private val lightningRepo: LightningRepo,
) : ViewModel() {

companion object {
private val HOSTNAME_PATTERN = Regex(
"^([a-z\\d]([a-z\\d-]*[a-z\\d])*\\.)+[a-z]{2,}|(\\d{1,3}\\.){3}\\d{1,3}$",
RegexOption.IGNORE_CASE,
)
private val PATH_PATTERN = Regex("^(/[a-zA-Z\\d_.~%+-]*)*$")
private val VALIDATION_DEBOUNCE = 1.seconds
}

private val _uiState = MutableStateFlow(RgsServerUiState())
val uiState: StateFlow<RgsServerUiState> = _uiState.asStateFlow()

private var validationJob: Job? = null

init {
observeState()
}
Expand All @@ -47,17 +63,20 @@ class RgsServerViewModel @Inject constructor(
}

fun setRgsUrl(url: String) {
_uiState.update {
val newState = it.copy(rgsUrl = url.trim())
computeState(newState)
}
_uiState.update { it.copy(rgsUrl = url.trim()) }
debounceValidation()
}

fun resetToDefault() {
val defaultUrl = Env.ldkRgsServerUrl ?: ""
_uiState.update {
val newState = it.copy(rgsUrl = defaultUrl)
computeState(newState)
_uiState.update { it.copy(rgsUrl = Env.ldkRgsServerUrl ?: "") }
debounceValidation()
}

private fun debounceValidation() {
validationJob?.cancel()
validationJob = viewModelScope.launch(bgDispatcher) {
delay(VALIDATION_DEBOUNCE)
_uiState.update { computeState(it) }
}
}

Expand Down Expand Up @@ -110,23 +129,27 @@ class RgsServerViewModel @Inject constructor(
}

private fun isValidURL(data: String): Boolean {
val pattern = Regex(
"^(https?://)?" + // protocol
"((([a-z\\d]([a-z\\d-]*[a-z\\d])*)\\.)+[a-z]{2,}|" + // domain name
"((\\d{1,3}\\.){3}\\d{1,3}))" + // IP (v4) address
"(:\\d+)?(/[-a-z\\d%_.~+]*)*", // port and path
RegexOption.IGNORE_CASE
)

// Allow localhost in development mode
if (Env.isDebug && data.contains("localhost")) {
return true
val normalized = if (!data.startsWith("http://") && !data.startsWith("https://")) {
"https://$data"
} else {
data
}

return pattern.matches(data)
return runCatching {
val uri = URI(normalized)
val hostname = uri.host ?: return false

if (Env.isDebug && hostname == "localhost") return true

if (!HOSTNAME_PATTERN.matches(hostname)) return false

val path = uri.path.orEmpty()
path.isEmpty() || PATH_PATTERN.matches(path)
}.getOrDefault(false)
}
}

@Stable
data class RgsServerUiState(
val connectedRgsUrl: String? = null,
val rgsUrl: String = "",
Expand Down
5 changes: 5 additions & 0 deletions app/src/main/java/to/bitkit/utils/UrlValidator.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package to.bitkit.utils

fun interface UrlValidator {
suspend fun validate(url: String): Result<Unit>
}
47 changes: 0 additions & 47 deletions app/src/main/java/to/bitkit/viewmodels/WalletViewModel.kt
Original file line number Diff line number Diff line change
Expand Up @@ -328,53 +328,6 @@ class WalletViewModel @Inject constructor(
}
}

private suspend fun checkForOrphanedChannelMonitorRecovery() {
if (migrationService.isChannelRecoveryChecked()) return

Logger.info("Running one-time channel monitor recovery check", context = TAG)

val allMonitorsRetrieved = runCatching {
val allRetrieved = migrationService.fetchRNRemoteLdkData()
// don't overwrite channel manager, we only need the monitors for the sweep
val channelMigration = buildChannelMigrationIfAvailable()?.let {
ChannelDataMigration(channelManager = null, channelMonitors = it.channelMonitors)
}

if (channelMigration == null) {
Logger.info("No channel monitors found on RN backup", context = TAG)
return@runCatching allRetrieved
}

Logger.info(
"Found ${channelMigration.channelMonitors.size} monitors on RN backup, attempting recovery",
context = TAG,
)

lightningRepo.stop().onFailure {
Logger.error("Failed to stop node for channel recovery", it, context = TAG)
}
delay(CHANNEL_RECOVERY_RESTART_DELAY_MS)
lightningRepo.start(channelMigration = channelMigration, shouldRetry = false)
.onSuccess {
migrationService.consumePendingChannelMigration()
walletRepo.syncNodeAndWallet()
walletRepo.syncBalances()
Logger.info("Channel monitor recovery complete", context = TAG)
}
.onFailure {
Logger.error("Failed to restart node after channel recovery", it, context = TAG)
}

allRetrieved
}.getOrDefault(false)

if (allMonitorsRetrieved) {
migrationService.markChannelRecoveryChecked()
} else {
Logger.warn("Some monitors failed to download, will retry on next startup", context = TAG)
}
}

fun stop() {
if (!walletExists) return

Expand Down
75 changes: 75 additions & 0 deletions app/src/test/java/to/bitkit/repositories/LightningRepoTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import to.bitkit.services.LightningService
import to.bitkit.services.LnurlService
import to.bitkit.services.LspNotificationsService
import to.bitkit.test.BaseUnitTest
import to.bitkit.utils.UrlValidator
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
Expand All @@ -72,6 +73,7 @@ class LightningRepoTest : BaseUnitTest() {
private val lnurlService = mock<LnurlService>()
private val connectivityRepo = mock<ConnectivityRepo>()
private val vssBackupClientLdk = mock<VssBackupClientLdk>()
private val urlValidator = UrlValidator { Result.success(Unit) }

@Before
fun setUp() = runBlocking {
Expand All @@ -94,6 +96,7 @@ class LightningRepoTest : BaseUnitTest() {
preActivityMetadataRepo = preActivityMetadataRepo,
connectivityRepo = connectivityRepo,
vssBackupClientLdk = vssBackupClientLdk,
urlValidator = urlValidator,
)
}

Expand Down Expand Up @@ -498,6 +501,78 @@ class LightningRepoTest : BaseUnitTest() {
assertTrue(result.isFailure)
}

@Test
fun `restartWithRgsServer should setup with new rgs server`() = test {
startNodeForTesting()
val customRgsUrl = "https://rgs.example.com/snapshot"
whenever(lightningService.node).thenReturn(null)
whenever(lightningService.stop()).thenReturn(Unit)

val result = sut.restartWithRgsServer(customRgsUrl)

assertTrue(result.isSuccess)
val inOrder = inOrder(lightningService)
inOrder.verify(lightningService).stop()
inOrder.verify(lightningService).setup(any(), isNull(), eq(customRgsUrl), anyOrNull(), anyOrNull())
inOrder.verify(lightningService).start(anyOrNull(), any())
assertEquals(NodeLifecycleState.Running, sut.lightningState.value.nodeLifecycleState)
}

@Test
fun `restartWithRgsServer should handle stop failure`() = test {
startNodeForTesting()
whenever(lightningService.stop()).thenThrow(RuntimeException("Stop failed"))

val result = sut.restartWithRgsServer("https://rgs.example.com/snapshot")

assertTrue(result.isFailure)
}

@Test
fun `restartWithRgsServer should handle start failure and recover`() = test {
startNodeForTesting()
whenever(lightningService.node).thenReturn(null)
whenever(lightningService.stop()).thenReturn(Unit)
whenever(lightningService.setup(any(), isNull(), eq("https://bad.rgs/snapshot"), anyOrNull(), anyOrNull()))
.thenThrow(RuntimeException("Failed to start node"))

val result = sut.restartWithRgsServer("https://bad.rgs/snapshot")

assertTrue(result.isFailure)
}

@Test
fun `restartWithRgsServer should fail when url is unreachable`() = test {
val failingValidator = UrlValidator { Result.failure(Exception("DNS resolution failed")) }
val sutWithFailingValidator = LightningRepo(
bgDispatcher = testDispatcher,
lightningService = lightningService,
settingsStore = settingsStore,
coreService = coreService,
lspNotificationsService = lspNotificationsService,
firebaseMessaging = firebaseMessaging,
keychain = keychain,
lnurlService = lnurlService,
cacheStore = cacheStore,
preActivityMetadataRepo = preActivityMetadataRepo,
connectivityRepo = connectivityRepo,
vssBackupClientLdk = vssBackupClientLdk,
urlValidator = failingValidator,
)
sutWithFailingValidator.setInitNodeLifecycleState()
whenever(lightningService.node).thenReturn(mock())
whenever(lightningService.sync()).thenReturn(Unit)
val blocktank = mock<BlocktankService>()
whenever(coreService.blocktank).thenReturn(blocktank)
whenever(blocktank.info(any())).thenReturn(null)
sutWithFailingValidator.start()

val result = sutWithFailingValidator.restartWithRgsServer("https://rapidsync.lightningdevkit/snapshot")

assertTrue(result.isFailure)
assertEquals("DNS resolution failed", result.exceptionOrNull()?.message)
}

@Test
fun `getFeeRateForSpeed should use provided feeRates`() = test {
val mockFeeRates = mock<FeeRates>()
Expand Down
Loading
Loading