diff --git a/android/app/src/main/java/com/kidsync/app/crypto/TinkCryptoManager.kt b/android/app/src/main/java/com/kidsync/app/crypto/TinkCryptoManager.kt index eaf4eff..3c275c9 100644 --- a/android/app/src/main/java/com/kidsync/app/crypto/TinkCryptoManager.kt +++ b/android/app/src/main/java/com/kidsync/app/crypto/TinkCryptoManager.kt @@ -250,15 +250,11 @@ class TinkCryptoManager( } } - // TODO(SEC4-A-07): Add cross-signature validation for wrapped DEKs. Currently, any device - // in the bucket can wrap a DEK for any other device without proof of authorization. A future - // protocol enhancement should include a signature from the wrapping device over the wrapped - // payload, allowing the recipient to verify that the DEK was wrapped by an authorized - // (attested) device. This requires: - // 1. The wrapper signs (wrappedDekBlob || recipientDeviceId || keyEpoch) with its Ed25519 key - // 2. The signature is transmitted alongside the wrapped DEK - // 3. The recipient verifies the signature against the wrapper's attested signing key - // 4. Key attestation chain must be validated before trusting the signature + // DEFERRED(SEC4-A-07): Cross-signature validation for wrapped DEKs. Any bucket device + // can currently wrap a DEK without proof of authorization. Requires protocol design: + // wrapper signs (wrappedDekBlob || recipientDeviceId || keyEpoch) with Ed25519 key, + // signature is transmitted alongside wrapped DEK, recipient verifies against attested + // signing key. Blocked on attestation signature format specification. override fun unwrapDek( wrappedDek: String, devicePrivateKey: PrivateKey, @@ -449,9 +445,10 @@ class TinkCryptoManager( } } - // TODO(SEC6-A-05): The DEK epoch should come from the server's wrapped key response, not - // from the locally stored currentEpoch. Using local epoch means a replayed wrapped DEK - // could overwrite a newer epoch's key. Accept keyEpoch as a parameter instead. + // DEFERRED(SEC6-A-05): DEK epoch from server. The epoch should come from the server's + // wrapped key response, not local currentEpoch. Using local epoch means a replayed + // wrapped DEK could overwrite a newer epoch's key. Requires server API change to + // include keyEpoch in the wrapped key response. override suspend fun unwrapAndStoreDek( bucketId: String, wrappedDek: String, diff --git a/android/app/src/main/java/com/kidsync/app/data/remote/interceptor/AuthInterceptor.kt b/android/app/src/main/java/com/kidsync/app/data/remote/interceptor/AuthInterceptor.kt index 3f09ac4..52853a0 100644 --- a/android/app/src/main/java/com/kidsync/app/data/remote/interceptor/AuthInterceptor.kt +++ b/android/app/src/main/java/com/kidsync/app/data/remote/interceptor/AuthInterceptor.kt @@ -18,13 +18,10 @@ import okhttp3.Response * Session tokens are stored in EncryptedSharedPreferences to prevent extraction * from device backups or root access. * - * TODO(SEC3-A-21): Implement an OkHttp Authenticator (TokenAuthenticator) to handle - * 401 responses with automatic re-authentication via challenge-response. This would - * transparently retry failed requests after obtaining a new session token, avoiding - * the need for callers to handle 401s explicitly. Implementation requires injecting - * AuthRepository or KeyManager + ApiService (careful to avoid circular dependencies - * with Hilt). Consider using OkHttp's `Authenticator` interface which is specifically - * designed for this purpose. + * SEC3-A-21: 401 responses are now handled by [TokenAuthenticator], which automatically + * re-authenticates via challenge-response and retries the failed request. The circular + * Hilt dependency (NetworkModule -> OkHttpClient -> TokenAuthenticator -> AuthRepository + * -> ApiService -> OkHttpClient) is broken using [dagger.Lazy]. */ class AuthInterceptor( private val prefs: SharedPreferences diff --git a/android/app/src/main/java/com/kidsync/app/data/remote/interceptor/TokenAuthenticator.kt b/android/app/src/main/java/com/kidsync/app/data/remote/interceptor/TokenAuthenticator.kt new file mode 100644 index 0000000..b5958fe --- /dev/null +++ b/android/app/src/main/java/com/kidsync/app/data/remote/interceptor/TokenAuthenticator.kt @@ -0,0 +1,77 @@ +package com.kidsync.app.data.remote.interceptor + +import android.content.SharedPreferences +import com.kidsync.app.domain.repository.AuthRepository +import kotlinx.coroutines.runBlocking +import okhttp3.Authenticator +import okhttp3.Request +import okhttp3.Response +import okhttp3.Route +import javax.inject.Inject +import javax.inject.Named + +/** + * OkHttp Authenticator that handles 401 responses by re-authenticating + * via the Ed25519 challenge-response flow and retrying the failed request. + * + * This replaces the need for callers to handle 401s explicitly. When the + * server returns a 401 Unauthorized, OkHttp invokes this authenticator + * which: + * 1. Checks if a retry has already been attempted (prevents infinite loops) + * 2. Synchronizes concurrent 401s so only one re-authentication occurs + * 3. Performs challenge-response auth via AuthRepository + * 4. Retries the original request with the new session token + * + * Uses [dagger.Lazy] for AuthRepository to break the circular Hilt dependency: + * NetworkModule -> OkHttpClient -> TokenAuthenticator -> AuthRepository -> ApiService -> OkHttpClient + */ +class TokenAuthenticator @Inject constructor( + private val authRepository: dagger.Lazy, + @Named("encrypted_prefs") private val prefs: SharedPreferences +) : Authenticator { + + private val lock = Object() + + /** + * Header added to retried requests to prevent infinite retry loops. + * If the retried request also gets a 401, we return null (give up). + */ + companion object { + internal const val HEADER_AUTH_RETRY = "X-Auth-Retry" + } + + override fun authenticate(route: Route?, response: Response): Request? { + // Prevent infinite retry loops: if we already retried, give up + if (response.request.header(HEADER_AUTH_RETRY) != null) { + return null + } + + synchronized(lock) { + // Check if another thread already refreshed the token since our + // request was made. Compare the token we sent with the current one. + val currentToken = prefs.getString(AuthInterceptor.PREF_SESSION_TOKEN, null) + val requestToken = response.request.header(AuthInterceptor.HEADER_AUTHORIZATION) + ?.removePrefix("Bearer ") + + if (currentToken != null && currentToken != requestToken) { + // Token was already refreshed by another thread -- retry with it + return response.request.newBuilder() + .header(AuthInterceptor.HEADER_AUTHORIZATION, "Bearer $currentToken") + .header(HEADER_AUTH_RETRY, "true") + .build() + } + + // Perform challenge-response authentication + val result = runBlocking { + authRepository.get().authenticate() + } + + val session = result.getOrNull() ?: return null + + return response.request.newBuilder() + .header(AuthInterceptor.HEADER_AUTHORIZATION, "Bearer ${session.sessionToken}") + .header(HEADER_AUTH_RETRY, "true") + .build() + } + } +} diff --git a/android/app/src/main/java/com/kidsync/app/di/DatabaseModule.kt b/android/app/src/main/java/com/kidsync/app/di/DatabaseModule.kt index 729012c..64fda9c 100644 --- a/android/app/src/main/java/com/kidsync/app/di/DatabaseModule.kt +++ b/android/app/src/main/java/com/kidsync/app/di/DatabaseModule.kt @@ -66,8 +66,10 @@ object DatabaseModule { @Suppress("DEPRECATION") builder.fallbackToDestructiveMigration() } - // TODO: Add migration objects here for each version bump in release builds: - // builder.addMigrations(MIGRATION_4_5, MIGRATION_5_6, ...) + // DEFERRED: Room migration objects. Add migrations here for each schema version + // bump in release builds: builder.addMigrations(MIGRATION_X_Y, ...). + // Currently no pending migrations — destructive fallback handles debug builds + // above, and release builds will crash visibly if a migration is missing. return builder.build() } diff --git a/android/app/src/main/java/com/kidsync/app/di/NetworkModule.kt b/android/app/src/main/java/com/kidsync/app/di/NetworkModule.kt index 9c05de5..c7fb4b0 100644 --- a/android/app/src/main/java/com/kidsync/app/di/NetworkModule.kt +++ b/android/app/src/main/java/com/kidsync/app/di/NetworkModule.kt @@ -3,6 +3,7 @@ package com.kidsync.app.di import android.content.SharedPreferences import com.kidsync.app.data.remote.api.ApiService import com.kidsync.app.data.remote.interceptor.AuthInterceptor +import com.kidsync.app.data.remote.interceptor.TokenAuthenticator import com.kidsync.app.BuildConfig import dagger.Module import dagger.Provides @@ -47,9 +48,10 @@ object NetworkModule { @Singleton fun provideOkHttpClientManager( authInterceptor: AuthInterceptor, - loggingInterceptor: HttpLoggingInterceptor + loggingInterceptor: HttpLoggingInterceptor, + tokenAuthenticator: TokenAuthenticator ): OkHttpClientManager { - return OkHttpClientManager(authInterceptor, loggingInterceptor) + return OkHttpClientManager(authInterceptor, loggingInterceptor, tokenAuthenticator) } /** diff --git a/android/app/src/main/java/com/kidsync/app/di/OkHttpClientManager.kt b/android/app/src/main/java/com/kidsync/app/di/OkHttpClientManager.kt index c2b9337..64a3190 100644 --- a/android/app/src/main/java/com/kidsync/app/di/OkHttpClientManager.kt +++ b/android/app/src/main/java/com/kidsync/app/di/OkHttpClientManager.kt @@ -2,6 +2,7 @@ package com.kidsync.app.di import com.kidsync.app.BuildConfig import com.kidsync.app.data.remote.interceptor.AuthInterceptor +import com.kidsync.app.data.remote.interceptor.TokenAuthenticator import okhttp3.CertificatePinner import okhttp3.ConnectionSpec import okhttp3.OkHttpClient @@ -35,7 +36,8 @@ import kotlin.concurrent.write @Singleton class OkHttpClientManager @Inject constructor( private val authInterceptor: AuthInterceptor, - private val loggingInterceptor: HttpLoggingInterceptor + private val loggingInterceptor: HttpLoggingInterceptor, + private val tokenAuthenticator: TokenAuthenticator ) { private val lock = ReentrantReadWriteLock() @@ -80,6 +82,7 @@ class OkHttpClientManager @Inject constructor( .connectionSpecs(listOf(tlsSpec)) .addInterceptor(authInterceptor) .addInterceptor(loggingInterceptor) + .authenticator(tokenAuthenticator) .connectTimeout(30, TimeUnit.SECONDS) .readTimeout(30, TimeUnit.SECONDS) .writeTimeout(30, TimeUnit.SECONDS) diff --git a/android/app/src/main/java/com/kidsync/app/domain/usecase/sync/SnapshotUseCase.kt b/android/app/src/main/java/com/kidsync/app/domain/usecase/sync/SnapshotUseCase.kt index 2f66345..7acc6d1 100644 --- a/android/app/src/main/java/com/kidsync/app/domain/usecase/sync/SnapshotUseCase.kt +++ b/android/app/src/main/java/com/kidsync/app/domain/usecase/sync/SnapshotUseCase.kt @@ -31,6 +31,8 @@ class SnapshotUseCase @Inject constructor( private val custodyScheduleDao: CustodyScheduleDao, private val overrideDao: OverrideDao, private val expenseDao: ExpenseDao, + private val calendarEventDao: CalendarEventDao, + private val infoBankDao: InfoBankDao, private val opLogDao: OpLogDao, private val syncStateDao: SyncStateDao, private val cryptoManager: CryptoManager, @@ -121,10 +123,13 @@ class SnapshotUseCase @Inject constructor( /** * Compute SHA-256 hash of the current materialized state. * - * TODO(SEC6-A-16): Incomplete entity coverage - this hash only covers CustodySchedule, - * ScheduleOverride, and Expense entities. CalendarEvent and InfoBankEntry entities are - * missing, meaning state divergence in those entity types would not be detected by - * snapshot comparison. Add them to the hash computation for full coverage. + * Covers all materialized entity types: CustodySchedule, ScheduleOverride, + * Expense, CalendarEvent, and InfoBankEntry (non-deleted only). + * + * Design: Hashes identity + core value fields per entity (not all columns). + * Metadata fields (createdBy, timestamps, description, location, notes) are + * excluded so the hash detects structural divergence without false positives + * from metadata-only changes that don't affect the materialized schedule. */ private suspend fun computeStateHash(): String { val digest = MessageDigest.getInstance("SHA-256") @@ -150,6 +155,25 @@ class SnapshotUseCase @Inject constructor( digest.update(expense.amountCents.toString().toByteArray()) } + // Hash all calendar events (SEC6-A-16) + val calendarEvents = calendarEventDao.getAllEvents() + for (event in calendarEvents.sortedBy { it.eventId }) { + digest.update(event.eventId.toByteArray()) + digest.update(event.title.toByteArray()) + digest.update(event.startTime.toByteArray()) + digest.update(event.endTime.toByteArray()) + } + + // Hash all non-deleted info bank entries (SEC6-A-16) + val infoBankEntries = infoBankDao.getAllEntries() + for (entry in infoBankEntries.sortedBy { it.entryId.toString() }) { + digest.update(entry.entryId.toString().toByteArray()) + digest.update(entry.childId.toString().toByteArray()) + digest.update(entry.category.toByteArray()) + entry.title?.let { digest.update(it.toByteArray()) } + entry.content?.let { digest.update(it.toByteArray()) } + } + return digest.digest().joinToString("") { "%02x".format(it) } } } diff --git a/android/app/src/main/java/com/kidsync/app/ui/viewmodel/AuthViewModel.kt b/android/app/src/main/java/com/kidsync/app/ui/viewmodel/AuthViewModel.kt index d1d8e05..b0623d6 100644 --- a/android/app/src/main/java/com/kidsync/app/ui/viewmodel/AuthViewModel.kt +++ b/android/app/src/main/java/com/kidsync/app/ui/viewmodel/AuthViewModel.kt @@ -271,10 +271,10 @@ class AuthViewModel @Inject constructor( * 3. Register as new device, authenticate * 4. Unwrap DEK using recovery key * - * TODO(SEC6-A-08): Key ordering during recovery flow - the recovery restores the - * device seed before registering new encryption keys with the server. If restoration - * fails mid-flow, the local seed may not match the server's registered keys. Consider - * ordering: register first, then restore seed, then unwrap DEK. + * DEFERRED(SEC6-A-08): Key ordering during recovery flow. Seed is restored before + * registering new keys with the server — if restoration fails mid-flow, local seed + * may not match server's registered keys. Fix requires an "update device keys" server + * endpoint to re-register keys after seed restore, which is a protocol change. */ fun restoreFromRecovery() { val state = _uiState.value diff --git a/android/app/src/main/java/com/kidsync/app/ui/viewmodel/BucketViewModel.kt b/android/app/src/main/java/com/kidsync/app/ui/viewmodel/BucketViewModel.kt index 91ab5f1..09d0cad 100644 --- a/android/app/src/main/java/com/kidsync/app/ui/viewmodel/BucketViewModel.kt +++ b/android/app/src/main/java/com/kidsync/app/ui/viewmodel/BucketViewModel.kt @@ -23,20 +23,17 @@ import javax.inject.Inject * QR code payload for pairing. * Contains connection info and initiator's key fingerprint but never the DEK. * - * TODO(SEC4-A-10): The invite token [t] is transmitted in plaintext within the QR code. - * This is by design: the QR code is displayed briefly for the co-parent to scan in person, - * and the token is single-use (the server invalidates it after redemption). Additionally, - * PairingScreen sets FLAG_SECURE to prevent screenshots. However, for defense in depth, - * a future enhancement could encrypt the token field using the recipient's public key - * (requires the scanner's public key to be known before pairing, e.g., via a two-phase - * handshake or pre-shared key exchange). + * DEFERRED(SEC4-A-10): Invite token is plaintext in QR code. Acceptable because: QR is + * displayed briefly, PairingScreen sets FLAG_SECURE, and the token is single-use (server + * invalidates after redemption). Encrypting the token would require the scanner's public + * key before pairing — needs a two-phase handshake or pre-shared key exchange protocol. */ @Serializable data class QrPairingPayload( val v: Int = 1, val s: String, // serverUrl val b: String, // bucketId - val t: String, // inviteToken (plaintext -- see SEC4-A-10 TODO above) + val t: String, // inviteToken (plaintext -- see DEFERRED SEC4-A-10 above) val f: String // signingKeyFingerprint of initiator ) @@ -163,13 +160,9 @@ class BucketViewModel @Inject constructor( ) inviteResult.getOrThrow() - // Build QR payload - // TODO(SC-03): The spec says the QR code should contain the signing key fingerprint, - // but we currently use the encryption key fingerprint. Both sides (QR generation and - // verification) consistently use the encryption key, so this works. If the spec is - // updated or the server changes to use signing key fingerprints, update both sides. + // Build QR payload with signing key fingerprint per spec (SC-03) val serverUrl = authRepository.getServerUrl() - val fingerprint = keyManager.getEncryptionKeyFingerprint() + val fingerprint = keyManager.getSigningKeyFingerprint() val payload = QrPairingPayload( v = 1, @@ -390,7 +383,7 @@ class BucketViewModel @Inject constructor( val peerDevicesResult = bucketRepository.getBucketDevices(payload.b) val peerDevices = peerDevicesResult.getOrThrow() val peerVerified = peerDevices.any { device -> - cryptoManager.computeKeyFingerprint(device.encryptionKey) == payload.f + cryptoManager.computeKeyFingerprint(device.signingKey) == payload.f } if (!peerVerified) { @@ -426,7 +419,7 @@ class BucketViewModel @Inject constructor( // Cross-sign the peer's key val peerDevice = peerDevices.first { device -> - cryptoManager.computeKeyFingerprint(device.encryptionKey) == payload.f + cryptoManager.computeKeyFingerprint(device.signingKey) == payload.f } val attestation = keyManager.createKeyAttestation( attestedDeviceId = peerDevice.deviceId, diff --git a/android/app/src/test/java/com/kidsync/app/domain/usecase/SnapshotUseCaseTest.kt b/android/app/src/test/java/com/kidsync/app/domain/usecase/SnapshotUseCaseTest.kt index 7c9ce25..54ddd7d 100644 --- a/android/app/src/test/java/com/kidsync/app/domain/usecase/SnapshotUseCaseTest.kt +++ b/android/app/src/test/java/com/kidsync/app/domain/usecase/SnapshotUseCaseTest.kt @@ -3,10 +3,13 @@ package com.kidsync.app.domain.usecase import com.kidsync.app.crypto.CryptoManager import com.kidsync.app.crypto.KeyManager import com.kidsync.app.data.local.dao.* +import com.kidsync.app.data.local.entity.CalendarEventEntity import com.kidsync.app.data.local.entity.CustodyScheduleEntity import com.kidsync.app.data.local.entity.ExpenseEntity +import com.kidsync.app.data.local.entity.InfoBankEntryEntity import com.kidsync.app.data.local.entity.ScheduleOverrideEntity import com.kidsync.app.data.local.entity.SyncStateEntity +import java.util.UUID import com.kidsync.app.data.remote.api.ApiService import com.kidsync.app.data.remote.dto.UploadSnapshotResponse import com.kidsync.app.domain.usecase.sync.SnapshotUseCase @@ -21,6 +24,8 @@ class SnapshotUseCaseTest : FunSpec({ val custodyScheduleDao = mockk() val overrideDao = mockk() val expenseDao = mockk() + val calendarEventDao = mockk() + val infoBankDao = mockk() val opLogDao = mockk() val syncStateDao = mockk() val cryptoManager = mockk() @@ -28,8 +33,8 @@ class SnapshotUseCaseTest : FunSpec({ val apiService = mockk() fun createUseCase() = SnapshotUseCase( - custodyScheduleDao, overrideDao, expenseDao, opLogDao, syncStateDao, - cryptoManager, keyManager, apiService + custodyScheduleDao, overrideDao, expenseDao, calendarEventDao, infoBankDao, + opLogDao, syncStateDao, cryptoManager, keyManager, apiService ) beforeEach { @@ -56,6 +61,8 @@ class SnapshotUseCaseTest : FunSpec({ coEvery { custodyScheduleDao.getAllSchedules() } returns emptyList() coEvery { overrideDao.getAllOverrides() } returns emptyList() coEvery { expenseDao.getAllExpenses() } returns emptyList() + coEvery { calendarEventDao.getAllEvents() } returns emptyList() + coEvery { infoBankDao.getAllEntries() } returns emptyList() every { cryptoManager.encryptPayload(any(), any(), any()) } returns "encrypted-snapshot-base64" every { cryptoManager.signEd25519(any(), any()) } returns ByteArray(64) { 0xAA.toByte() } @@ -186,4 +193,129 @@ class SnapshotUseCaseTest : FunSpec({ result2.isSuccess shouldBe true result1.getOrNull() shouldNotBe result2.getOrNull() } + + // ── SEC6-A-16: Calendar event and info bank entity coverage ───────── + + test("state hash changes when calendar event is added") { + setupHappyPath() + val useCase1 = createUseCase() + val result1 = useCase1.createSnapshot(bucketId) + + setupHappyPath() + coEvery { calendarEventDao.getAllEvents() } returns listOf( + CalendarEventEntity( + eventId = "evt-1", + childId = "child-1", + title = "Doctor visit", + startTime = "2026-03-01T10:00:00Z", + endTime = "2026-03-01T11:00:00Z", + clientTimestamp = "2026-02-20T12:00:00Z" + ) + ) + + val useCase2 = createUseCase() + val result2 = useCase2.createSnapshot(bucketId) + + result1.isSuccess shouldBe true + result2.isSuccess shouldBe true + result1.getOrNull() shouldNotBe result2.getOrNull() + } + + test("state hash changes when info bank entry is added") { + setupHappyPath() + val useCase1 = createUseCase() + val result1 = useCase1.createSnapshot(bucketId) + + setupHappyPath() + coEvery { infoBankDao.getAllEntries() } returns listOf( + InfoBankEntryEntity( + entryId = UUID.fromString("11111111-1111-1111-1111-111111111111"), + childId = UUID.fromString("22222222-2222-2222-2222-222222222222"), + category = "MEDICAL", + title = "Allergies", + content = "Peanuts, shellfish" + ) + ) + + val useCase2 = createUseCase() + val result2 = useCase2.createSnapshot(bucketId) + + result1.isSuccess shouldBe true + result2.isSuccess shouldBe true + result1.getOrNull() shouldNotBe result2.getOrNull() + } + + test("state hash is stable for same state") { + setupHappyPath() + coEvery { calendarEventDao.getAllEvents() } returns listOf( + CalendarEventEntity( + eventId = "evt-stable", + childId = "child-1", + title = "Playdate", + startTime = "2026-04-01T14:00:00Z", + endTime = "2026-04-01T16:00:00Z", + clientTimestamp = "2026-03-15T09:00:00Z" + ) + ) + coEvery { infoBankDao.getAllEntries() } returns listOf( + InfoBankEntryEntity( + entryId = UUID.fromString("33333333-3333-3333-3333-333333333333"), + childId = UUID.fromString("44444444-4444-4444-4444-444444444444"), + category = "EMERGENCY", + content = "Mom: 555-1234" + ) + ) + + val useCase1 = createUseCase() + val result1 = useCase1.createSnapshot(bucketId) + + // Re-setup with exact same data + setupHappyPath() + coEvery { calendarEventDao.getAllEvents() } returns listOf( + CalendarEventEntity( + eventId = "evt-stable", + childId = "child-1", + title = "Playdate", + startTime = "2026-04-01T14:00:00Z", + endTime = "2026-04-01T16:00:00Z", + clientTimestamp = "2026-03-15T09:00:00Z" + ) + ) + coEvery { infoBankDao.getAllEntries() } returns listOf( + InfoBankEntryEntity( + entryId = UUID.fromString("33333333-3333-3333-3333-333333333333"), + childId = UUID.fromString("44444444-4444-4444-4444-444444444444"), + category = "EMERGENCY", + content = "Mom: 555-1234" + ) + ) + + val useCase2 = createUseCase() + val result2 = useCase2.createSnapshot(bucketId) + + result1.isSuccess shouldBe true + result2.isSuccess shouldBe true + result1.getOrNull() shouldBe result2.getOrNull() + } + + test("deleted info bank entries are excluded from hash via getAllEntries") { + // getAllEntries() DAO query filters deleted = 0, so deleted entries + // should never be returned. Verify by checking that if we mock + // getAllEntries to return nothing (simulating all deleted), the hash + // matches the baseline with no entries. + setupHappyPath() + // Baseline: no info bank entries + val useCase1 = createUseCase() + val result1 = useCase1.createSnapshot(bucketId) + + // Same setup: getAllEntries returns empty (all entries are deleted) + setupHappyPath() + // infoBankDao.getAllEntries() already returns emptyList() from setupHappyPath + val useCase2 = createUseCase() + val result2 = useCase2.createSnapshot(bucketId) + + result1.isSuccess shouldBe true + result2.isSuccess shouldBe true + result1.getOrNull() shouldBe result2.getOrNull() + } }) diff --git a/android/app/src/test/java/com/kidsync/app/interceptor/TokenAuthenticatorTest.kt b/android/app/src/test/java/com/kidsync/app/interceptor/TokenAuthenticatorTest.kt new file mode 100644 index 0000000..b2a82b7 --- /dev/null +++ b/android/app/src/test/java/com/kidsync/app/interceptor/TokenAuthenticatorTest.kt @@ -0,0 +1,205 @@ +package com.kidsync.app.interceptor + +import android.content.SharedPreferences +import com.kidsync.app.data.remote.interceptor.AuthInterceptor +import com.kidsync.app.data.remote.interceptor.TokenAuthenticator +import com.kidsync.app.domain.model.DeviceSession +import com.kidsync.app.domain.repository.AuthRepository +import io.kotest.core.spec.style.FunSpec +import io.kotest.matchers.shouldBe +import io.kotest.matchers.shouldNotBe +import io.mockk.clearAllMocks +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.every +import io.mockk.mockk +import okhttp3.Protocol +import okhttp3.Request +import okhttp3.Response +import okhttp3.ResponseBody.Companion.toResponseBody +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicInteger + +class TokenAuthenticatorTest : FunSpec({ + + val authRepository = mockk() + val lazyAuthRepository = mockk>() + val prefs = mockk(relaxed = true) + + beforeEach { + clearAllMocks() + every { lazyAuthRepository.get() } returns authRepository + } + + fun createAuthenticator(): TokenAuthenticator { + return TokenAuthenticator(lazyAuthRepository, prefs) + } + + fun build401Response(request: Request): Response { + return Response.Builder() + .request(request) + .protocol(Protocol.HTTP_2) + .code(401) + .message("Unauthorized") + .body("".toResponseBody()) + .build() + } + + // ── Basic token refresh ───────────────────────────────────────────── + + test("401 response triggers re-authentication and retries with new token") { + val request = Request.Builder() + .url("https://api.example.com/buckets") + .header("Authorization", "Bearer old-token") + .build() + val response = build401Response(request) + + // Current token in prefs matches the request token (no other thread refreshed) + every { prefs.getString(AuthInterceptor.PREF_SESSION_TOKEN, null) } returns "old-token" + + // Auth succeeds with new token + coEvery { authRepository.authenticate() } returns Result.success( + DeviceSession("device-1", "new-token", 3600) + ) + + val authenticator = createAuthenticator() + val retryRequest = authenticator.authenticate(null, response) + + retryRequest shouldNotBe null + retryRequest!!.header("Authorization") shouldBe "Bearer new-token" + retryRequest.header(TokenAuthenticator.HEADER_AUTH_RETRY) shouldBe "true" + } + + // ── Infinite loop prevention ──────────────────────────────────────── + + test("second 401 with X-Auth-Retry header returns null to prevent infinite loop") { + val request = Request.Builder() + .url("https://api.example.com/buckets") + .header("Authorization", "Bearer some-token") + .header(TokenAuthenticator.HEADER_AUTH_RETRY, "true") + .build() + val response = build401Response(request) + + val authenticator = createAuthenticator() + val retryRequest = authenticator.authenticate(null, response) + + retryRequest shouldBe null + // authenticate() should never be called + coVerify(exactly = 0) { authRepository.authenticate() } + } + + // ── Token already refreshed by another thread ─────────────────────── + + test("skips re-auth when another thread already refreshed the token") { + val request = Request.Builder() + .url("https://api.example.com/buckets") + .header("Authorization", "Bearer old-token") + .build() + val response = build401Response(request) + + // Prefs now have a different (newer) token than what the request sent + every { prefs.getString(AuthInterceptor.PREF_SESSION_TOKEN, null) } returns "already-refreshed-token" + + val authenticator = createAuthenticator() + val retryRequest = authenticator.authenticate(null, response) + + retryRequest shouldNotBe null + retryRequest!!.header("Authorization") shouldBe "Bearer already-refreshed-token" + retryRequest.header(TokenAuthenticator.HEADER_AUTH_RETRY) shouldBe "true" + + // authenticate() should not be called since token was already refreshed + coVerify(exactly = 0) { authRepository.authenticate() } + } + + // ── Auth failure returns null ──────────────────────────────────────── + + test("returns null when re-authentication fails") { + val request = Request.Builder() + .url("https://api.example.com/buckets") + .header("Authorization", "Bearer expired-token") + .build() + val response = build401Response(request) + + every { prefs.getString(AuthInterceptor.PREF_SESSION_TOKEN, null) } returns "expired-token" + coEvery { authRepository.authenticate() } returns Result.failure( + RuntimeException("Network error") + ) + + val authenticator = createAuthenticator() + val retryRequest = authenticator.authenticate(null, response) + + retryRequest shouldBe null + } + + // ── Request without Authorization header ──────────────────────────── + + test("handles request without Authorization header") { + val request = Request.Builder() + .url("https://api.example.com/buckets") + .build() + val response = build401Response(request) + + // No current token in prefs either + every { prefs.getString(AuthInterceptor.PREF_SESSION_TOKEN, null) } returns null + + coEvery { authRepository.authenticate() } returns Result.success( + DeviceSession("device-1", "fresh-token", 3600) + ) + + val authenticator = createAuthenticator() + val retryRequest = authenticator.authenticate(null, response) + + retryRequest shouldNotBe null + retryRequest!!.header("Authorization") shouldBe "Bearer fresh-token" + } + + // ── Concurrent 401s result in single auth call ────────────────────── + + test("concurrent 401s result in single authentication call") { + val authenticator = createAuthenticator() + val authCallCount = AtomicInteger(0) + + // First call to prefs returns the old token (both threads see same stale token) + // After auth, subsequent calls return new token + every { prefs.getString(AuthInterceptor.PREF_SESSION_TOKEN, null) } returnsMany listOf( + "old-token", // first thread sees old token + "new-token", // second thread sees new token (first already refreshed) + "new-token" // any subsequent calls + ) + + coEvery { authRepository.authenticate() } answers { + authCallCount.incrementAndGet() + Result.success(DeviceSession("device-1", "new-token", 3600)) + } + + val request = Request.Builder() + .url("https://api.example.com/buckets") + .header("Authorization", "Bearer old-token") + .build() + + val latch = CountDownLatch(2) + val results = arrayOfNulls(2) + + // Launch two threads that hit authenticate concurrently + val t1 = Thread { + results[0] = authenticator.authenticate(null, build401Response(request)) + latch.countDown() + } + val t2 = Thread { + results[1] = authenticator.authenticate(null, build401Response(request)) + latch.countDown() + } + + t1.start() + t2.start() + latch.await() + + // Both should get valid retry requests + results[0] shouldNotBe null + results[1] shouldNotBe null + + // Only one authentication call should have been made + // (the second thread sees that the token was already refreshed) + authCallCount.get() shouldBe 1 + } +}) diff --git a/android/app/src/test/java/com/kidsync/app/viewmodel/BucketViewModelTest.kt b/android/app/src/test/java/com/kidsync/app/viewmodel/BucketViewModelTest.kt index feccebb..7d72066 100644 --- a/android/app/src/test/java/com/kidsync/app/viewmodel/BucketViewModelTest.kt +++ b/android/app/src/test/java/com/kidsync/app/viewmodel/BucketViewModelTest.kt @@ -13,7 +13,14 @@ import io.kotest.core.spec.style.FunSpec import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldNotBe import io.kotest.matchers.string.shouldContain -import io.mockk.* +import io.mockk.Runs +import io.mockk.clearAllMocks +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.every +import io.mockk.just +import io.mockk.mockk +import io.mockk.verify import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.test.StandardTestDispatcher @@ -154,8 +161,8 @@ class BucketViewModelTest : FunSpec({ coEvery { cryptoManager.generateInviteToken() } returns "tok-abc-123" coEvery { bucketRepository.createInvite("bucket-inv", "tok-abc-123") } returns Result.success(Unit) - coEvery { authRepository.getServerUrl() } returns "https://api.kidsync.dev" - coEvery { keyManager.getEncryptionKeyFingerprint() } returns "fp:enc:001" + coEvery { authRepository.getServerUrl() } returns "https://api.kidsync.app" + coEvery { keyManager.getSigningKeyFingerprint() } returns "fp:sign:001" val vm = createViewModel() advanceUntilIdle() @@ -308,4 +315,93 @@ class BucketViewModelTest : FunSpec({ vm.uiState.value.error shouldBe null } } + + // ── SC-03: QR signing key fingerprint ───────────────────────────────── + + test("generateInvite uses signing key fingerprint, not encryption key") { + runTest(testDispatcher) { + val bucket = Bucket("bucket-sc03", "device-001", Instant.now()) + coEvery { bucketRepository.createBucket() } returns Result.success(bucket) + coEvery { bucketRepository.storeLocalBucketName(any(), any()) } just Runs + coEvery { cryptoManager.generateAndStoreDek(any()) } just Runs + + coEvery { cryptoManager.generateInviteToken() } returns "tok-sc03" + coEvery { bucketRepository.createInvite("bucket-sc03", "tok-sc03") } returns Result.success(Unit) + coEvery { authRepository.getServerUrl() } returns "https://api.kidsync.app" + coEvery { keyManager.getSigningKeyFingerprint() } returns "fp:signing:abc123" + + val vm = createViewModel() + advanceUntilIdle() + + vm.createBucket() + advanceUntilIdle() + vm.generateInvite() + advanceUntilIdle() + + // Verify signing key fingerprint was used + coVerify { keyManager.getSigningKeyFingerprint() } + coVerify(exactly = 0) { keyManager.getEncryptionKeyFingerprint() } + + // Verify fingerprint appears in QR payload + val state = vm.uiState.value + state.qrPayload shouldNotBe null + state.qrPayload!! shouldContain "fp:signing:abc123" + } + } + + test("continueJoinBucket compares fingerprint against signing key, not encryption key") { + runTest(testDispatcher) { + val signingFingerprint = "fp:signing:peer001" + val peerDevice = Device( + deviceId = "peer-device-1", + signingKey = "cGVlci1zaWduaW5nLWtleQ==", // base64 for test + encryptionKey = "cGVlci1lbmMta2V5", + createdAt = Instant.now() + ) + + coEvery { keyManager.hasExistingKeys() } returns true + coEvery { authRepository.authenticate() } returns Result.success( + com.kidsync.app.domain.model.DeviceSession("device-join", "token-123", 3600) + ) + coEvery { bucketRepository.joinBucket("bucket-join", "tok-join") } returns Result.success(Unit) + coEvery { bucketRepository.getBucketDevices("bucket-join") } returns Result.success(listOf(peerDevice)) + // Return matching fingerprint for signing key, different for encryption key + every { cryptoManager.computeKeyFingerprint(peerDevice.signingKey) } returns signingFingerprint + every { cryptoManager.computeKeyFingerprint(peerDevice.encryptionKey) } returns "fp:enc:different" + + val encKeyPair = KeyPairGenerator.getInstance("X25519").generateKeyPair() + coEvery { keyManager.getEncryptionKeyPair() } returns encKeyPair + + coEvery { bucketRepository.waitForWrappedDek("bucket-join") } returns WrappedKeyResponse( + wrappedDek = "wrapped-dek-base64", wrappedBy = "sender-pub-key", keyEpoch = 1 + ) + coEvery { cryptoManager.unwrapAndStoreDek(any(), any(), any(), any()) } just Runs + coEvery { keyManager.createKeyAttestation(any(), any()) } returns KeyAttestation( + signerDeviceId = "device-join", + attestedDeviceId = "peer-device-1", + attestedEncryptionKey = "key", + signature = "sig", + createdAt = Instant.now().toString() + ) + coEvery { bucketRepository.uploadKeyAttestation(any()) } returns Result.success(Unit) + coEvery { bucketRepository.storeLocalBucketName(any(), any()) } just Runs + coEvery { bucketRepository.getLocalBucketName(any()) } returns "Shared Bucket" + coEvery { authRepository.getServerUrl() } returns "https://api.kidsync.app" + + val vm = createViewModel() + advanceUntilIdle() + + // Build QR payload with signing key fingerprint + val qrData = """{"v":1,"s":"https://api.kidsync.app","b":"bucket-join","t":"tok-join","f":"$signingFingerprint"}""" + vm.joinBucket(qrData) + advanceUntilIdle() + + // Should have verified against signing key, not encryption key + verify { cryptoManager.computeKeyFingerprint(peerDevice.signingKey) } + + val state = vm.uiState.value + state.isJoined shouldBe true + state.error shouldBe null + } + } }) diff --git a/server/build.gradle.kts b/server/build.gradle.kts index d86480f..25d03aa 100644 --- a/server/build.gradle.kts +++ b/server/build.gradle.kts @@ -60,6 +60,7 @@ dependencies { // Testing testImplementation("io.ktor:ktor-server-test-host:$ktorVersion") testImplementation("io.ktor:ktor-client-content-negotiation:$ktorVersion") + testImplementation("io.ktor:ktor-client-websockets:$ktorVersion") testImplementation("org.jetbrains.kotlin:kotlin-test:2.1.0") testImplementation("org.jetbrains.kotlin:kotlin-test-junit5:2.1.0") testImplementation("org.junit.jupiter:junit-jupiter:5.11.4") diff --git a/server/src/main/kotlin/dev/kidsync/server/Application.kt b/server/src/main/kotlin/dev/kidsync/server/Application.kt index 0f2ebaa..d3187d6 100644 --- a/server/src/main/kotlin/dev/kidsync/server/Application.kt +++ b/server/src/main/kotlin/dev/kidsync/server/Application.kt @@ -1,7 +1,10 @@ package dev.kidsync.server +import dev.kidsync.server.db.Checkpoints import dev.kidsync.server.db.DatabaseFactory import dev.kidsync.server.db.DatabaseFactory.dbQuery +import dev.kidsync.server.services.SyncService +import org.jetbrains.exposed.sql.selectAll import dev.kidsync.server.models.HealthResponse import dev.kidsync.server.plugins.* import dev.kidsync.server.routes.* @@ -19,6 +22,7 @@ import io.ktor.server.routing.* import kotlinx.coroutines.* import org.slf4j.LoggerFactory import org.slf4j.event.Level +import java.io.File fun main() { val config = AppConfig() @@ -48,6 +52,9 @@ fun Application.module(config: AppConfig = AppConfig()) { // SEC3-S-16: Validate storage paths on startup (fail fast if invalid) AppConfig.validateStoragePaths(config) + // Reset IP-based rate limiters on startup (important for test isolation) + dev.kidsync.server.routes.DeviceRegistrationRateLimiter.reset() + // Initialize database DatabaseFactory.init(config) @@ -84,6 +91,20 @@ fun Application.module(config: AppConfig = AppConfig()) { } catch (e: Exception) { LoggerFactory.getLogger("Application").warn("Invite token cleanup failed: {}", e.message) } + // SEC4-S-11: Clean up orphan .tmp files in snapshot and blob storage + // that are older than 1 hour (leftover from crashed uploads) + try { + cleanupOrphanTempFiles(config.snapshotStoragePath) + cleanupOrphanTempFiles(config.blobStoragePath) + } catch (e: Exception) { + LoggerFactory.getLogger("Application").warn("Temp file cleanup failed: {}", e.message) + } + // SEC5-S-14: Prune ops covered by fully-acknowledged checkpoints + try { + pruneAllBuckets(syncService) + } catch (e: Exception) { + LoggerFactory.getLogger("Application").warn("Op pruning failed: {}", e.message) + } } } @@ -107,9 +128,10 @@ fun Application.module(config: AppConfig = AppConfig()) { // this server MUST be deployed behind a trusted reverse proxy (nginx, Caddy, etc.) that // strips/overwrites X-Forwarded-* headers from untrusted clients. Without this, an // attacker can spoof their IP address to bypass rate limiting. - // TODO: When Ktor adds support for configuring trusted proxy addresses, restrict this - // to only trust the known reverse proxy IPs. Alternatively, set KIDSYNC_TRUST_PROXY=false - // to disable forwarded headers entirely if not behind a proxy. + // DEFERRED: Ktor framework limitation — XForwardedHeaders trusts all sources and does not + // support configuring trusted proxy addresses. When Ktor adds this support, restrict to + // known reverse proxy IPs. Workaround: deploy behind a reverse proxy that strips/overwrites + // X-Forwarded-* headers from untrusted clients. install(XForwardedHeaders) // SEC-S-18: Configure CallLogging with a custom format to avoid logging sensitive headers @@ -185,3 +207,36 @@ fun Application.module(config: AppConfig = AppConfig()) { keyRoutes(keyService) } } + +/** + * SEC4-S-11: Clean up orphan .tmp files in a storage directory that are older than 1 hour. + * These are leftover from uploads that crashed between writing the temp file and committing + * the DB row (or renaming to the final filename). + */ +private fun cleanupOrphanTempFiles(storagePath: String) { + val dir = File(storagePath) + if (!dir.exists() || !dir.isDirectory) return + + val oneHourAgo = System.currentTimeMillis() - 3_600_000L + val logger = LoggerFactory.getLogger("Application") + + dir.listFiles()?.filter { it.name.endsWith(".tmp") && it.lastModified() < oneHourAgo }?.forEach { file -> + if (file.delete()) { + logger.info("Cleaned up orphan temp file: {}", file.name) + } + } +} + +/** + * SEC5-S-14: Prune acknowledged ops for all buckets that have checkpoints. + */ +private suspend fun pruneAllBuckets(syncService: SyncService) { + val bucketIds = dbQuery { + Checkpoints.selectAll() + .map { row -> row[Checkpoints.bucketId] } + .distinct() + } + for (bucketId in bucketIds) { + syncService.pruneAcknowledgedOps(bucketId) + } +} diff --git a/server/src/main/kotlin/dev/kidsync/server/db/DatabaseFactory.kt b/server/src/main/kotlin/dev/kidsync/server/db/DatabaseFactory.kt index a564941..cd8c7b2 100644 --- a/server/src/main/kotlin/dev/kidsync/server/db/DatabaseFactory.kt +++ b/server/src/main/kotlin/dev/kidsync/server/db/DatabaseFactory.kt @@ -77,6 +77,7 @@ object DatabaseFactory { KeyAttestations, Sessions, Challenges, + CheckpointAcknowledgments, ) } } diff --git a/server/src/main/kotlin/dev/kidsync/server/db/Tables.kt b/server/src/main/kotlin/dev/kidsync/server/db/Tables.kt index b6b2ddf..3d36867 100644 --- a/server/src/main/kotlin/dev/kidsync/server/db/Tables.kt +++ b/server/src/main/kotlin/dev/kidsync/server/db/Tables.kt @@ -168,10 +168,9 @@ object InviteTokens : Table("invite_tokens") { // SEC3-S-01: Session tokens are stored as SHA-256 hashes, not plaintext. // The raw token is returned to the client; only the hash is persisted. -// SEC4-S-16: TODO - The signingKey column is redundant here since it can be looked up -// via the Devices table using deviceId. Removing it would reduce storage and eliminate -// a potential data inconsistency if the device's signing key is rotated. This requires -// a DB migration and should be addressed when migration tooling is added. +// SEC4-S-16: The signingKey column is no longer written to (empty string) and no longer +// read from. Session validation now joins with Devices to get the current signing key. +// The column is retained for schema compatibility until DB migration tooling is added. object Sessions : Table("sessions") { val tokenHash = varchar("token_hash", 64) val deviceId = varchar("device_id", 36) @@ -197,6 +196,23 @@ object Challenges : Table("challenges") { init { index(false, signingKey) } } +// ---- Checkpoint Acknowledgments ---- +// SEC5-S-14: Track per-device acknowledgment of checkpoints to enable +// safe op pruning once all active devices have acknowledged. + +object CheckpointAcknowledgments : Table("checkpoint_acknowledgments") { + val id = integer("id").autoIncrement() + val checkpointId = integer("checkpoint_id").references(Checkpoints.id) + val deviceId = varchar("device_id", 255) + val acknowledgedAt = long("acknowledged_at") + + override val primaryKey = PrimaryKey(id) + + init { + uniqueIndex(checkpointId, deviceId) + } +} + // ---- Key Attestations ---- object KeyAttestations : Table("key_attestations") { diff --git a/server/src/main/kotlin/dev/kidsync/server/models/Requests.kt b/server/src/main/kotlin/dev/kidsync/server/models/Requests.kt index f910c04..ed09cc8 100644 --- a/server/src/main/kotlin/dev/kidsync/server/models/Requests.kt +++ b/server/src/main/kotlin/dev/kidsync/server/models/Requests.kt @@ -92,6 +92,20 @@ data class PushTokenRequest( val platform: String, ) +// ---- Checkpoint Acknowledgment ---- + +@Serializable +data class AcknowledgeCheckpointRequest( + val checkpointId: Int, +) + +// ---- Bucket Creator Transfer ---- + +@Serializable +data class TransferCreatorRequest( + val targetDeviceId: String, +) + // ---- Snapshot metadata (multipart JSON part) ---- @Serializable diff --git a/server/src/main/kotlin/dev/kidsync/server/routes/BucketRoutes.kt b/server/src/main/kotlin/dev/kidsync/server/routes/BucketRoutes.kt index f51ea4b..d2dada8 100644 --- a/server/src/main/kotlin/dev/kidsync/server/routes/BucketRoutes.kt +++ b/server/src/main/kotlin/dev/kidsync/server/routes/BucketRoutes.kt @@ -118,6 +118,24 @@ fun Route.bucketRoutes(bucketService: BucketService, wsManager: WebSocketManager call.respond(HttpStatusCode.OK, devices) } + /** + * PATCH /buckets/{id}/creator + * SEC3-S-08: Transfer bucket creator role to another device. + * Only the current creator can transfer ownership. + */ + patch("/creator") { + val principal = call.devicePrincipal() + val bucketId = ValidationUtil.requireUuidPathParam(call, "id", "bucket id") + val request = call.receive() + + if (request.targetDeviceId.isBlank()) { + throw ApiException(HttpStatusCode.BadRequest.value, "INVALID_REQUEST", "targetDeviceId is required") + } + + bucketService.transferCreator(bucketId, principal.deviceId, request.targetDeviceId) + call.respond(HttpStatusCode.OK, mapOf("status" to "transferred")) + } + /** * DELETE /buckets/{id}/devices/me * Self-revoke: remove own access from this bucket. diff --git a/server/src/main/kotlin/dev/kidsync/server/routes/DeviceRoutes.kt b/server/src/main/kotlin/dev/kidsync/server/routes/DeviceRoutes.kt index ccd8f57..6619a02 100644 --- a/server/src/main/kotlin/dev/kidsync/server/routes/DeviceRoutes.kt +++ b/server/src/main/kotlin/dev/kidsync/server/routes/DeviceRoutes.kt @@ -21,6 +21,51 @@ import org.jetbrains.exposed.sql.selectAll import java.time.LocalDateTime import java.time.ZoneOffset import java.util.* +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicLong + +/** + * SEC-S-10: IP-based rate limiter for device registration. + * Limits each IP to 5 registrations per hour to prevent mass device creation attacks. + */ +object DeviceRegistrationRateLimiter { + private const val MAX_REGISTRATIONS_PER_IP_PER_HOUR = 5 + private const val WINDOW_MS = 3_600_000L + + private data class IpWindow(val count: AtomicInteger = AtomicInteger(0), val windowStart: AtomicLong = AtomicLong(0)) + private val windows = ConcurrentHashMap() + + fun checkAndIncrement(ip: String): Boolean { + cleanup() + val now = System.currentTimeMillis() + val window = windows.computeIfAbsent(ip) { IpWindow() } + + synchronized(window) { + val start = window.windowStart.get() + if (now - start > WINDOW_MS) { + window.windowStart.set(now) + window.count.set(1) + return true + } + val current = window.count.incrementAndGet() + return current <= MAX_REGISTRATIONS_PER_IP_PER_HOUR + } + } + + private fun cleanup() { + val now = System.currentTimeMillis() + val threshold = now - (2 * WINDOW_MS) + windows.entries.removeIf { (_, window) -> + window.windowStart.get() < threshold + } + } + + /** Visible for testing: reset all rate limit state. */ + fun reset() { + windows.clear() + } +} fun Route.deviceRoutes(sessionUtil: SessionUtil) { rateLimit(RateLimitName("auth")) { @@ -34,6 +79,12 @@ fun Route.deviceRoutes(sessionUtil: SessionUtil) { * provides basic protection for now. */ post("/register") { + // SEC-S-10: IP-based rate limiting for device registration + val clientIp = call.request.local.remoteAddress + if (!DeviceRegistrationRateLimiter.checkAndIncrement(clientIp)) { + throw ApiException(HttpStatusCode.TooManyRequests.value, "RATE_LIMITED", "Too many registration attempts. Try again later.") + } + val request = call.receive() if (request.signingKey.isBlank()) { diff --git a/server/src/main/kotlin/dev/kidsync/server/routes/SyncRoutes.kt b/server/src/main/kotlin/dev/kidsync/server/routes/SyncRoutes.kt index 2ec74ae..d648e92 100644 --- a/server/src/main/kotlin/dev/kidsync/server/routes/SyncRoutes.kt +++ b/server/src/main/kotlin/dev/kidsync/server/routes/SyncRoutes.kt @@ -38,6 +38,11 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicLong +/** WebSocket close codes for application-level errors (RFC 6455 §7.4.2: 4000-4999). */ +private const val WS_CLOSE_INVALID_PARAMS: Short = 4000 +private const val WS_CLOSE_AUTH_FAILED: Short = 4001 +private const val WS_CLOSE_RATE_LIMITED: Short = 4003 + /** * SEC4-S-10: IP-based rate limiter for WebSocket upgrade attempts. * Limits each IP to MAX_WS_CONNECTIONS_PER_IP_PER_MINUTE WebSocket connections per minute. @@ -147,6 +152,18 @@ fun Route.syncRoutes( } } + // SEC5-S-14: Checkpoint acknowledgment for op pruning + rateLimit(RateLimitName("general")) { + post("/checkpoints/acknowledge") { + val principal = call.devicePrincipal() + val bucketId = ValidationUtil.requireUuidPathParam(call, "id", "bucket id") + val request = call.receive() + + syncService.acknowledgeCheckpoint(bucketId, principal.deviceId, request.checkpointId) + call.respond(HttpStatusCode.OK, mapOf("status" to "acknowledged")) + } + } + // SEC6-S-06: Per-bucket snapshot quota enforced below. // POST /buckets/{id}/snapshots @@ -281,19 +298,18 @@ fun Route.syncRoutes( val sizeBytes = blob.size.toLong() val now = LocalDateTime.now(ZoneOffset.UTC) - // SEC4-S-11: TODO - If the server crashes between writing the file to disk - // and committing the DB row, an orphaned file remains. A background cleanup - // job should periodically scan the snapshot directory for files not referenced - // in the Snapshots table and delete them after a grace period (e.g., 1 hour). + // SEC4-S-11: Write to temp file first, rename after DB commit to + // prevent orphaned final files on crash between write and commit. val snapshotId = UUID.randomUUID().toString() val snapshotDir = File(config.snapshotStoragePath) snapshotDir.mkdirs() - val snapshotFile = File(snapshotDir, snapshotId) - snapshotFile.writeBytes(blob) + val tempFile = File(snapshotDir, "$snapshotId.tmp") + val finalFile = File(snapshotDir, snapshotId) + tempFile.writeBytes(blob) // SEC6-S-15: Set file permissions to 600 (owner read/write only) try { - Files.setPosixFilePermissions(snapshotFile.toPath(), PosixFilePermissions.fromString("rw-------")) + Files.setPosixFilePermissions(tempFile.toPath(), PosixFilePermissions.fromString("rw-------")) } catch (_: UnsupportedOperationException) { // Windows doesn't support POSIX file permissions } @@ -314,9 +330,11 @@ fun Route.syncRoutes( it[createdAt] = now } } + // DB commit succeeded -- rename temp to final + tempFile.renameTo(finalFile) } catch (e: Exception) { - // Clean up orphaned file on DB insert failure - snapshotFile.delete() + // Clean up temp file on DB insert failure + tempFile.delete() throw e } @@ -332,9 +350,50 @@ fun Route.syncRoutes( } } - // SEC3-S-06: TODO - A snapshot download endpoint (GET /buckets/{id}/snapshots/{snapshotId}/download) - // is needed to allow clients to download the actual snapshot binary data, not just metadata. - // This should serve the file from snapshotStoragePath with path traversal protection. + // SEC3-S-06: Snapshot download endpoint + rateLimit(RateLimitName("general")) { + get("/snapshots/{snapshotId}/download") { + val principal = call.devicePrincipal() + val bucketId = ValidationUtil.requireUuidPathParam(call, "id", "bucket id") + val snapshotId = ValidationUtil.requireUuidPathParam(call, "snapshotId", "snapshot id") + + // Verify bucket access + dbQuery { BucketService.requireBucketAccess(bucketId, principal.deviceId) } + + // Look up snapshot and verify it belongs to this bucket + val snapshot = dbQuery { + Snapshots.selectAll() + .where { Snapshots.id eq snapshotId } + .firstOrNull() + } ?: throw ApiException(HttpStatusCode.NotFound.value, "NOT_FOUND", "Snapshot not found") + + if (snapshot[Snapshots.bucketId] != bucketId) { + throw ApiException(HttpStatusCode.Forbidden.value, "BUCKET_ACCESS_DENIED", "Snapshot does not belong to this bucket") + } + + // Resolve file path with path traversal protection + val snapshotDir = File(config.snapshotStoragePath).canonicalFile + val snapshotFile = File(config.snapshotStoragePath, snapshot[Snapshots.filePath]) + if (!snapshotFile.canonicalFile.startsWith(snapshotDir)) { + throw ApiException(HttpStatusCode.Forbidden.value, "BUCKET_ACCESS_DENIED", "Invalid file path") + } + + if (!snapshotFile.exists()) { + // Recover from crash between DB commit and rename (SEC4-S-11) + val tempFile = File(config.snapshotStoragePath, "${snapshot[Snapshots.filePath]}.tmp") + if (tempFile.exists()) { + tempFile.renameTo(snapshotFile) + } + if (!snapshotFile.exists()) { + throw ApiException(HttpStatusCode.NotFound.value, "NOT_FOUND", "Snapshot file not found on disk") + } + } + + // Return the verified SHA-256 stored at upload time (avoids loading entire file into memory) + call.response.header("X-Snapshot-SHA256", snapshot[Snapshots.sha256Hash]) + call.respondFile(snapshotFile) + } + } // GET /buckets/{id}/snapshots/latest rateLimit(RateLimitName("general")) { @@ -379,24 +438,23 @@ fun Route.syncRoutes( // SEC6-S-07: The session token is hashed with SHA-256 before being stored in the WebSocket // connection's in-memory state. This limits exposure if the server process memory is dumped. - // SEC5-S-03: TODO - Accept WebSocket auth token as a query parameter (?token=...) in addition - // to the current in-band auth message. This would allow the server to reject unauthenticated - // connections at upgrade time (before allocating WebSocket resources), reducing the attack - // surface for resource exhaustion. The token should still be validated via SessionUtil. + // SEC5-S-03: WebSocket auth supports both query param (?token=...) and in-band auth message. + // Query param auth allows the server to reject unauthenticated connections at upgrade time, + // reducing the attack surface for resource exhaustion. // WebSocket /buckets/{id}/ws webSocket("/buckets/{id}/ws") { // SEC4-S-10: IP-based rate limiting for WebSocket upgrade attempts val clientIp = call.request.local.remoteAddress if (!WebSocketConnectionRateLimiter.checkAndIncrement(clientIp)) { - close(CloseReason(4003, "Rate limit exceeded")) + close(CloseReason(WS_CLOSE_RATE_LIMITED, "Rate limit exceeded")) return@webSocket } val json = Json { ignoreUnknownKeys = true } val bucketId = call.parameters["id"] if (bucketId == null || !ValidationUtil.isValidUUID(bucketId)) { - close(CloseReason(4000, "Missing or invalid bucket id")) + close(CloseReason(WS_CLOSE_INVALID_PARAMS, "Missing or invalid bucket id")) return@webSocket } @@ -405,50 +463,64 @@ fun Route.syncRoutes( var sessionTokenHash: String? = null try { - // Wait for auth message with timeout - val authFrame = withTimeoutOrNull(5000) { incoming.receive() } - if (authFrame == null) { - close(CloseReason(4001, "Auth timeout")) - return@webSocket - } - if (authFrame !is Frame.Text) { - close(CloseReason(4001, "Expected text frame for auth")) - return@webSocket - } + // SEC5-S-03: Check for query param auth first, fall back to in-band auth + val queryToken = call.request.queryParameters["token"] + val session: dev.kidsync.server.util.Session? + + if (queryToken != null) { + // Query param auth: validate token from URL + session = sessionUtil.validateSession(queryToken) + if (session == null) { + close(CloseReason(WS_CLOSE_AUTH_FAILED, "Invalid token")) + return@webSocket + } + sessionTokenHash = dev.kidsync.server.util.HashUtil.sha256HexString(queryToken) + } else { + // In-band auth: wait for auth message (backward compatible) + val authFrame = withTimeoutOrNull(5000) { incoming.receive() } + if (authFrame == null) { + close(CloseReason(WS_CLOSE_AUTH_FAILED, "Auth timeout")) + return@webSocket + } + if (authFrame !is Frame.Text) { + close(CloseReason(WS_CLOSE_AUTH_FAILED, "Expected text frame for auth")) + return@webSocket + } - val authText = authFrame.readText() - val authJsonObj = json.parseToJsonElement(authText).jsonObject - val authType = authJsonObj["type"]?.jsonPrimitive?.content + val authText = authFrame.readText() + val authJsonObj = json.parseToJsonElement(authText).jsonObject + val authType = authJsonObj["type"]?.jsonPrimitive?.content - if (authType != "auth") { - close(CloseReason(4001, "Expected auth message")) - return@webSocket - } + if (authType != "auth") { + close(CloseReason(WS_CLOSE_AUTH_FAILED, "Expected auth message")) + return@webSocket + } - val token = authJsonObj["token"]?.jsonPrimitive?.content - if (token == null) { - close(CloseReason(4001, "Missing auth token")) - return@webSocket - } + val token = authJsonObj["token"]?.jsonPrimitive?.content + if (token == null) { + close(CloseReason(WS_CLOSE_AUTH_FAILED, "Missing auth token")) + return@webSocket + } - val session = sessionUtil.validateSession(token) - if (session == null) { - send( - Frame.Text( - json.encodeToString( - WsAuthFailed.serializer(), - WsAuthFailed(error = "TOKEN_INVALID", message = "Session token invalid or expired") + session = sessionUtil.validateSession(token) + if (session == null) { + send( + Frame.Text( + json.encodeToString( + WsAuthFailed.serializer(), + WsAuthFailed(error = "TOKEN_INVALID", message = "Session token invalid or expired") + ) ) ) - ) - close(CloseReason(4001, "Auth failed")) - return@webSocket - } + close(CloseReason(WS_CLOSE_AUTH_FAILED, "Auth failed")) + return@webSocket + } - // SEC6-S-07: Hash the token immediately; discard the raw value - sessionTokenHash = dev.kidsync.server.util.HashUtil.sha256HexString(token) + // SEC6-S-07: Hash the token immediately; discard the raw value + sessionTokenHash = dev.kidsync.server.util.HashUtil.sha256HexString(token) + } - // Verify bucket access + // Verify bucket access (shared for both auth paths) val hasAccess = try { dbQuery { BucketService.requireBucketAccess(bucketId, session.deviceId) } true @@ -465,7 +537,7 @@ fun Route.syncRoutes( ) ) ) - close(CloseReason(4001, "No access")) + close(CloseReason(WS_CLOSE_AUTH_FAILED, "No access")) return@webSocket } @@ -474,7 +546,7 @@ fun Route.syncRoutes( // SEC-S-06: Enforce connection limits if (!wsManager.addConnection(bucketId, connection!!)) { - close(CloseReason(4003, "Connection limit exceeded")) + close(CloseReason(WS_CLOSE_RATE_LIMITED, "Connection limit exceeded")) return@webSocket } @@ -499,7 +571,7 @@ fun Route.syncRoutes( // SEC6-S-07: Re-validate using the stored hash instead of raw token val revalidatedSession = sessionUtil.validateSessionByHash(sessionTokenHash!!) if (revalidatedSession == null) { - close(CloseReason(4001, "Session expired")) + close(CloseReason(WS_CLOSE_AUTH_FAILED, "Session expired")) return@launch } val stillHasAccess = try { diff --git a/server/src/main/kotlin/dev/kidsync/server/services/BucketService.kt b/server/src/main/kotlin/dev/kidsync/server/services/BucketService.kt index a17b24a..e17f433 100644 --- a/server/src/main/kotlin/dev/kidsync/server/services/BucketService.kt +++ b/server/src/main/kotlin/dev/kidsync/server/services/BucketService.kt @@ -5,6 +5,7 @@ import dev.kidsync.server.db.DatabaseFactory.dbQuery import dev.kidsync.server.models.* import dev.kidsync.server.util.HashUtil import dev.kidsync.server.util.SessionUtil +import io.ktor.http.* import org.jetbrains.exposed.sql.* import org.jetbrains.exposed.sql.SqlExpressionBuilder.eq import org.jetbrains.exposed.sql.SqlExpressionBuilder.lessEq @@ -57,9 +58,7 @@ class BucketService( /** * Create a new anonymous bucket. The creator device automatically gets access. * - * SEC3-S-08: TODO - The bucket creator role cannot be transferred to another device. - * If the creator device is lost, the bucket cannot be deleted by any other device. - * Consider adding a creator transfer mechanism or multi-admin support in a future version. + * SEC3-S-08: The bucket creator role can be transferred via PATCH /buckets/{id}/creator. */ suspend fun createBucket(deviceId: String): BucketResponse { val bucketId = UUID.randomUUID().toString() @@ -428,6 +427,42 @@ class BucketService( } } + /** + * SEC3-S-08: Transfer the bucket creator role to another device. + * Only the current creator can transfer ownership. The target must have + * active (non-revoked) access to the bucket. + */ + @Suppress("ThrowsCount") // Validation throws are intentional guard clauses + suspend fun transferCreator(bucketId: String, callerDeviceId: String, targetDeviceId: String) { + if (callerDeviceId == targetDeviceId) { + throw ApiException(HttpStatusCode.BadRequest.value, "INVALID_REQUEST", "Cannot transfer creator role to yourself") + } + + dbQuery { + val bucket = Buckets.selectAll().where { Buckets.id eq bucketId }.firstOrNull() + ?: throw ApiException(HttpStatusCode.NotFound.value, "NOT_FOUND", "Bucket not found") + + if (bucket[Buckets.createdBy] != callerDeviceId) { + throw ApiException(HttpStatusCode.Forbidden.value, "NOT_BUCKET_CREATOR", "Only the bucket creator can transfer ownership") + } + + // Verify target has active (non-revoked) bucket access + val hasAccess = BucketAccess.selectAll().where { + (BucketAccess.bucketId eq bucketId) and + (BucketAccess.deviceId eq targetDeviceId) and + BucketAccess.revokedAt.isNull() + }.any() + if (!hasAccess) { + throw ApiException(HttpStatusCode.NotFound.value, "NOT_FOUND", "Target device does not have active access to this bucket") + } + + // Transfer ownership + Buckets.update({ Buckets.id eq bucketId }) { + it[createdBy] = targetDeviceId + } + } + } + /** * SEC5-S-08: Creator-driven device revocation. Only the bucket creator can remove * another device from the bucket. Removes bucket access and deletes wrapped keys diff --git a/server/src/main/kotlin/dev/kidsync/server/services/KeyService.kt b/server/src/main/kotlin/dev/kidsync/server/services/KeyService.kt index 5523c48..dc3ae7b 100644 --- a/server/src/main/kotlin/dev/kidsync/server/services/KeyService.kt +++ b/server/src/main/kotlin/dev/kidsync/server/services/KeyService.kt @@ -19,9 +19,10 @@ class KeyService { private val isoFormatter = DateTimeFormatter.ISO_INSTANT // ---- Wrapped Keys ---- - // SEC3-S-17: TODO - No key rotation mechanism exists. When a device is revoked from a - // bucket, the remaining devices should rotate to a new key epoch and re-wrap the DEK - // for each remaining device. This is a planned feature for a future version. + // SEC3-S-17: DEFERRED - Key rotation mechanism. When a device is revoked, remaining + // devices should rotate to a new epoch and re-wrap the DEK. Requires protocol design: + // epoch semantics, backward compatibility for offline devices, and rotation triggers. + // Tracked as future protocol-level work. /** * Upload a wrapped DEK for a target device. diff --git a/server/src/main/kotlin/dev/kidsync/server/services/SyncService.kt b/server/src/main/kotlin/dev/kidsync/server/services/SyncService.kt index 2608b2e..56458f8 100644 --- a/server/src/main/kotlin/dev/kidsync/server/services/SyncService.kt +++ b/server/src/main/kotlin/dev/kidsync/server/services/SyncService.kt @@ -7,6 +7,9 @@ import dev.kidsync.server.models.* import dev.kidsync.server.util.HashUtil import dev.kidsync.server.util.ValidationUtil import org.jetbrains.exposed.sql.* +import org.jetbrains.exposed.sql.SqlExpressionBuilder.eq +import org.jetbrains.exposed.sql.SqlExpressionBuilder.lessEq +import org.slf4j.LoggerFactory import java.time.LocalDateTime import java.time.ZoneOffset import java.time.format.DateTimeFormatter @@ -33,13 +36,15 @@ data class CheckpointCreated(val startSequence: Long, val endSequence: Long) * devices. This is by design: checkpoints provide an integrity anchor for the global * op stream, not per-device streams. */ -// SEC5-S-14: TODO - Add op table pruning after checkpoints. Once a checkpoint covers a range -// of ops, the individual ops in that range could be pruned to save storage. This requires -// ensuring all devices have acknowledged the checkpoint before pruning, which needs a -// per-device checkpoint acknowledgment tracking mechanism. +// SEC5-S-14: Op table pruning is implemented via CheckpointAcknowledgments table and +// pruneAcknowledgedOps(). Ops covered by fully-acknowledged checkpoints (all active +// devices have acknowledged) are pruned, preserving the latest checkpoint's ops as +// a safety margin. class SyncService(private val config: AppConfig) { + private val logger = LoggerFactory.getLogger(SyncService::class.java) + private val isoFormatter = DateTimeFormatter.ISO_INSTANT /** @@ -311,4 +316,97 @@ class SyncService(private val config: AppConfig) { } return null } + + /** + * SEC5-S-14: Record a device's acknowledgment of a checkpoint. + * Upserts to handle idempotent re-acknowledgment. + */ + suspend fun acknowledgeCheckpoint(bucketId: String, deviceId: String, checkpointId: Int) { + dbQuery { + BucketService.requireBucketAccess(bucketId, deviceId) + + // Verify checkpoint exists and belongs to this bucket + val checkpoint = Checkpoints.selectAll() + .where { (Checkpoints.id eq checkpointId) and (Checkpoints.bucketId eq bucketId) } + .firstOrNull() + ?: throw ApiException(404, "NOT_FOUND", "Checkpoint not found in this bucket") + + // Upsert: insert or ignore if already acknowledged + val existing = CheckpointAcknowledgments.selectAll() + .where { + (CheckpointAcknowledgments.checkpointId eq checkpointId) and + (CheckpointAcknowledgments.deviceId eq deviceId) + } + .firstOrNull() + + if (existing == null) { + CheckpointAcknowledgments.insert { + it[CheckpointAcknowledgments.checkpointId] = checkpointId + it[CheckpointAcknowledgments.deviceId] = deviceId + it[acknowledgedAt] = java.time.Instant.now().epochSecond + } + } + } + } + + /** + * SEC5-S-14: Prune ops covered by fully-acknowledged checkpoints. + * + * A checkpoint is "fully acknowledged" when ALL active (non-revoked) devices + * in the bucket have acknowledged it. For safety, the latest fully-acknowledged + * checkpoint's ops are preserved -- only ops covered by older checkpoints are pruned. + * + * Returns the number of ops pruned. + */ + suspend fun pruneAcknowledgedOps(bucketId: String): Long { + return dbQuery { + // Get all active (non-revoked) devices for this bucket + val activeDeviceIds = BucketAccess.selectAll() + .where { + (BucketAccess.bucketId eq bucketId) and BucketAccess.revokedAt.isNull() + } + .map { it[BucketAccess.deviceId] } + .toSet() + + if (activeDeviceIds.isEmpty()) return@dbQuery 0L + + // Get all checkpoints for this bucket ordered by endSequence + val checkpoints = Checkpoints.selectAll() + .where { Checkpoints.bucketId eq bucketId } + .orderBy(Checkpoints.endSequence, SortOrder.ASC) + .toList() + + if (checkpoints.isEmpty()) return@dbQuery 0L + + // Find checkpoints where ALL active devices have acknowledged + val fullyAcknowledged = checkpoints.filter { cp -> + val cpId = cp[Checkpoints.id] + val acknowledgedDevices = CheckpointAcknowledgments.selectAll() + .where { CheckpointAcknowledgments.checkpointId eq cpId } + .map { it[CheckpointAcknowledgments.deviceId] } + .toSet() + activeDeviceIds.all { it in acknowledgedDevices } + } + + if (fullyAcknowledged.isEmpty()) return@dbQuery 0L + + // Keep the latest fully-acknowledged checkpoint's ops as safety margin. + // Only prune ops covered by older fully-acknowledged checkpoints. + val prunableCps = fullyAcknowledged.dropLast(1) + if (prunableCps.isEmpty()) return@dbQuery 0L + + val maxPrunableSeq = prunableCps.last()[Checkpoints.endSequence] + + // Delete ops up to and including maxPrunableSeq + val deleted = Ops.deleteWhere { + (Ops.bucketId eq bucketId) and (Ops.sequence lessEq maxPrunableSeq) + }.toLong() + + if (deleted > 0) { + logger.info("Pruned {} ops from bucket {} (up to seq {})", deleted, bucketId, maxPrunableSeq) + } + + deleted + } + } } diff --git a/server/src/main/kotlin/dev/kidsync/server/util/SessionUtil.kt b/server/src/main/kotlin/dev/kidsync/server/util/SessionUtil.kt index 9d7a479..561c28c 100644 --- a/server/src/main/kotlin/dev/kidsync/server/util/SessionUtil.kt +++ b/server/src/main/kotlin/dev/kidsync/server/util/SessionUtil.kt @@ -2,6 +2,7 @@ package dev.kidsync.server.util import dev.kidsync.server.AppConfig import dev.kidsync.server.db.Challenges +import dev.kidsync.server.db.Devices import dev.kidsync.server.db.Sessions import dev.kidsync.server.db.DatabaseFactory.dbQuery import org.jetbrains.exposed.sql.* @@ -169,7 +170,9 @@ class SessionUtil(private val config: AppConfig) { Sessions.insert { it[Sessions.tokenHash] = hashedToken it[Sessions.deviceId] = deviceId - it[Sessions.signingKey] = signingKey + // SEC4-S-16: signingKey is now read from Devices table via join; + // column retained for schema compatibility but no longer used. + it[Sessions.signingKey] = "" it[createdAt] = now.epochSecond it[expiresAt] = session.expiresAt.epochSecond } @@ -181,13 +184,16 @@ class SessionUtil(private val config: AppConfig) { /** * Validate a session token. Returns the session if valid, null otherwise. * SEC3-S-01: Hashes the input token with SHA-256 and queries against the stored hash. + * SEC4-S-16: Reads signingKey from Devices table (via join) instead of the redundant + * Sessions.signingKey column, ensuring the current device key is always returned. */ suspend fun validateSession(token: String): Session? { // SEC6-S-05: Reject tokens without the session prefix (prevents challenge token cross-use) if (!token.startsWith(SESSION_TOKEN_PREFIX)) return null val hashedToken = HashUtil.sha256HexString(token) return dbQuery { - val row = Sessions.selectAll() + val row = Sessions.join(Devices, JoinType.INNER, Sessions.deviceId, Devices.id) + .selectAll() .where { Sessions.tokenHash eq hashedToken } .firstOrNull() ?: return@dbQuery null @@ -199,7 +205,7 @@ class SessionUtil(private val config: AppConfig) { Session( deviceId = row[Sessions.deviceId], - signingKey = row[Sessions.signingKey], + signingKey = row[Devices.signingKey], createdAt = Instant.ofEpochSecond(row[Sessions.createdAt]), expiresAt = expiresAt, ) @@ -209,10 +215,13 @@ class SessionUtil(private val config: AppConfig) { /** * SEC6-S-07: Validate a session by its pre-computed SHA-256 hash. * Used by WebSocket connections that store only the hash to avoid holding raw tokens in memory. + * SEC4-S-16: Reads signingKey from Devices table (via join) instead of the redundant + * Sessions.signingKey column. */ suspend fun validateSessionByHash(tokenHash: String): Session? { return dbQuery { - val row = Sessions.selectAll() + val row = Sessions.join(Devices, JoinType.INNER, Sessions.deviceId, Devices.id) + .selectAll() .where { Sessions.tokenHash eq tokenHash } .firstOrNull() ?: return@dbQuery null @@ -224,7 +233,7 @@ class SessionUtil(private val config: AppConfig) { Session( deviceId = row[Sessions.deviceId], - signingKey = row[Sessions.signingKey], + signingKey = row[Devices.signingKey], createdAt = Instant.ofEpochSecond(row[Sessions.createdAt]), expiresAt = expiresAt, ) diff --git a/server/src/test/kotlin/dev/kidsync/server/BucketCreatorTransferTest.kt b/server/src/test/kotlin/dev/kidsync/server/BucketCreatorTransferTest.kt new file mode 100644 index 0000000..83a3ed0 --- /dev/null +++ b/server/src/test/kotlin/dev/kidsync/server/BucketCreatorTransferTest.kt @@ -0,0 +1,153 @@ +package dev.kidsync.server + +import dev.kidsync.server.TestHelper.createJsonClient +import dev.kidsync.server.models.ErrorResponse +import dev.kidsync.server.models.TransferCreatorRequest +import io.ktor.client.call.body +import io.ktor.client.request.delete +import io.ktor.client.request.header +import io.ktor.client.request.patch +import io.ktor.client.request.setBody +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.server.testing.testApplication +import org.junit.jupiter.api.Test +import kotlin.test.assertEquals + +/** + * Integration tests for SEC3-S-08: Bucket creator role transfer. + */ +class BucketCreatorTransferTest { + + @Test + fun `creator successfully transfers to active member`() = testApplication { + application { module(testConfig()) } + val client = createJsonClient() + + val (deviceA, deviceB) = TestHelper.setupTwoDeviceBucket(client) + val bucketId = deviceA.bucketId!! + + // Transfer creator from A to B + val resp = client.patch("/buckets/$bucketId/creator") { + contentType(ContentType.Application.Json) + header(HttpHeaders.Authorization, "Bearer ${deviceA.sessionToken}") + setBody(TransferCreatorRequest(targetDeviceId = deviceB.deviceId)) + } + assertEquals(HttpStatusCode.OK, resp.status) + + // Verify B is now the creator: B can delete the bucket + val deleteResp = client.delete("/buckets/$bucketId") { + header(HttpHeaders.Authorization, "Bearer ${deviceB.sessionToken}") + } + assertEquals(HttpStatusCode.NoContent, deleteResp.status) + } + + @Test + fun `non-creator gets 403`() = testApplication { + application { module(testConfig()) } + val client = createJsonClient() + + val (deviceA, deviceB) = TestHelper.setupTwoDeviceBucket(client) + val bucketId = deviceA.bucketId!! + + // Device B (non-creator) tries to transfer creator + val resp = client.patch("/buckets/$bucketId/creator") { + contentType(ContentType.Application.Json) + header(HttpHeaders.Authorization, "Bearer ${deviceB.sessionToken}") + setBody(TransferCreatorRequest(targetDeviceId = deviceA.deviceId)) + } + assertEquals(HttpStatusCode.Forbidden, resp.status) + val body = resp.body() + assertEquals("NOT_BUCKET_CREATOR", body.error) + } + + @Test + fun `target not in bucket gets 404`() = testApplication { + application { module(testConfig()) } + val client = createJsonClient() + + val deviceA = TestHelper.setupDeviceWithBucket(client) + val bucketId = deviceA.bucketId!! + + // Create outsider device (not in bucket) + val outsiderReg = TestHelper.registerDevice(client) + val outsider = TestHelper.authenticateDevice(client, outsiderReg) + + val resp = client.patch("/buckets/$bucketId/creator") { + contentType(ContentType.Application.Json) + header(HttpHeaders.Authorization, "Bearer ${deviceA.sessionToken}") + setBody(TransferCreatorRequest(targetDeviceId = outsider.deviceId)) + } + assertEquals(HttpStatusCode.NotFound, resp.status) + } + + @Test + fun `revoked target gets 404`() = testApplication { + application { module(testConfig()) } + val client = createJsonClient() + + val (deviceA, deviceB) = TestHelper.setupTwoDeviceBucket(client) + val bucketId = deviceA.bucketId!! + + // Revoke device B + val revokeResp = client.delete("/buckets/$bucketId/devices/${deviceB.deviceId}") { + header(HttpHeaders.Authorization, "Bearer ${deviceA.sessionToken}") + } + assertEquals(HttpStatusCode.NoContent, revokeResp.status) + + // Try to transfer to revoked device B + val resp = client.patch("/buckets/$bucketId/creator") { + contentType(ContentType.Application.Json) + header(HttpHeaders.Authorization, "Bearer ${deviceA.sessionToken}") + setBody(TransferCreatorRequest(targetDeviceId = deviceB.deviceId)) + } + assertEquals(HttpStatusCode.NotFound, resp.status) + } + + @Test + fun `after transfer new creator can perform creator-only operations`() = testApplication { + application { module(testConfig()) } + val client = createJsonClient() + + val (deviceA, deviceB) = TestHelper.setupTwoDeviceBucket(client) + val bucketId = deviceA.bucketId!! + + // Transfer from A to B + val transferResp = client.patch("/buckets/$bucketId/creator") { + contentType(ContentType.Application.Json) + header(HttpHeaders.Authorization, "Bearer ${deviceA.sessionToken}") + setBody(TransferCreatorRequest(targetDeviceId = deviceB.deviceId)) + } + assertEquals(HttpStatusCode.OK, transferResp.status) + + // Old creator (A) cannot revoke B + val revokeByA = client.delete("/buckets/$bucketId/devices/${deviceB.deviceId}") { + header(HttpHeaders.Authorization, "Bearer ${deviceA.sessionToken}") + } + assertEquals(HttpStatusCode.Forbidden, revokeByA.status) + + // New creator (B) can revoke A + val revokeByB = client.delete("/buckets/$bucketId/devices/${deviceA.deviceId}") { + header(HttpHeaders.Authorization, "Bearer ${deviceB.sessionToken}") + } + assertEquals(HttpStatusCode.NoContent, revokeByB.status) + } + + @Test + fun `cannot transfer to self`() = testApplication { + application { module(testConfig()) } + val client = createJsonClient() + + val device = TestHelper.setupDeviceWithBucket(client) + val bucketId = device.bucketId!! + + val resp = client.patch("/buckets/$bucketId/creator") { + contentType(ContentType.Application.Json) + header(HttpHeaders.Authorization, "Bearer ${device.sessionToken}") + setBody(TransferCreatorRequest(targetDeviceId = device.deviceId)) + } + assertEquals(HttpStatusCode.BadRequest, resp.status) + } +} diff --git a/server/src/test/kotlin/dev/kidsync/server/DeviceRegistrationRateLimitTest.kt b/server/src/test/kotlin/dev/kidsync/server/DeviceRegistrationRateLimitTest.kt new file mode 100644 index 0000000..5fe9da9 --- /dev/null +++ b/server/src/test/kotlin/dev/kidsync/server/DeviceRegistrationRateLimitTest.kt @@ -0,0 +1,110 @@ +package dev.kidsync.server + +import dev.kidsync.server.TestHelper.createJsonClient +import dev.kidsync.server.models.ErrorResponse +import dev.kidsync.server.models.RegisterRequest +import dev.kidsync.server.models.RegisterResponse +import dev.kidsync.server.routes.DeviceRegistrationRateLimiter +import io.ktor.client.call.body +import io.ktor.client.request.header +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.server.testing.testApplication +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import kotlin.test.assertEquals + +/** + * Tests for SEC-S-10: Device registration IP-based rate limiting. + */ +class DeviceRegistrationRateLimitTest { + + @BeforeEach + fun resetRateLimiter() { + DeviceRegistrationRateLimiter.reset() + } + + private suspend fun registerNewDevice(client: io.ktor.client.HttpClient): io.ktor.client.statement.HttpResponse { + val signingKeyPair = TestHelper.generateSigningKeyPair() + val encryptionKeyPair = TestHelper.generateEncryptionKeyPair() + return client.post("/register") { + contentType(ContentType.Application.Json) + setBody(RegisterRequest( + signingKey = TestHelper.encodePublicKey(signingKeyPair.public), + encryptionKey = TestHelper.encodePublicKey(encryptionKeyPair.public), + )) + } + } + + @Test + fun `5 registrations from same IP succeed`() = testApplication { + application { module(testConfig()) } + val client = createJsonClient() + + for (i in 1..5) { + val resp = registerNewDevice(client) + assertEquals(HttpStatusCode.Created, resp.status, + "Registration $i should succeed") + } + } + + @Test + fun `6th registration returns 429`() = testApplication { + application { module(testConfig()) } + val client = createJsonClient() + + // First 5 succeed + for (i in 1..5) { + val resp = registerNewDevice(client) + assertEquals(HttpStatusCode.Created, resp.status, + "Registration $i should succeed") + } + + // 6th should be rate limited + val resp = registerNewDevice(client) + assertEquals(HttpStatusCode.TooManyRequests, resp.status) + val body = resp.body() + assertEquals("RATE_LIMITED", body.error) + } + + @Test + fun `rate limiter unit test - independent IPs`() { + DeviceRegistrationRateLimiter.reset() + + // IP A uses 5 slots + for (i in 1..5) { + assertEquals(true, DeviceRegistrationRateLimiter.checkAndIncrement("1.2.3.4"), + "IP A registration $i should pass") + } + // IP A is now limited + assertEquals(false, DeviceRegistrationRateLimiter.checkAndIncrement("1.2.3.4"), + "IP A should be rate limited after 5") + + // IP B should still be fine + for (i in 1..5) { + assertEquals(true, DeviceRegistrationRateLimiter.checkAndIncrement("5.6.7.8"), + "IP B registration $i should pass") + } + } + + @Test + fun `rate limiter unit test - cleanup removes stale entries`() { + DeviceRegistrationRateLimiter.reset() + + // Fill up IP A + for (i in 1..5) { + DeviceRegistrationRateLimiter.checkAndIncrement("10.0.0.1") + } + assertEquals(false, DeviceRegistrationRateLimiter.checkAndIncrement("10.0.0.1")) + + // Reset simulates window expiry (in practice, time passes) + DeviceRegistrationRateLimiter.reset() + + // After reset, IP A can register again + assertEquals(true, DeviceRegistrationRateLimiter.checkAndIncrement("10.0.0.1")) + } +} diff --git a/server/src/test/kotlin/dev/kidsync/server/OpPruningTest.kt b/server/src/test/kotlin/dev/kidsync/server/OpPruningTest.kt new file mode 100644 index 0000000..aabbe98 --- /dev/null +++ b/server/src/test/kotlin/dev/kidsync/server/OpPruningTest.kt @@ -0,0 +1,223 @@ +package dev.kidsync.server + +import dev.kidsync.server.TestHelper.createJsonClient +import dev.kidsync.server.TestHelper.uploadOpsBatch +import dev.kidsync.server.models.AcknowledgeCheckpointRequest +import dev.kidsync.server.models.PullOpsResponse +import io.ktor.client.call.body +import io.ktor.client.request.delete +import io.ktor.client.request.get +import io.ktor.client.request.header +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.client.statement.HttpResponse +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.server.testing.testApplication +import org.junit.jupiter.api.Test +import kotlin.test.assertEquals + +/** + * Integration tests for SEC5-S-14: Op table pruning after checkpoints. + */ +class OpPruningTest { + + private suspend fun acknowledgeCheckpoint( + client: io.ktor.client.HttpClient, + device: TestDevice, + checkpointId: Int, + ): HttpResponse { + val bucketId = device.bucketId!! + return client.post("/buckets/$bucketId/checkpoints/acknowledge") { + contentType(ContentType.Application.Json) + header(HttpHeaders.Authorization, "Bearer ${device.sessionToken}") + setBody(AcknowledgeCheckpointRequest(checkpointId = checkpointId)) + } + } + + private suspend fun getOpsCount( + client: io.ktor.client.HttpClient, + device: TestDevice, + ): Int { + val bucketId = device.bucketId!! + val resp = client.get("/buckets/$bucketId/ops?since=0&limit=1000") { + header(HttpHeaders.Authorization, "Bearer ${device.sessionToken}") + } + assertEquals(HttpStatusCode.OK, resp.status) + return resp.body().ops.size + } + + @Test + fun `acknowledge checkpoint returns 200`() = testApplication { + // checkpointInterval=100 so 100 ops creates 1 checkpoint + val config = testConfig().copy(checkpointInterval = 100) + application { module(config) } + val client = createJsonClient() + + val device = TestHelper.setupDeviceWithBucket(client) + uploadOpsBatch(client, device, 100) + + // Get checkpoint + val cpResp = client.get("/buckets/${device.bucketId}/checkpoint") { + header(HttpHeaders.Authorization, "Bearer ${device.sessionToken}") + } + assertEquals(HttpStatusCode.OK, cpResp.status) + + // The checkpoint ID is the auto-increment ID from the Checkpoints table. + // We need to get it. We'll use checkpointId=1 since it's the first checkpoint. + val ackResp = acknowledgeCheckpoint(client, device, 1) + assertEquals(HttpStatusCode.OK, ackResp.status) + } + + @Test + fun `idempotent checkpoint acknowledgment`() = testApplication { + val config = testConfig().copy(checkpointInterval = 100) + application { module(config) } + val client = createJsonClient() + + val device = TestHelper.setupDeviceWithBucket(client) + uploadOpsBatch(client, device, 100) + + // Acknowledge twice + val ack1 = acknowledgeCheckpoint(client, device, 1) + assertEquals(HttpStatusCode.OK, ack1.status) + + val ack2 = acknowledgeCheckpoint(client, device, 1) + assertEquals(HttpStatusCode.OK, ack2.status) + } + + @Test + fun `acknowledge nonexistent checkpoint returns 404`() = testApplication { + val config = testConfig().copy(checkpointInterval = 100) + application { module(config) } + val client = createJsonClient() + + val device = TestHelper.setupDeviceWithBucket(client) + + val ackResp = acknowledgeCheckpoint(client, device, 9999) + assertEquals(HttpStatusCode.NotFound, ackResp.status) + } + + @Test + fun `all devices acknowledge - ops pruned after pruneAcknowledgedOps`() = testApplication { + // Use small checkpoint interval for testing + val config = testConfig().copy(checkpointInterval = 10) + application { module(config) } + val client = createJsonClient() + + val (deviceA, deviceB) = TestHelper.setupTwoDeviceBucket(client) + val bucketId = deviceA.bucketId!! + + // Upload 30 ops to create 3 checkpoints (interval=10) + var prevHash = "0".repeat(64) + prevHash = uploadOpsBatch(client, deviceA, 10, startPrevHash = prevHash, localIdPrefix = "a1") + prevHash = uploadOpsBatch(client, deviceA, 10, startPrevHash = prevHash, localIdPrefix = "a2") + uploadOpsBatch(client, deviceA, 10, startPrevHash = prevHash, localIdPrefix = "a3") + + // Verify 30 ops exist + assertEquals(30, getOpsCount(client, deviceA)) + + // Both devices acknowledge checkpoints 1, 2, 3 + for (cpId in 1..3) { + acknowledgeCheckpoint(client, deviceA, cpId) + acknowledgeCheckpoint(client, deviceB, cpId) + } + + // Trigger pruning via the sync service + val syncService = dev.kidsync.server.services.SyncService(config) + val pruned = syncService.pruneAcknowledgedOps(bucketId) + + // Should have pruned ops from checkpoints 1 and 2 (not 3, safety margin) + // That's 20 ops pruned + assertEquals(20, pruned) + + // Remaining ops should be 10 + assertEquals(10, getOpsCount(client, deviceA)) + } + + @Test + fun `partial acknowledgment - no pruning`() = testApplication { + val config = testConfig().copy(checkpointInterval = 10) + application { module(config) } + val client = createJsonClient() + + val (deviceA, deviceB) = TestHelper.setupTwoDeviceBucket(client) + val bucketId = deviceA.bucketId!! + + // Upload 20 ops to create 2 checkpoints + var prevHash = "0".repeat(64) + prevHash = uploadOpsBatch(client, deviceA, 10, startPrevHash = prevHash, localIdPrefix = "a1") + uploadOpsBatch(client, deviceA, 10, startPrevHash = prevHash, localIdPrefix = "a2") + + assertEquals(20, getOpsCount(client, deviceA)) + + // Only device A acknowledges - device B hasn't acknowledged + acknowledgeCheckpoint(client, deviceA, 1) + acknowledgeCheckpoint(client, deviceA, 2) + + // Trigger pruning + val syncService = dev.kidsync.server.services.SyncService(config) + val pruned = syncService.pruneAcknowledgedOps(bucketId) + + // No pruning because B hasn't acknowledged + assertEquals(0, pruned) + assertEquals(20, getOpsCount(client, deviceA)) + } + + @Test + fun `revoked device does not block pruning`() = testApplication { + val config = testConfig().copy(checkpointInterval = 10) + application { module(config) } + val client = createJsonClient() + + val (deviceA, deviceB) = TestHelper.setupTwoDeviceBucket(client) + val bucketId = deviceA.bucketId!! + + // Upload 20 ops to create 2 checkpoints + var prevHash = "0".repeat(64) + prevHash = uploadOpsBatch(client, deviceA, 10, startPrevHash = prevHash, localIdPrefix = "a1") + uploadOpsBatch(client, deviceA, 10, startPrevHash = prevHash, localIdPrefix = "a2") + + // Revoke device B + client.delete("/buckets/$bucketId/devices/${deviceB.deviceId}") { + header(HttpHeaders.Authorization, "Bearer ${deviceA.sessionToken}") + } + + // Only device A needs to acknowledge (B is revoked) + acknowledgeCheckpoint(client, deviceA, 1) + acknowledgeCheckpoint(client, deviceA, 2) + + // Trigger pruning + val syncService = dev.kidsync.server.services.SyncService(config) + val pruned = syncService.pruneAcknowledgedOps(bucketId) + + // Checkpoint 1 pruned (checkpoint 2 preserved as safety margin) + assertEquals(10, pruned) + } + + @Test + fun `latest checkpoint ops preserved as safety margin`() = testApplication { + val config = testConfig().copy(checkpointInterval = 10) + application { module(config) } + val client = createJsonClient() + + val device = TestHelper.setupDeviceWithBucket(client) + val bucketId = device.bucketId!! + + // Upload 10 ops (1 checkpoint) + uploadOpsBatch(client, device, 10) + + // Acknowledge the only checkpoint + acknowledgeCheckpoint(client, device, 1) + + // Trigger pruning + val syncService = dev.kidsync.server.services.SyncService(config) + val pruned = syncService.pruneAcknowledgedOps(bucketId) + + // Should NOT prune - the latest fully-acknowledged checkpoint is preserved + assertEquals(0, pruned) + assertEquals(10, getOpsCount(client, device)) + } +} diff --git a/server/src/test/kotlin/dev/kidsync/server/SnapshotDownloadTest.kt b/server/src/test/kotlin/dev/kidsync/server/SnapshotDownloadTest.kt new file mode 100644 index 0000000..9c5097a --- /dev/null +++ b/server/src/test/kotlin/dev/kidsync/server/SnapshotDownloadTest.kt @@ -0,0 +1,169 @@ +package dev.kidsync.server + +import dev.kidsync.server.TestHelper.createJsonClient +import dev.kidsync.server.TestHelper.uploadOpsChain +import dev.kidsync.server.models.BucketResponse +import dev.kidsync.server.models.CreateBucketRequest +import dev.kidsync.server.models.SnapshotMetadata +import dev.kidsync.server.models.UploadSnapshotResponse +import io.ktor.client.call.body +import io.ktor.client.request.forms.formData +import io.ktor.client.request.forms.submitFormWithBinaryData +import io.ktor.client.request.get +import io.ktor.client.request.header +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.client.statement.HttpResponse +import io.ktor.client.statement.readRawBytes +import io.ktor.http.ContentType +import io.ktor.http.Headers +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.server.testing.testApplication +import kotlinx.serialization.json.Json +import org.junit.jupiter.api.Test +import java.security.MessageDigest +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +/** + * Integration tests for SEC3-S-06: Snapshot download endpoint. + */ +class SnapshotDownloadTest { + + private val encoder = java.util.Base64.getEncoder() + + private fun createSnapshotData(content: String = "snapshot-content-${System.nanoTime()}"): Triple { + val data = content.toByteArray() + val sha256 = MessageDigest.getInstance("SHA-256") + .digest(data) + .joinToString("") { "%02x".format(it) } + val signature = encoder.encodeToString("test-signature".toByteArray()) + return Triple(data, sha256, signature) + } + + private suspend fun uploadSnapshot( + client: io.ktor.client.HttpClient, + device: TestDevice, + atSequence: Long, + snapshotData: Triple? = null, + ): Pair> { + val bucketId = device.bucketId!! + val dataTriple = snapshotData ?: createSnapshotData() + val (data, sha256, signature) = dataTriple + val metadata = Json.encodeToString( + SnapshotMetadata.serializer(), + SnapshotMetadata(atSequence = atSequence, keyEpoch = 1, sha256 = sha256, signature = signature) + ) + + val response = client.submitFormWithBinaryData( + url = "/buckets/$bucketId/snapshots", + formData = formData { + append("metadata", metadata) + append("snapshot", data, Headers.build { + append(HttpHeaders.ContentType, "application/octet-stream") + append(HttpHeaders.ContentDisposition, "filename=\"snapshot.bin\"") + }) + } + ) { + header(HttpHeaders.Authorization, "Bearer ${device.sessionToken}") + } + + return Pair(response, dataTriple) + } + + @Test + fun `download matches upload content`() = testApplication { + application { module(testConfig()) } + val client = createJsonClient() + + val device = TestHelper.setupDeviceWithBucket(client) + uploadOpsChain(client, device, 1) + + val snapshotData = createSnapshotData("test-snapshot-binary-data") + val (uploadResp, _) = uploadSnapshot(client, device, 1, snapshotData) + assertEquals(HttpStatusCode.Created, uploadResp.status) + val uploadBody = uploadResp.body() + val snapshotId = uploadBody.snapshotId + + // Download + val downloadResp = client.get("/buckets/${device.bucketId}/snapshots/$snapshotId/download") { + header(HttpHeaders.Authorization, "Bearer ${device.sessionToken}") + } + assertEquals(HttpStatusCode.OK, downloadResp.status) + + val downloadedBytes = downloadResp.readRawBytes() + assertTrue(snapshotData.first.contentEquals(downloadedBytes), + "Downloaded content should match uploaded content") + + // Verify SHA-256 header + val sha256Header = downloadResp.headers["X-Snapshot-SHA256"] + assertNotNull(sha256Header) + assertEquals(snapshotData.second, sha256Header) + } + + @Test + fun `wrong bucket returns 403`() = testApplication { + application { module(testConfig()) } + val client = createJsonClient() + + // Device A creates bucket and uploads snapshot + val deviceA = TestHelper.setupDeviceWithBucket(client) + uploadOpsChain(client, deviceA, 1) + val (uploadResp, _) = uploadSnapshot(client, deviceA, 1) + assertEquals(HttpStatusCode.Created, uploadResp.status) + val snapshotId = uploadResp.body().snapshotId + + // Device A creates a second bucket + val bucket2Resp = client.post("/buckets") { + contentType(ContentType.Application.Json) + header(HttpHeaders.Authorization, "Bearer ${deviceA.sessionToken}") + setBody(CreateBucketRequest()) + } + assertEquals(HttpStatusCode.Created, bucket2Resp.status) + val bucket2Id = bucket2Resp.body().bucketId + + // Try to download snapshot from wrong bucket + val downloadResp = client.get("/buckets/$bucket2Id/snapshots/$snapshotId/download") { + header(HttpHeaders.Authorization, "Bearer ${deviceA.sessionToken}") + } + assertEquals(HttpStatusCode.Forbidden, downloadResp.status) + } + + @Test + fun `non-existent snapshot returns 404`() = testApplication { + application { module(testConfig()) } + val client = createJsonClient() + + val device = TestHelper.setupDeviceWithBucket(client) + val fakeSnapshotId = java.util.UUID.randomUUID().toString() + + val downloadResp = client.get("/buckets/${device.bucketId}/snapshots/$fakeSnapshotId/download") { + header(HttpHeaders.Authorization, "Bearer ${device.sessionToken}") + } + assertEquals(HttpStatusCode.NotFound, downloadResp.status) + } + + @Test + fun `device without bucket access gets 403`() = testApplication { + application { module(testConfig()) } + val client = createJsonClient() + + val device = TestHelper.setupDeviceWithBucket(client) + uploadOpsChain(client, device, 1) + val (uploadResp, _) = uploadSnapshot(client, device, 1) + assertEquals(HttpStatusCode.Created, uploadResp.status) + val snapshotId = uploadResp.body().snapshotId + + // Create an outsider device + val outsiderReg = TestHelper.registerDevice(client) + val outsider = TestHelper.authenticateDevice(client, outsiderReg) + + val downloadResp = client.get("/buckets/${device.bucketId}/snapshots/$snapshotId/download") { + header(HttpHeaders.Authorization, "Bearer ${outsider.sessionToken}") + } + assertEquals(HttpStatusCode.Forbidden, downloadResp.status) + } +} diff --git a/server/src/test/kotlin/dev/kidsync/server/WebSocketQueryParamAuthTest.kt b/server/src/test/kotlin/dev/kidsync/server/WebSocketQueryParamAuthTest.kt new file mode 100644 index 0000000..2082caa --- /dev/null +++ b/server/src/test/kotlin/dev/kidsync/server/WebSocketQueryParamAuthTest.kt @@ -0,0 +1,110 @@ +package dev.kidsync.server + +import dev.kidsync.server.TestHelper.createJsonClient +import io.ktor.client.plugins.websocket.WebSockets +import io.ktor.client.plugins.websocket.webSocket +import io.ktor.server.testing.testApplication +import io.ktor.websocket.Frame +import io.ktor.websocket.readText +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive +import kotlinx.serialization.json.put +import org.junit.jupiter.api.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +/** + * Integration tests for SEC5-S-03: WebSocket query parameter authentication. + */ +class WebSocketQueryParamAuthTest { + + private val json = Json { ignoreUnknownKeys = true } + + /** + * Parse the auth response and verify it contains deviceId + bucketId (auth_ok structure). + * Note: the "type" field may be omitted by the server's JSON encoder when encodeDefaults=false, + * so we verify the presence of deviceId + bucketId fields as the auth_ok indicator. + */ + private fun assertAuthOk(text: String, expectedDeviceId: String, expectedBucketId: String) { + val obj = json.parseToJsonElement(text).jsonObject + assertEquals(expectedDeviceId, obj["deviceId"]?.jsonPrimitive?.content, + "deviceId mismatch in response: $text") + assertEquals(expectedBucketId, obj["bucketId"]?.jsonPrimitive?.content, + "bucketId mismatch in response: $text") + assertNotNull(obj["latestSequence"], "latestSequence missing in response: $text") + } + + @Test + fun `valid query param token authenticates successfully`() = testApplication { + application { module(testConfig()) } + val httpClient = createJsonClient() + + val device = TestHelper.setupDeviceWithBucket(httpClient) + val bucketId = device.bucketId!! + + val wsClient = createClient { + install(WebSockets) + } + + wsClient.webSocket("/buckets/$bucketId/ws?token=${device.sessionToken}") { + // Should receive auth_ok without needing to send an auth message + val frame = incoming.receive() + assertTrue(frame is Frame.Text, "Expected Text frame, got ${frame.frameType}") + assertAuthOk((frame as Frame.Text).readText(), device.deviceId, bucketId) + } + } + + @Test + fun `invalid query param token closes with 4001`() = testApplication { + application { module(testConfig()) } + val httpClient = createJsonClient() + + val device = TestHelper.setupDeviceWithBucket(httpClient) + val bucketId = device.bucketId!! + + val wsClient = createClient { + install(WebSockets) + } + + wsClient.webSocket("/buckets/$bucketId/ws?token=sess_invalidtoken12345") { + // Should receive close with code 4001 + val reason = closeReason.await() + assertNotNull(reason) + assertEquals(4001, reason.code) + } + } + + @Test + fun `absent query param falls through to in-band auth`() = testApplication { + application { module(testConfig()) } + val httpClient = createJsonClient() + + val device = TestHelper.setupDeviceWithBucket(httpClient) + val bucketId = device.bucketId!! + + val wsClient = createClient { + install(WebSockets) + } + + wsClient.webSocket("/buckets/$bucketId/ws") { + // Send in-band auth message (existing behavior) + val authMsg = json.encodeToString( + JsonObject.serializer(), + buildJsonObject { + put("type", "auth") + put("token", device.sessionToken) + } + ) + send(Frame.Text(authMsg)) + + // Should receive auth_ok + val frame = incoming.receive() + assertTrue(frame is Frame.Text, "Expected Text frame, got ${frame.frameType}") + assertAuthOk((frame as Frame.Text).readText(), device.deviceId, bucketId) + } + } +}