diff --git a/Sources/SpeakerKit/DiarizationResult.swift b/Sources/SpeakerKit/DiarizationResult.swift index 0ad0e700..2422ddc8 100644 --- a/Sources/SpeakerKit/DiarizationResult.swift +++ b/Sources/SpeakerKit/DiarizationResult.swift @@ -31,26 +31,56 @@ public struct DiarizationResult: Sendable { public private(set) var segments: [SpeakerSegment] public var timings: (any DiarizationTimings)? + /// Per-speaker centroid embeddings keyed by `speakerId`, in the raw speaker-embedder output + /// space (unnormalised, pre-PLDA). Useful for linking the same speaker across independent + /// `diarize(...)` calls without re-running the embedder. + /// + /// Each centroid is the mean of the per-window embeddings that ended up under that + /// `speakerId`. Which embeddings contribute is controlled by + /// `PyannoteDiarizationOptions.centroidSource`; see ``SpeakerCentroidSource``. Under + /// `.trainableOnly`, some speakers in `segments` may not have a centroid in this + /// map; use `if let centroid = speakerCentroidEmbeddings[id]` rather than assuming + /// the key is present. + /// + /// Compare centroids with cosine distance via `centroidCosineDistance(between:and:)` or + /// `nearestSpeakerCentroid(to:)`, matching the convention used by + /// `MathOps.cosineDistanceMatrix` elsewhere in SpeakerKit. SpeakerKit does not define a + /// universal "same speaker" threshold for comparing centroids across independent runs; + /// callers should calibrate that policy for their model, audio, and application. + /// + /// This field is populated by the Pyannote backend (`PyannoteDiarizer`). Other backends + /// conforming to `Diarizer` may leave it as `[:]` if they do not expose per-cluster centroids. + public private(set) var speakerCentroidEmbeddings: [Int: [Float]] + /// Pyannote init: builds segments from binary speaker activity matrix - init(binaryMatrix: [[Int]], diarizationFrameRate: Float) { + init(binaryMatrix: [[Int]], diarizationFrameRate: Float, speakerCentroidEmbeddings: [Int: [Float]] = [:]) { self.binaryMatrix = binaryMatrix self.frameRate = diarizationFrameRate self.speakerCount = binaryMatrix.count self.totalFrames = speakerCount > 0 ? binaryMatrix[0].count : 0 self.segments = [] self.timings = nil + self.speakerCentroidEmbeddings = speakerCentroidEmbeddings self.updateSegments(minActiveOffset: 0.0) } /// Generic init: for engines that produce segments directly - public init(speakerCount: Int, totalFrames: Int, frameRate: Float, segments: [SpeakerSegment], timings: (any DiarizationTimings)? = nil) { + public init( + speakerCount: Int, + totalFrames: Int, + frameRate: Float, + segments: [SpeakerSegment], + timings: (any DiarizationTimings)? = nil, + speakerCentroidEmbeddings: [Int: [Float]] = [:] + ) { self.binaryMatrix = [] self.speakerCount = speakerCount self.totalFrames = totalFrames self.frameRate = frameRate self.segments = segments self.timings = timings + self.speakerCentroidEmbeddings = speakerCentroidEmbeddings } public mutating func updateSegments(minActiveOffset: Float) { @@ -101,6 +131,52 @@ public struct DiarizationResult: Sendable { self.segments = segments.sorted { $0.startFrame < $1.startFrame } } + // MARK: - Speaker Centroid Comparison + + /// Cosine distance in `[0.0, 2.0]` between two speaker centroids from this result. + /// + /// Delegates to `MathOps.cosineDistance(_:_:)`, matching the convention used by + /// `MathOps.cosineDistanceMatrix` elsewhere in SpeakerKit. The result is clamped to + /// `[0, 2]` to absorb floating-point error near the extremes. A distance of `0` means + /// identical direction, `1` means orthogonal vectors (no directional similarity), and + /// `2` means opposite direction. + /// + /// - Returns: `nil` if either `speakerId` is absent from + /// ``speakerCentroidEmbeddings``, the centroids have different dimensions, or either + /// vector is empty. Zero-magnitude centroids (unreachable in real diarization runs) + /// yield `MathOps.cosineDistance`'s sentinel of `1.0`. + public func centroidCosineDistance(between a: Int, and b: Int) -> Float? { + guard let lhs = speakerCentroidEmbeddings[a], + let rhs = speakerCentroidEmbeddings[b], + lhs.count == rhs.count, !lhs.isEmpty else { return nil } + return MathOps.cosineDistance(lhs, rhs) + } + + /// Nearest centroid in this result to an external speaker embedding. + /// + /// This is a pure nearest-neighbour lookup over ``speakerCentroidEmbeddings``. It does not + /// apply a same-speaker threshold; callers should interpret the returned distance according + /// to their own calibration. Ties resolve deterministically to the lowest `speakerId`. + /// + /// - Returns: The nearest compatible centroid, or `nil` when `embedding` is empty, no + /// centroid exists, or all stored centroids have different dimensions. + public func nearestSpeakerCentroid(to embedding: [Float]) -> (speakerId: Int, distance: Float)? { + guard !embedding.isEmpty else { return nil } + + var nearest: (speakerId: Int, distance: Float)? + for speakerId in speakerCentroidEmbeddings.keys.sorted() { + guard let centroid = speakerCentroidEmbeddings[speakerId], centroid.count == embedding.count else { + continue + } + let distance = MathOps.cosineDistance(embedding, centroid) + if nearest == nil || distance < (nearest?.distance ?? .infinity) { + nearest = (speakerId, distance) + } + } + + return nearest + } + // MARK: - Speaker Info Matching public func addSpeakerInfo(to transcription: [TranscriptionResult], strategy: SpeakerInfoStrategy = SpeakerInfoStrategy.subsegment) -> [[SpeakerSegment]] { diff --git a/Sources/SpeakerKit/Pyannote/PyannoteConfig.swift b/Sources/SpeakerKit/Pyannote/PyannoteConfig.swift index cac1d385..c7d65ae7 100644 --- a/Sources/SpeakerKit/Pyannote/PyannoteConfig.swift +++ b/Sources/SpeakerKit/Pyannote/PyannoteConfig.swift @@ -119,12 +119,26 @@ public class PyannoteConfig: SpeakerKitConfig, @unchecked Sendable { // MARK: - Diarization Options +public enum SpeakerCentroidSource: Equatable, Hashable, Sendable { + /// Mean of all embeddings under the final post-reassignment speaker labels, without + /// any filtering for quality of embeddings. + case finalAssignment + + /// Mean of embeddings under the final post-reassignment speaker labels, with + /// additional filtering by the clustering algorithm for purer voice embeddings. + /// May omit speakers whose members are entirely excluded by the filter. Use + /// `if let centroid = speakerCentroidEmbeddings[id]` rather than assuming the + /// key is present. + case trainableOnly +} + public struct PyannoteDiarizationOptions: DiarizationOptions { public var numberOfSpeakers: Int? public var minActiveOffset: Float? public var clusterDistanceThreshold: Float? public var minClusterSize: Int? public var useExclusiveReconciliation: Bool + public var centroidSource: SpeakerCentroidSource /// Optional seek boundaries in seconds; pairs define [start, end] clips. Empty means process full audio. public var clipTimestamps: [Float] @@ -134,6 +148,7 @@ public struct PyannoteDiarizationOptions: DiarizationOptions { clusterDistanceThreshold: Float? = nil, minClusterSize: Int? = nil, useExclusiveReconciliation: Bool = true, + centroidSource: SpeakerCentroidSource = .finalAssignment, clipTimestamps: [Float] = [] ) { self.numberOfSpeakers = numberOfSpeakers @@ -141,6 +156,7 @@ public struct PyannoteDiarizationOptions: DiarizationOptions { self.clusterDistanceThreshold = clusterDistanceThreshold self.minClusterSize = minClusterSize self.useExclusiveReconciliation = useExclusiveReconciliation + self.centroidSource = centroidSource self.clipTimestamps = clipTimestamps } } @@ -208,4 +224,3 @@ public struct PyannoteDiarizationTimings: DiarizationTimings, CustomStringConver """ } } - diff --git a/Sources/SpeakerKit/Pyannote/PyannoteDiarizer.swift b/Sources/SpeakerKit/Pyannote/PyannoteDiarizer.swift index d825dc80..8293ec3b 100644 --- a/Sources/SpeakerKit/Pyannote/PyannoteDiarizer.swift +++ b/Sources/SpeakerKit/Pyannote/PyannoteDiarizer.swift @@ -255,6 +255,7 @@ actor PyannoteDiarizerActor { progressObj.completedUnitCount = 80 progressCallback?(progressObj) var diarizationResult = postProcess(speakerEmbeddings: clusteringResult.speakerEmbeddings, + speakerCentroids: clusteringResult.speakerCentroids, originalLength: audioLength, useExclusiveReconciliation: resolvedOptions.useExclusiveReconciliation) timings.numberOfSpeakers = diarizationResult.speakerCount @@ -268,7 +269,10 @@ actor PyannoteDiarizerActor { return diarizationResult } - private func postProcess(speakerEmbeddings: [SpeakerEmbedding], originalLength: Int, useExclusiveReconciliation: Bool) -> DiarizationResult { + private func postProcess(speakerEmbeddings: [SpeakerEmbedding], + speakerCentroids: [Int: [Float]], + originalLength: Int, + useExclusiveReconciliation: Bool) -> DiarizationResult { let startTime = CFAbsoluteTimeGetCurrent() defer { let totalTime = (CFAbsoluteTimeGetCurrent() - startTime) * 1_000 @@ -360,7 +364,7 @@ actor PyannoteDiarizerActor { } } - return DiarizationResult(binaryMatrix: binaryDiarization, diarizationFrameRate: diarizationFrameRate) + return DiarizationResult(binaryMatrix: binaryDiarization, diarizationFrameRate: diarizationFrameRate, speakerCentroidEmbeddings: speakerCentroids) } func diarize(audioArray: [Float], options: (any DiarizationOptions)?, progressCallback: (@Sendable (Progress) -> Void)?) async throws -> DiarizationResult { diff --git a/Sources/SpeakerKit/Pyannote/SpeakerClustering.swift b/Sources/SpeakerKit/Pyannote/SpeakerClustering.swift index 6b6da6c7..bbe957e5 100644 --- a/Sources/SpeakerKit/Pyannote/SpeakerClustering.swift +++ b/Sources/SpeakerKit/Pyannote/SpeakerClustering.swift @@ -12,6 +12,7 @@ struct VBxClusteringConfig: Sendable { let maxIterations: Int let initialSmoothingFactor: Float let numSpeakers: Int? + let centroidSource: SpeakerCentroidSource private static let defaultThreshold: Float = 0.6 @@ -22,7 +23,8 @@ struct VBxClusteringConfig: Sendable { minActiveRatio: Float = 0.2, maxIterations: Int = 20, initialSmoothingFactor: Float = 7.0, - numSpeakers: Int? = nil) { + numSpeakers: Int? = nil, + centroidSource: SpeakerCentroidSource = .finalAssignment) { self.threshold = threshold self.speakerRelevanceFactorA = speakerRelevanceFactorA self.speakerRelevanceFactorB = speakerRelevanceFactorB @@ -31,12 +33,14 @@ struct VBxClusteringConfig: Sendable { self.maxIterations = maxIterations self.initialSmoothingFactor = initialSmoothingFactor self.numSpeakers = numSpeakers + self.centroidSource = centroidSource } init(from options: PyannoteDiarizationOptions) { self.init( threshold: options.clusterDistanceThreshold ?? Self.defaultThreshold, - numSpeakers: options.numberOfSpeakers + numSpeakers: options.numberOfSpeakers, + centroidSource: options.centroidSource ) } } @@ -44,11 +48,14 @@ struct VBxClusteringConfig: Sendable { struct ClusteringResult { let clusterIndices: [Int] let speakerEmbeddings: [SpeakerEmbedding] + let speakerCentroids: [Int: [Float]] init(clusterIndices: [Int], - speakerEmbeddings: [SpeakerEmbedding]) { + speakerEmbeddings: [SpeakerEmbedding], + speakerCentroids: [Int: [Float]] = [:]) { self.clusterIndices = clusterIndices self.speakerEmbeddings = speakerEmbeddings + self.speakerCentroids = speakerCentroids } } diff --git a/Sources/SpeakerKit/Pyannote/VBxClustering.swift b/Sources/SpeakerKit/Pyannote/VBxClustering.swift index 8506e005..4b069030 100644 --- a/Sources/SpeakerKit/Pyannote/VBxClustering.swift +++ b/Sources/SpeakerKit/Pyannote/VBxClustering.swift @@ -22,7 +22,7 @@ actor VBxClustering: Clusterer { _speakerEmbeddings.sort { ($0.windowIndex, $0.speakerIndex) < ($1.windowIndex, $1.speakerIndex) } - let (clusters, _) = cluster(embeddings: _speakerEmbeddings, config: config) + let (clusters, _, centroids) = cluster(embeddings: _speakerEmbeddings, config: config) for (clusterIndex, clusterId) in clusters.enumerated() { _speakerEmbeddings[clusterIndex].clusterId = clusterId @@ -30,7 +30,8 @@ actor VBxClustering: Clusterer { return ClusteringResult( clusterIndices: clusters, - speakerEmbeddings: _speakerEmbeddings + speakerEmbeddings: _speakerEmbeddings, + speakerCentroids: centroids ) } @@ -45,9 +46,10 @@ actor VBxClustering: Clusterer { func cluster( embeddings: [SpeakerEmbedding], config: VBxClusteringConfig - ) -> (clusters: [Int], linkageMatrix: [[Float]]) { + ) -> (clusters: [Int], linkageMatrix: [[Float]], centroids: [Int: [Float]]) { let trainableEmbeddings = embeddings.filter { $0.nonOverlappedFrameRatio > config.minActiveRatio } let embeddingsFloats = trainableEmbeddings.map { $0.embedding } + let allEmbeddingsFloats = embeddings.map { $0.embedding } let pldaEmbeddingsFloats = trainableEmbeddings.map { $0.pldaEmbedding ?? [] } @@ -88,6 +90,7 @@ actor VBxClustering: Clusterer { let clusterAssignments = speakerWeights.isEmpty ? clusters : MathOps.argmax(speakerWeights, axis: 0) var centroids = calculateCentroids(speakerWeights: speakerWeights, embeddings: embeddingsFloats) + // These centroids seed cluster reassignment; returned centroids are recomputed below. let autoSpeakerCount = centroids.count Logging.debug("VBx clustering completed with \(autoSpeakerCount) speakers") @@ -97,11 +100,14 @@ actor VBxClustering: Clusterer { if let requestedSpeakers = config.numSpeakers, autoSpeakerCount != requestedSpeakers { Logging.debug("K-Means correction: VBx gave \(autoSpeakerCount) speakers, requested \(requestedSpeakers)") let kAssignments = ClusterAlgorithms.kMeans(embeddings: embeddingsNormalized, clusterCount: requestedSpeakers) - centroids = centroidsFromAssignments(assignments: kAssignments, embeddings: embeddingsFloats, k: requestedSpeakers) + centroids = centroidsFromAssignments( + assignments: kAssignments, + embeddings: embeddingsFloats, + clusterCount: requestedSpeakers + ) } if !centroids.isEmpty { - let allEmbeddingsFloats = embeddings.map { $0.embedding } clusters = clusterReassignment(embeddings: allEmbeddingsFloats, centroids: centroids) Logging.debug("Cluster reassignment completed") } else { @@ -109,11 +115,14 @@ actor VBxClustering: Clusterer { // from those AHC assignments so clusterReassignment can cover all N embeddings to match path above. let numClusters = (clusterAssignments.max() ?? -1) + 1 let fallbackCentroids = numClusters > 0 - ? centroidsFromAssignments(assignments: clusterAssignments, embeddings: embeddingsFloats, k: numClusters) + ? centroidsFromAssignments( + assignments: clusterAssignments, + embeddings: embeddingsFloats, + clusterCount: numClusters + ) : [] if !fallbackCentroids.isEmpty { - let allEmbeddingsFloats = embeddings.map { $0.embedding } clusters = clusterReassignment(embeddings: allEmbeddingsFloats, centroids: fallbackCentroids) Logging.debug("Cluster reassignment from AHC fallback completed") } else { @@ -122,7 +131,17 @@ actor VBxClustering: Clusterer { } } - return (clusters, linkageMatrix) + // Returned centroids use final assignments and the caller-selected embedding source, + // uniform across all paths (VBx weighted, kMeans correction, AHC fallback). Empty + // clusters (no surviving members under .trainableOnly) are not keyed. + let finalCentroids = centroidsFromFinalAssignments( + assignments: clusters, + embeddings: embeddings, + source: config.centroidSource, + minActiveRatio: config.minActiveRatio + ) + + return (clusters, linkageMatrix, finalCentroids) } // MARK: - Internal Methods @@ -176,24 +195,63 @@ actor VBxClustering: Clusterer { return clusterIndices } - private func centroidsFromAssignments(assignments: [Int], embeddings: [[Float]], k: Int) -> [[Float]] { + func centroidsFromAssignments(assignments: [Int], embeddings: [[Float]], clusterCount: Int) -> [[Float]] { guard !embeddings.isEmpty, !embeddings[0].isEmpty else { return [] } let dim = embeddings[0].count - var sums = Array(repeating: Array(repeating: Float(0), count: dim), count: k) - var counts = Array(repeating: 0, count: k) + var sums = Array(repeating: Array(repeating: Float(0), count: dim), count: clusterCount) + var counts = Array(repeating: 0, count: clusterCount) for (i, assignment) in assignments.enumerated() { guard i < embeddings.count else { continue } counts[assignment] += 1 for d in 0.. 0 else { return sums[ki] } return sums[ki].map { $0 / Float(count) } } } - private func calculateCentroids(speakerWeights: [[Float]], embeddings: [[Float]]) -> [[Float]] { + /// Mean-pools embeddings under final post-reassignment labels, keyed by cluster id. + /// Empty clusters (no surviving members after the `.trainableOnly` filter) never get a key. + func centroidsFromFinalAssignments( + assignments: [Int], + embeddings: [SpeakerEmbedding], + source: SpeakerCentroidSource, + minActiveRatio: Float + ) -> [Int: [Float]] { + var sums: [Int: [Float]] = [:] + var counts: [Int: Int] = [:] + for (index, speakerEmbedding) in embeddings.enumerated() { + guard index < assignments.count else { continue } + let assignment = assignments[index] + guard assignment >= 0 else { continue } + if source == .trainableOnly, speakerEmbedding.nonOverlappedFrameRatio <= minActiveRatio { + continue + } + let embedding = speakerEmbedding.embedding + guard !embedding.isEmpty else { continue } + if var existing = sums[assignment] { + guard existing.count == embedding.count else { continue } + for d in 0.. 0 else { continue } + result[clusterId] = sum.map { $0 / Float(count) } + } + return result + } + + func calculateCentroids(speakerWeights: [[Float]], embeddings: [[Float]]) -> [[Float]] { guard !speakerWeights.isEmpty, !embeddings.isEmpty, !embeddings[0].isEmpty else { return [] } diff --git a/Tests/SpeakerKitTests/SpeakerCentroidEmbeddingsTests.swift b/Tests/SpeakerKitTests/SpeakerCentroidEmbeddingsTests.swift new file mode 100644 index 00000000..8cb8d885 --- /dev/null +++ b/Tests/SpeakerKitTests/SpeakerCentroidEmbeddingsTests.swift @@ -0,0 +1,504 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2026 Argmax, Inc. All rights reserved. + +import XCTest +import WhisperKit +@testable import SpeakerKit + +final class SpeakerCentroidEmbeddingsTests: XCTestCase { + + // MARK: - Helpers + + private func loadAudio(named name: String, extension ext: String = "wav") throws -> [Float] { + guard let url = Bundle.module.url(forResource: name, withExtension: ext) else { + throw XCTSkip("Audio file \(name).\(ext) not found in test bundle") + } + let audioBuffer = try AudioProcessor.loadAudio(fromPath: url.path) + return AudioProcessor.convertBufferToArray(buffer: audioBuffer) + } + + private func assertVectorsEqual( + _ actual: [Float], + _ expected: [Float], + accuracy: Float = 1e-5, + file: StaticString = #filePath, + line: UInt = #line + ) { + XCTAssertEqual(actual.count, expected.count, "Vector lengths differ", file: file, line: line) + for (i, (a, e)) in zip(actual, expected).enumerated() { + XCTAssertEqual(a, e, accuracy: accuracy, "Mismatch at index \(i)", file: file, line: line) + } + } + + // MARK: - Unit tests: calculateCentroids (main VBx path) + + /// With one-hot responsibility per speaker, the weighted mean collapses to the + /// arithmetic mean of the embeddings owned by that speaker. + func testCalculateCentroids_uniformWeightsEqualsArithmeticMean() async { + let embeddings: [[Float]] = [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0] + ] + // speakerWeights[s][e] -- 3 speakers owning embeddings 0, 1, and (2+3) + let speakerWeights: [[Float]] = [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0] + ] + + let clusterer = VBxClustering() + let centroids = await clusterer.calculateCentroids( + speakerWeights: speakerWeights, + embeddings: embeddings + ) + + XCTAssertEqual(centroids.count, 3) + assertVectorsEqual(centroids[0], embeddings[0]) + assertVectorsEqual(centroids[1], embeddings[1]) + let expectedMix = zip(embeddings[2], embeddings[3]).map { ($0 + $1) / 2 } + assertVectorsEqual(centroids[2], expectedMix) + } + + /// Fractional responsibility weights must produce sum(w_i * x_i) / sum(w_i). + func testCalculateCentroids_weightedMean() async { + let embeddings: [[Float]] = [ + [2.0, 4.0], + [6.0, 8.0] + ] + let speakerWeights: [[Float]] = [ + [0.7, 0.3], + [0.1, 0.9] + ] + + let clusterer = VBxClustering() + let centroids = await clusterer.calculateCentroids( + speakerWeights: speakerWeights, + embeddings: embeddings + ) + + XCTAssertEqual(centroids.count, 2) + // speaker 0: (0.7*2 + 0.3*6) / 1.0 = 3.2, (0.7*4 + 0.3*8) / 1.0 = 5.2 + assertVectorsEqual(centroids[0], [3.2, 5.2]) + // speaker 1: (0.1*2 + 0.9*6) / 1.0 = 5.6, (0.1*4 + 0.9*8) / 1.0 = 7.6 + assertVectorsEqual(centroids[1], [5.6, 7.6]) + } + + /// A speaker with zero total weight must stay at the helper's sentinel (zeros) and + /// must not divide by zero or crash. + func testCalculateCentroids_zeroWeightSpeakerIsSkipped() async { + let embeddings: [[Float]] = [ + [1.0, 1.0, 1.0], + [2.0, 2.0, 2.0] + ] + let speakerWeights: [[Float]] = [ + [0.5, 0.5], + [0.0, 0.0] + ] + + let clusterer = VBxClustering() + let centroids = await clusterer.calculateCentroids( + speakerWeights: speakerWeights, + embeddings: embeddings + ) + + XCTAssertEqual(centroids.count, 2) + assertVectorsEqual(centroids[0], [1.5, 1.5, 1.5]) + assertVectorsEqual(centroids[1], [0.0, 0.0, 0.0]) + } + + // MARK: - Unit tests: centroidsFromAssignments (kMeans + AHC fallback paths) + + /// Arithmetic mean per cluster for a simple two-cluster partition. + func testCentroidsFromAssignments_arithmeticMean() async { + let dim = 8 + let embeddings: [[Float]] = (0..<6).map { i in + Array(repeating: Float(i), count: dim) + } + let assignments = [0, 0, 0, 1, 1, 1] + + let clusterer = VBxClustering() + let centroids = await clusterer.centroidsFromAssignments( + assignments: assignments, + embeddings: embeddings, + clusterCount: 2 + ) + + XCTAssertEqual(centroids.count, 2) + assertVectorsEqual(centroids[0], Array(repeating: 1.0, count: dim)) + assertVectorsEqual(centroids[1], Array(repeating: 4.0, count: dim)) + } + + /// A cluster with exactly one embedding must return that embedding as its centroid. + func testCentroidsFromAssignments_singletonCluster() async { + let embeddings: [[Float]] = [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0] + ] + let assignments = [0, 1, 0] + + let clusterer = VBxClustering() + let centroids = await clusterer.centroidsFromAssignments( + assignments: assignments, + embeddings: embeddings, + clusterCount: 2 + ) + + XCTAssertEqual(centroids.count, 2) + assertVectorsEqual(centroids[0], [4.0, 5.0, 6.0]) + assertVectorsEqual(centroids[1], [4.0, 5.0, 6.0]) + } + + /// A cluster id with no members must yield a zero vector, not NaN or crash. + func testCentroidsFromAssignments_emptyCluster() async { + let embeddings: [[Float]] = [ + [1.0, 2.0, 3.0], + [2.0, 3.0, 4.0] + ] + let assignments = [0, 1] + + let clusterer = VBxClustering() + let centroids = await clusterer.centroidsFromAssignments( + assignments: assignments, + embeddings: embeddings, + clusterCount: 3 + ) + + XCTAssertEqual(centroids.count, 3) + assertVectorsEqual(centroids[0], [1.0, 2.0, 3.0]) + assertVectorsEqual(centroids[1], [2.0, 3.0, 4.0]) + assertVectorsEqual(centroids[2], [0.0, 0.0, 0.0]) + } + + /// `centroidSource` controls whether the surfaced centroid mean includes all final + /// assignment members or only the trainable subset used to seed clustering. + func testCentroidsFromFinalAssignmentsHonoursCentroidSource() async { + let embeddings = [ + SpeakerEmbedding( + embedding: [1.0, 0.0], + activeFrames: [1.0], + windowIndex: 0, + speakerIndex: 0, + nonOverlappedFrameRatio: 1.0 + ), + SpeakerEmbedding( + embedding: [3.0, 0.0], + activeFrames: [1.0], + windowIndex: 1, + speakerIndex: 0, + nonOverlappedFrameRatio: 0.0 + ), + SpeakerEmbedding( + embedding: [0.0, 10.0], + activeFrames: [1.0], + windowIndex: 2, + speakerIndex: 0, + nonOverlappedFrameRatio: 1.0 + ), + SpeakerEmbedding( + embedding: [0.0, 20.0], + activeFrames: [1.0], + windowIndex: 3, + speakerIndex: 0, + nonOverlappedFrameRatio: 0.0 + ) + ] + let assignments = [0, 0, 1, 1] + let clusterer = VBxClustering() + + let finalAssignmentCentroids = await clusterer.centroidsFromFinalAssignments( + assignments: assignments, + embeddings: embeddings, + source: .finalAssignment, + minActiveRatio: 0.2 + ) + XCTAssertEqual(Set(finalAssignmentCentroids.keys), [0, 1]) + assertVectorsEqual(finalAssignmentCentroids[0] ?? [], [2.0, 0.0]) + assertVectorsEqual(finalAssignmentCentroids[1] ?? [], [0.0, 15.0]) + + let trainableOnlyCentroids = await clusterer.centroidsFromFinalAssignments( + assignments: assignments, + embeddings: embeddings, + source: .trainableOnly, + minActiveRatio: 0.2 + ) + XCTAssertEqual(Set(trainableOnlyCentroids.keys), [0, 1]) + assertVectorsEqual(trainableOnlyCentroids[0] ?? [], [1.0, 0.0]) + assertVectorsEqual(trainableOnlyCentroids[1] ?? [], [0.0, 10.0]) + } + /// `.trainableOnly` omits clusters whose members are all overlap-flagged, rather than surfacing a zero-vector centroid. + func testCentroidsFromFinalAssignments_omitsClusterWithNoTrainableMembers() async { + let embeddings = [ + SpeakerEmbedding( + embedding: [1.0, 0.0], + activeFrames: [1.0], windowIndex: 0, speakerIndex: 0, + nonOverlappedFrameRatio: 1.0 + ), + SpeakerEmbedding( + embedding: [0.0, 1.0], + activeFrames: [1.0], windowIndex: 1, speakerIndex: 0, + nonOverlappedFrameRatio: 0.0 + ), + SpeakerEmbedding( + embedding: [0.0, 2.0], + activeFrames: [1.0], windowIndex: 2, speakerIndex: 0, + nonOverlappedFrameRatio: 0.0 + ) + ] + let assignments = [0, 1, 1] // cluster 1 is entirely overlap-flagged + + let clusterer = VBxClustering() + let trainable = await clusterer.centroidsFromFinalAssignments( + assignments: assignments, embeddings: embeddings, + source: .trainableOnly, minActiveRatio: 0.2 + ) + + XCTAssertEqual(Set(trainable.keys), [0]) + assertVectorsEqual(trainable[0] ?? [], [1.0, 0.0]) + } + // MARK: - Clusterer-level invariant: value matches post-reassignment mean + + /// For any embedding input that drives VBxClustering end-to-end, the surfaced + /// `speakerCentroids[k]` must equal the arithmetic mean of the embeddings whose final + /// `clusterIndices[i]` is `k`. This holds irrespective of which internal path + /// (VBx weighted, kMeans correction, AHC fallback) produced the assignments, because the + /// final recompute is always a plain `centroidsFromAssignments(...)` over the + /// post-reassignment labels. + func testCentroidValuesMatchFinalAssignmentMean() async { + let dim = 128 + // two separated groups in raw embedder space (4 + 4), with an explicit two-speaker + // request so this fixture exercises the post-reassignment centroid recompute path. + let groupA: [[Float]] = (0..<4).map { i in + var v = Array(repeating: Float(0), count: dim) + v[0] = 1.0 + v[1] = Float(i) * 0.01 + return v + } + let groupB: [[Float]] = (0..<4).map { i in + var v = Array(repeating: Float(0), count: dim) + v[0] = -1.0 + v[1] = Float(i) * 0.01 + return v + } + let raw = groupA + groupB + let plda = raw + + let speakerEmbeddings = (0..= 2 else { + throw XCTSkip("need at least two speakers to compare centroids") + } + + for i in 0..