diff --git a/config/canary_streamatt.yaml b/config/canary_streamatt.yaml new file mode 100755 index 0000000..65d0976 --- /dev/null +++ b/config/canary_streamatt.yaml @@ -0,0 +1,16 @@ +type: "simulstream.server.speech_processors.canary_streamatt.CanaryStreamAtt" +model_name: "nvidia/canary-1b-v2" +text_history: + type: "simulstream.server.speech_processors.base_streamatt.FixedWordsTextHistory" + history_words: 10 +speech_chunk_size: 0.960 # seconds +detokenizer_type: "canary" +cross_attn_layer: -2 +cutoff_frame_num: 8 +num_beams: 5 +audio_subsampling_factor: 8 +audio_history_max_duration: 160 # Maximum length for the audio buffer, in seconds +mel_hop_samples: 160 # Number of audio samples between adjacent mel frames +text_history_max_len: 128 +word_level_postprocess: True # Disable if character-level language +use_raw_audio_history: True \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 45a6581..3579e9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ hf = [ canary = [ "Cython", - "nemo_toolkit[asr]==2.4.0", + "nemo_toolkit[asr]==2.8.0", ] vad = [ diff --git a/simulstream/server/speech_processors/base_streamatt.py b/simulstream/server/speech_processors/base_streamatt.py index fa7ccd4..9fb65da 100644 --- a/simulstream/server/speech_processors/base_streamatt.py +++ b/simulstream/server/speech_processors/base_streamatt.py @@ -60,6 +60,10 @@ class BaseStreamAtt(BaseSpeechProcessor): context for next predictions. - **audio_subsampling_factor (int)**: Subsampling factor of the model, if any. Defaults to 1. + - **mel_hop_samples (int)**: Number of raw waveform samples per mel frame. + Defaults to 160, i.e. 10ms at 16kHz. + - **use_raw_audio_history (bool)**: Returns whether ``audio_history`` stores raw + waveform samples rather than processed frames. Defaults to False. - **text_history_max_len (int)**: The maximum length of the textual history after which the current content is cut. Defaults to 128. - **cross_attention_layer (int)**: Layer from which to extract the cross-attention from. @@ -77,6 +81,11 @@ def __init__(self, config: SimpleNamespace): text_history_cls = class_load(text_history_config.type) self.text_history_method = text_history_cls(text_history_config) self.audio_subsampling_factor = getattr(self.config, "audio_subsampling_factor", 1) + self.mel_hop_samples = getattr(self.config, "mel_hop_samples", 160) + self.use_raw_audio_history = getattr(self.config, "use_raw_audio_history", False) + self.frames_to_audio_history = self.audio_subsampling_factor + if self.use_raw_audio_history: + self.frames_to_audio_history *= self.mel_hop_samples self.text_history_max_len = getattr(self.config, "text_history_max_len", 128) self.cross_attn_layer = getattr(self.config, "cross_attention_layer", 3) self.cutoff_frame_num = getattr(self.config, "cutoff_frame_num", 2) @@ -173,8 +182,8 @@ def _update_speech_history(self, discarded_text: int, cross_attn: torch.Tensor) # Only one token: use the unique most attended frame earliest_attended_idx = most_attended_idxs[0] - # Multiply by the subsampling factor to recover the original number of frames - frames_to_cut = earliest_attended_idx * self.audio_subsampling_factor + # Multiply by the number of frames/samples corresponding to the audio history + frames_to_cut = earliest_attended_idx * self.frames_to_audio_history # Cut the unattended audio features self.audio_history = self.audio_history[frames_to_cut:] diff --git a/simulstream/server/speech_processors/canary_streamatt.py b/simulstream/server/speech_processors/canary_streamatt.py new file mode 100755 index 0000000..660bf00 --- /dev/null +++ b/simulstream/server/speech_processors/canary_streamatt.py @@ -0,0 +1,158 @@ +# Copyright 2025 FBK + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import torch +import numpy as np + +from types import SimpleNamespace +from typing import List, Tuple + +import copy + +from simulstream.server.speech_processors import SAMPLE_RATE +from simulstream.server.speech_processors.base_streamatt import BaseStreamAtt + +from nemo.collections.asr.models import ASRModel +from nemo.collections.asr.parts.submodules.multitask_decoding import ( + MultiTaskDecodingConfig, +) +from nemo.collections.asr.models.aed_multitask_models import ( + MultiTaskTranscriptionConfig, +) + + +class CanaryStreamAtt(BaseStreamAtt): + """ + StreamAtt policy implementation for NVIDIA's Canary-v2 model. + + Args: + config (SimpleNamespace): Configuration object. + Supported attributes: + - **audio_history_max_duration (int)**: Maximum audio history in seconds. + Defaults to ``30``. + - **num_beams (int)**: Number of beams to use for beam search decoding. + Defaults to ``5``. + """ + + def __init__(self, config: SimpleNamespace): + super().__init__(config) + self._audio_history_max_duration = getattr(self.config, "audio_history_max_duration", 30) + + expected_mel_hop_samples = ( + self.model.cfg.preprocessor.window_stride * self.model.cfg.preprocessor.sample_rate + ) + + assert self.mel_hop_samples == expected_mel_hop_samples, ( + f"mel_hop_samples is set to {self.mel_hop_samples} in the config, but the loaded " + f"model's preprocessor uses {expected_mel_hop_samples} samples per mel frame" + ) + + # Build the transcription config, which is reused for every transcribe() call. + self.transcription_cfg = MultiTaskTranscriptionConfig( + batch_size=1, + return_hypotheses=True, + enable_chunking=False, + verbose=False, + ) + + @property + def audio_max_len(self) -> int: + """Maximum audio history length in raw waveform samples.""" + return self._audio_history_max_duration * SAMPLE_RATE + + def set_source_language(self, language: str) -> None: + self.src_lang = language + + def set_target_language(self, language: str) -> None: + self.tgt_lang = language + + @classmethod + def load_model(cls, config: SimpleNamespace): + if not hasattr(cls, "model") or cls.model is None: + cls.model = ASRModel.from_pretrained(model_name=config.model_name) + + # Configure decoding strategy + multitask_decoding = MultiTaskDecodingConfig() + multitask_decoding.strategy = "beam" + multitask_decoding.return_xattn_scores = True + multitask_decoding.beam.beam_size = getattr(config, "num_beams", 5) + cls.model.change_decoding_strategy(multitask_decoding) + + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + assert cls.model.cfg.preprocessor.sample_rate == SAMPLE_RATE + cls.model.to(cls.device) + + def _build_transcription_config(self): + """ + Return a ``MultiTaskTranscriptionConfig`` whose prompt encodes the current source/target + languages, task, PNC preference, and forced decoder prefix. + """ + + default_turns = self.model.prompt.get_default_dialog_slots() + default_slots = copy.deepcopy(default_turns[0]["slots"]) + default_slots["source_lang"] = self.src_lang + default_slots["target_lang"] = self.tgt_lang + + turns = [ + { + "role": "user", "slots": default_slots + }, + { + "role": "user_prefix", + "slots": { + "prefix": self.model.tokenizer.tokens_to_text(self.text_history) + }, + }, + ] + + cfg_copy = copy.deepcopy(self.transcription_cfg) + cfg_copy.prompt = turns + + return cfg_copy + + def _preprocess(self, waveform: np.ndarray) -> np.ndarray: + """ + Append the incoming waveform chunk to the raw audio history and return it. + + Returns: + np.ndarray: Accumulated raw audio history. + """ + waveform = waveform.astype(np.float32) + if self.audio_history is None: + self.audio_history = waveform + else: + self.audio_history = np.concatenate( + [self.audio_history, waveform]) + + return self.audio_history + + def _generate(self, speech: np.ndarray) -> Tuple[List[str], torch.Tensor]: + override_config = self._build_transcription_config() + + with torch.inference_mode(): + output = self.model.transcribe(audio=speech, override_config=override_config) + + hypothesis = output[0] + + token_ids = hypothesis.y_sequence.detach().cpu().tolist() + tokens = self.model.tokenizer.ids_to_tokens(token_ids) + + xatt_raw = hypothesis.xatt_scores[self.cross_attn_layer] + xatt = xatt_raw.mean(dim=0).cpu() # we average over heads + xatt = self.normalize_attn(xatt) + + return tokens, xatt + + def tokens_to_string(self, tokens: List[str]) -> str: + return self.model.tokenizer.tokens_to_text(tokens) diff --git a/uts/speech_processors/test_streamatt.py b/uts/speech_processors/test_streamatt.py index 180c408..296b30f 100644 --- a/uts/speech_processors/test_streamatt.py +++ b/uts/speech_processors/test_streamatt.py @@ -14,8 +14,14 @@ import unittest from types import SimpleNamespace +import torch +import numpy as np +from typing import Dict, List, Tuple, Union -from simulstream.server.speech_processors.base_streamatt import PunctuationTextHistory +from simulstream.server.speech_processors.base_streamatt import ( + BaseStreamAtt, + PunctuationTextHistory, +) class TestPunctuationTextHistory(unittest.TestCase): @@ -60,5 +66,65 @@ def test_no_strong_punctuation(self): self.assertEqual(selected_history, ['回', '到', '纽', '约', '后', ',', '我']) +class FakeStreamAtt(BaseStreamAtt): + + def _preprocess(self, waveform: np.float32) -> Union[Dict[str, torch.Tensor], torch.Tensor]: + raise NotImplementedError("_preprocess not implemented in FakeStreamAtt") + + @classmethod + def load_model(cls, config: SimpleNamespace): + raise NotImplementedError("load_model not implemented in FakeStreamAtt") + + def set_source_language(self, language: str) -> None: + pass + + def set_target_language(self, language: str) -> None: + pass + + def tokens_to_string(self, tokens: List[str]) -> str: + return " ".join(tokens) + + def _generate(self, speech: torch.Tensor) -> Tuple[List[str], torch.Tensor]: + raise NotImplementedError("_generate not implemented in FakeStreamAtt") + + @property + def audio_max_len(self) -> float: + return 10000 + + +class TestUpdateSpeechHistory(unittest.TestCase): + def _run_update_speech_history(self, use_raw_audio_history): + config = SimpleNamespace( + use_raw_audio_history=use_raw_audio_history, + audio_subsampling_factor=2, + mel_hop_samples=2, + text_history=SimpleNamespace( + type="simulstream.server.speech_processors.base_streamatt.FixedWordsTextHistory", + ) + + ) + audio = np.arange(40, dtype=np.float32) + proc = FakeStreamAtt(config) + proc.text_history = ["▁hello"] + proc.audio_history = audio.copy() + + attn = torch.zeros(2, 10) + attn[1, 2] = 1.0 + + proc._update_speech_history(discarded_text=1, cross_attn=attn) + return proc.audio_history.tolist() + + def test_update_speech_history_trims_audio_with_raw_audio(self): + audio_hist = self._run_update_speech_history(use_raw_audio_history=True) + # 2 audio token discarded, subsampling factor is 2, + # num mel hop is 2, so 2*2*2=8 samples removed + self.assertListEqual(audio_hist, list(np.arange(8, 40, dtype=np.float32))) + + def test_update_speech_history_trims_audio(self): + audio_hist = self._run_update_speech_history(use_raw_audio_history=False) + # 2 audio token discarded, subsampling factor is 2, so 2*2=4 samples removed + self.assertListEqual(audio_hist, list(np.arange(4, 40, dtype=np.float32))) + + if __name__ == "__main__": unittest.main()