diff --git a/cosyvoice/cli/cosyvoice.py b/cosyvoice/cli/cosyvoice.py index 9b1299116..1c1529601 100644 --- a/cosyvoice/cli/cosyvoice.py +++ b/cosyvoice/cli/cosyvoice.py @@ -14,7 +14,7 @@ import os import time import threading -from typing import Generator, List +from typing import Generator, List, Optional import numpy as np from tqdm import tqdm from hyperpyyaml import load_hyperpyyaml @@ -299,6 +299,9 @@ def _load_llama_cpp(self, gguf_model_path): self.task_id_token_id = self.speech_token_offset + self.task_id_speech_idx self._llama_cpp_loaded = True + # llama.cpp context is NOT thread-safe and is shared across requests: + # serialize all reset()/eval()/sample() sequences with this lock. + self._llama_lock = threading.Lock() def _sample_speech_token_constrained(self, logit_pos): """Sample next token constrained to speech tokens + EOS only. @@ -346,50 +349,69 @@ def _run_llama_cpp_inference( text_token_ids: List[int], prompt_text_token_ids: List[int], prompt_speech_tokens: List[int], + on_token=None, + cancel_event: Optional[threading.Event] = None, ) -> List[int]: """ Run llama.cpp inference to generate speech tokens. Uses pre-tokenized IDs from the CosyVoice frontend (same as PyTorch path). Format: [SOS] + prompt_text_ids + text_ids + [TASK_ID] + offset(prompt_speech_tokens) + + on_token: optional callback invoked with each decoded speech token as soon + as it is generated (enables true streaming overlap with flow). + cancel_event: optional threading.Event; when set, generation stops early + (e.g. client disconnected mid-stream). + + The whole generation is serialized with self._llama_lock because the + llama.cpp context is shared and reset() from a concurrent request would + corrupt the KV cache of an in-flight generation. """ all_text_ids = prompt_text_token_ids + text_token_ids prompt_speech_ids = [self.speech_token_offset + t for t in prompt_speech_tokens] input_ids = [self.sos_token_id] + all_text_ids + [self.task_id_token_id] + prompt_speech_ids - self.llm_gguf.reset() - self.llm_gguf.eval(input_ids) - - # Track position for constrained sampling fallback - n_past = len(input_ids) - speech_tokens = [] raw_generated = [] max_new_tokens = 2048 - for i in range(max_new_tokens): - # Use built-in sample() (position-aware, like FastCosyVoice) - next_token_id = self.llm_gguf.sample() + with self._llama_lock: + self.llm_gguf.reset() + self.llm_gguf.eval(input_ids) - # If built-in sample returns text token, retry with constrained sampling - if (next_token_id != self.eos_token_id and - not (self.speech_token_offset <= next_token_id < self.speech_token_offset + self.base_speech_token_size)): - if i == 0: - logging.info('Built-in sample() returned text token {} on step 0, switching to constrained'.format(next_token_id)) - next_token_id = self._sample_speech_token_constrained(logit_pos=n_past - 1) + # Track position for constrained sampling fallback + n_past = len(input_ids) - raw_generated.append(next_token_id) + for i in range(max_new_tokens): + if cancel_event is not None and cancel_event.is_set(): + logging.info('llama.cpp inference cancelled after {} tokens'.format(len(speech_tokens))) + break - if next_token_id == self.eos_token_id: - break + # Use built-in sample() (position-aware, like FastCosyVoice) + next_token_id = self.llm_gguf.sample() - if self.speech_token_offset <= next_token_id < self.speech_token_offset + self.base_speech_token_size: - speech_tokens.append(next_token_id - self.speech_token_offset) - else: - break + # If built-in sample returns text token, retry with constrained sampling + if (next_token_id != self.eos_token_id and + not (self.speech_token_offset <= next_token_id < self.speech_token_offset + self.base_speech_token_size)): + if i == 0: + logging.info('Built-in sample() returned text token {} on step 0, switching to constrained'.format(next_token_id)) + next_token_id = self._sample_speech_token_constrained(logit_pos=n_past - 1) - self.llm_gguf.eval([next_token_id]) - n_past += 1 + raw_generated.append(next_token_id) + + if next_token_id == self.eos_token_id: + break + + if self.speech_token_offset <= next_token_id < self.speech_token_offset + self.base_speech_token_size: + speech_token = next_token_id - self.speech_token_offset + speech_tokens.append(speech_token) + if on_token is not None: + on_token(speech_token) + else: + break + + self.llm_gguf.eval([next_token_id]) + n_past += 1 return speech_tokens @@ -401,20 +423,42 @@ def _llama_cpp_job( tokens_list: list, llm_end_flag: dict, tokens_lock: threading.Lock, + tokens_cond: Optional[threading.Condition] = None, + cancel_event: Optional[threading.Event] = None, ): - """Thread target: generate all speech tokens via llama.cpp and fill shared tokens_list.""" + """Thread target: generate speech tokens via llama.cpp. + + Tokens are appended to the shared tokens_list ONE BY ONE as they are + decoded (previously the list was filled with a single extend() after the + whole generation finished, which meant flow matching could not start + until the LLM was completely done — i.e. no real streaming overlap). + """ + def _on_token(tok): + if tokens_cond is not None: + with tokens_cond: + tokens_list.append(tok) + tokens_cond.notify_all() + else: + with tokens_lock: + tokens_list.append(tok) + try: - speech_tokens = self._run_llama_cpp_inference( + self._run_llama_cpp_inference( text_token_ids=text_token_ids, prompt_text_token_ids=prompt_text_token_ids, prompt_speech_tokens=prompt_speech_tokens, + on_token=_on_token, + cancel_event=cancel_event, ) - with tokens_lock: - tokens_list.extend(speech_tokens) except Exception as e: logging.error('llama.cpp inference error: {}'.format(e), exc_info=True) finally: - llm_end_flag['done'] = True + if tokens_cond is not None: + with tokens_cond: + llm_end_flag['done'] = True + tokens_cond.notify_all() + else: + llm_end_flag['done'] = True # ------------------------------------------------------------------------- # Overridden inference methods with llama.cpp support @@ -451,28 +495,39 @@ def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_i if stream: tokens_list = [] tokens_lock = threading.Lock() + tokens_cond = threading.Condition(tokens_lock) llm_end_flag = {'done': False} + cancel_event = threading.Event() llm_thread = threading.Thread( target=self._llama_cpp_job, args=(text_ids, prompt_text_ids, prompt_speech_ids, tokens_list, llm_end_flag, tokens_lock), + kwargs={'tokens_cond': tokens_cond, 'cancel_event': cancel_event}, daemon=True ) llm_thread.start() - for model_output in self.model.tts_stream_external_llm( - tokens_list=tokens_list, - tokens_lock=tokens_lock, - llm_end_flag=llm_end_flag, - **{k: v for k, v in model_input.items() if k.startswith('flow') or k.startswith('prompt_speech')} - ): - speech_len = model_output['tts_speech'].shape[1] / self.sample_rate - logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) - yield model_output - start_time = time.time() - - llm_thread.join(timeout=5.0) + try: + for model_output in self.model.tts_stream_external_llm( + tokens_list=tokens_list, + tokens_lock=tokens_lock, + llm_end_flag=llm_end_flag, + tokens_cond=tokens_cond, + **{k: v for k, v in model_input.items() if k.startswith('flow') or k.startswith('prompt_speech')} + ): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() + finally: + # Гарантированно останавливаем LLM-поток, даже если клиент + # оборвал стрим (GeneratorExit): иначе следующий запрос + # сделает llm_gguf.reset() параллельно с живым eval(). + cancel_event.set() + with tokens_cond: + tokens_cond.notify_all() + llm_thread.join(timeout=30.0) else: speech_tokens = self._run_llama_cpp_inference( text_token_ids=text_ids, @@ -506,28 +561,39 @@ def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', str if stream: tokens_list = [] tokens_lock = threading.Lock() + tokens_cond = threading.Condition(tokens_lock) llm_end_flag = {'done': False} + cancel_event = threading.Event() llm_thread = threading.Thread( target=self._llama_cpp_job, args=(text_ids, prompt_text_ids, prompt_speech_ids, tokens_list, llm_end_flag, tokens_lock), + kwargs={'tokens_cond': tokens_cond, 'cancel_event': cancel_event}, daemon=True ) llm_thread.start() - for model_output in self.model.tts_stream_external_llm( - tokens_list=tokens_list, - tokens_lock=tokens_lock, - llm_end_flag=llm_end_flag, - **{k: v for k, v in model_input.items() if k.startswith('flow') or k.startswith('prompt_speech')} - ): - speech_len = model_output['tts_speech'].shape[1] / self.sample_rate - logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) - yield model_output - start_time = time.time() - - llm_thread.join(timeout=5.0) + try: + for model_output in self.model.tts_stream_external_llm( + tokens_list=tokens_list, + tokens_lock=tokens_lock, + llm_end_flag=llm_end_flag, + tokens_cond=tokens_cond, + **{k: v for k, v in model_input.items() if k.startswith('flow') or k.startswith('prompt_speech')} + ): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() + finally: + # Гарантированно останавливаем LLM-поток, даже если клиент + # оборвал стрим (GeneratorExit): иначе следующий запрос + # сделает llm_gguf.reset() параллельно с живым eval(). + cancel_event.set() + with tokens_cond: + tokens_cond.notify_all() + llm_thread.join(timeout=30.0) else: speech_tokens = self._run_llama_cpp_inference( text_token_ids=text_ids, @@ -561,28 +627,39 @@ def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk if stream: tokens_list = [] tokens_lock = threading.Lock() + tokens_cond = threading.Condition(tokens_lock) llm_end_flag = {'done': False} + cancel_event = threading.Event() llm_thread = threading.Thread( target=self._llama_cpp_job, args=(text_ids, prompt_text_ids, prompt_speech_ids, tokens_list, llm_end_flag, tokens_lock), + kwargs={'tokens_cond': tokens_cond, 'cancel_event': cancel_event}, daemon=True ) llm_thread.start() - for model_output in self.model.tts_stream_external_llm( - tokens_list=tokens_list, - tokens_lock=tokens_lock, - llm_end_flag=llm_end_flag, - **{k: v for k, v in model_input.items() if k.startswith('flow') or k.startswith('prompt_speech')} - ): - speech_len = model_output['tts_speech'].shape[1] / self.sample_rate - logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) - yield model_output - start_time = time.time() - - llm_thread.join(timeout=5.0) + try: + for model_output in self.model.tts_stream_external_llm( + tokens_list=tokens_list, + tokens_lock=tokens_lock, + llm_end_flag=llm_end_flag, + tokens_cond=tokens_cond, + **{k: v for k, v in model_input.items() if k.startswith('flow') or k.startswith('prompt_speech')} + ): + speech_len = model_output['tts_speech'].shape[1] / self.sample_rate + logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) + yield model_output + start_time = time.time() + finally: + # Гарантированно останавливаем LLM-поток, даже если клиент + # оборвал стрим (GeneratorExit): иначе следующий запрос + # сделает llm_gguf.reset() параллельно с живым eval(). + cancel_event.set() + with tokens_cond: + tokens_cond.notify_all() + llm_thread.join(timeout=30.0) else: speech_tokens = self._run_llama_cpp_inference( text_token_ids=text_ids, @@ -608,4 +685,4 @@ def AutoModel(**kwargs): elif os.path.exists('{}/cosyvoice3.yaml'.format(kwargs['model_dir'])): return CosyVoice3(**kwargs) else: - raise TypeError('No valid model type found!') + raise TypeError('No valid model type found!') \ No newline at end of file diff --git a/cosyvoice/cli/model.py b/cosyvoice/cli/model.py index 94f9ae61c..9c44d29c0 100644 --- a/cosyvoice/cli/model.py +++ b/cosyvoice/cli/model.py @@ -468,19 +468,22 @@ def tts_with_external_tokens( self.hift_cache_dict[this_uuid] = None tokens_gpu = torch.tensor(tokens, dtype=torch.int32, device=self.device).unsqueeze(0) - tts_speech = self.token2wav( - token=tokens_gpu, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - token_offset=0, - uuid=this_uuid, - finalize=True, - speed=speed, - ) - - with self.lock: - self.hift_cache_dict.pop(this_uuid) + try: + tts_speech = self.token2wav( + token=tokens_gpu, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + token_offset=0, + uuid=this_uuid, + finalize=True, + speed=speed, + ) + finally: + # pop в finally: иначе исключение в token2wav навсегда оставляет + # запись в hift_cache_dict (утечка памяти между запросами) + with self.lock: + self.hift_cache_dict.pop(this_uuid, None) return {'tts_speech': tts_speech.cpu()} def tts_stream_external_llm( @@ -488,12 +491,19 @@ def tts_stream_external_llm( tokens_list, tokens_lock, llm_end_flag, + tokens_cond=None, flow_embedding=torch.zeros(0, 192), flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), prompt_speech_feat=torch.zeros(1, 0, 80), **kwargs ): - """Streaming TTS with external LLM providing tokens via shared list + lock.""" + """Streaming TTS with external LLM providing tokens via shared list + lock. + + tokens_cond: optional threading.Condition built on tokens_lock. If given, + the consumer wakes up immediately when the producer appends a token + instead of polling with a fixed 0.1s sleep (which added up to 100ms of + latency per chunk). + """ this_uuid = str(uuid.uuid1()) with self.lock: self.hift_cache_dict[this_uuid] = None @@ -505,37 +515,53 @@ def tts_stream_external_llm( - flow_prompt_speech_token.shape[1] ) + # Condition и Lock делят один и тот же примитив, поэтому контекстный + # менеджер у обоих захватывает tokens_lock + sync = tokens_cond if tokens_cond is not None else tokens_lock + try: while True: - time.sleep(0.1) this_token_hop_len = ( self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len ) - with tokens_lock: - available = len(tokens_list) - token_offset - need = this_token_hop_len + self.flow.pre_lookahead_len - - if available >= need: - with tokens_lock: + need = this_token_hop_len + self.flow.pre_lookahead_len + + with sync: + # Ждём пока наберётся need токенов или продюсер закончит. + # done и available читаются под одним локом — нет гонки + # со «протухшим» available из прошлой итерации. + while (len(tokens_list) - token_offset) < need and not llm_end_flag['done']: + if tokens_cond is not None: + tokens_cond.wait(timeout=0.5) + else: + tokens_lock.release() + try: + time.sleep(0.02) + finally: + tokens_lock.acquire() + if (len(tokens_list) - token_offset) >= need: batch = list(tokens_list[:token_offset + need]) - this_tts_speech_token = torch.tensor(batch).unsqueeze(0) - tts_speech = self.token2wav( - token=this_tts_speech_token, - prompt_token=flow_prompt_speech_token, - prompt_feat=prompt_speech_feat, - embedding=flow_embedding, - token_offset=token_offset, - uuid=this_uuid, - stream=True, - finalize=False, - ) - token_offset += this_token_hop_len - yield {'tts_speech': tts_speech.cpu()} + else: + batch = None # done, остаток уйдёт финальным чанком - if llm_end_flag['done'] and available < need: + if batch is None: break + this_tts_speech_token = torch.tensor(batch).unsqueeze(0) + tts_speech = self.token2wav( + token=this_tts_speech_token, + prompt_token=flow_prompt_speech_token, + prompt_feat=prompt_speech_feat, + embedding=flow_embedding, + token_offset=token_offset, + uuid=this_uuid, + stream=True, + finalize=False, + ) + token_offset += this_token_hop_len + yield {'tts_speech': tts_speech.cpu()} + # Final batch with tokens_lock: final_tokens = list(tokens_list) @@ -555,4 +581,4 @@ def tts_stream_external_llm( with self.lock: self.hift_cache_dict.pop(this_uuid, None) if torch.cuda.is_available(): - torch.cuda.empty_cache() + torch.cuda.empty_cache() \ No newline at end of file