Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 144 additions & 67 deletions cosyvoice/cli/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import time
import threading
from typing import Generator, List
from typing import Generator, List, Optional
import numpy as np
from tqdm import tqdm
from hyperpyyaml import load_hyperpyyaml
Expand Down Expand Up @@ -299,6 +299,9 @@ def _load_llama_cpp(self, gguf_model_path):
self.task_id_token_id = self.speech_token_offset + self.task_id_speech_idx

self._llama_cpp_loaded = True
# llama.cpp context is NOT thread-safe and is shared across requests:
# serialize all reset()/eval()/sample() sequences with this lock.
self._llama_lock = threading.Lock()

def _sample_speech_token_constrained(self, logit_pos):
"""Sample next token constrained to speech tokens + EOS only.
Expand Down Expand Up @@ -346,50 +349,69 @@ def _run_llama_cpp_inference(
text_token_ids: List[int],
prompt_text_token_ids: List[int],
prompt_speech_tokens: List[int],
on_token=None,
cancel_event: Optional[threading.Event] = None,
) -> List[int]:
"""
Run llama.cpp inference to generate speech tokens.

Uses pre-tokenized IDs from the CosyVoice frontend (same as PyTorch path).
Format: [SOS] + prompt_text_ids + text_ids + [TASK_ID] + offset(prompt_speech_tokens)

on_token: optional callback invoked with each decoded speech token as soon
as it is generated (enables true streaming overlap with flow).
cancel_event: optional threading.Event; when set, generation stops early
(e.g. client disconnected mid-stream).

The whole generation is serialized with self._llama_lock because the
llama.cpp context is shared and reset() from a concurrent request would
corrupt the KV cache of an in-flight generation.
"""
all_text_ids = prompt_text_token_ids + text_token_ids
prompt_speech_ids = [self.speech_token_offset + t for t in prompt_speech_tokens]
input_ids = [self.sos_token_id] + all_text_ids + [self.task_id_token_id] + prompt_speech_ids

self.llm_gguf.reset()
self.llm_gguf.eval(input_ids)

# Track position for constrained sampling fallback
n_past = len(input_ids)

speech_tokens = []
raw_generated = []
max_new_tokens = 2048

for i in range(max_new_tokens):
# Use built-in sample() (position-aware, like FastCosyVoice)
next_token_id = self.llm_gguf.sample()
with self._llama_lock:
self.llm_gguf.reset()
self.llm_gguf.eval(input_ids)

# If built-in sample returns text token, retry with constrained sampling
if (next_token_id != self.eos_token_id and
not (self.speech_token_offset <= next_token_id < self.speech_token_offset + self.base_speech_token_size)):
if i == 0:
logging.info('Built-in sample() returned text token {} on step 0, switching to constrained'.format(next_token_id))
next_token_id = self._sample_speech_token_constrained(logit_pos=n_past - 1)
# Track position for constrained sampling fallback
n_past = len(input_ids)

raw_generated.append(next_token_id)
for i in range(max_new_tokens):
if cancel_event is not None and cancel_event.is_set():
logging.info('llama.cpp inference cancelled after {} tokens'.format(len(speech_tokens)))
break

if next_token_id == self.eos_token_id:
break
# Use built-in sample() (position-aware, like FastCosyVoice)
next_token_id = self.llm_gguf.sample()

if self.speech_token_offset <= next_token_id < self.speech_token_offset + self.base_speech_token_size:
speech_tokens.append(next_token_id - self.speech_token_offset)
else:
break
# If built-in sample returns text token, retry with constrained sampling
if (next_token_id != self.eos_token_id and
not (self.speech_token_offset <= next_token_id < self.speech_token_offset + self.base_speech_token_size)):
if i == 0:
logging.info('Built-in sample() returned text token {} on step 0, switching to constrained'.format(next_token_id))
next_token_id = self._sample_speech_token_constrained(logit_pos=n_past - 1)

self.llm_gguf.eval([next_token_id])
n_past += 1
raw_generated.append(next_token_id)

if next_token_id == self.eos_token_id:
break

if self.speech_token_offset <= next_token_id < self.speech_token_offset + self.base_speech_token_size:
speech_token = next_token_id - self.speech_token_offset
speech_tokens.append(speech_token)
if on_token is not None:
on_token(speech_token)
else:
break

self.llm_gguf.eval([next_token_id])
n_past += 1

return speech_tokens

Expand All @@ -401,20 +423,42 @@ def _llama_cpp_job(
tokens_list: list,
llm_end_flag: dict,
tokens_lock: threading.Lock,
tokens_cond: Optional[threading.Condition] = None,
cancel_event: Optional[threading.Event] = None,
):
"""Thread target: generate all speech tokens via llama.cpp and fill shared tokens_list."""
"""Thread target: generate speech tokens via llama.cpp.

Tokens are appended to the shared tokens_list ONE BY ONE as they are
decoded (previously the list was filled with a single extend() after the
whole generation finished, which meant flow matching could not start
until the LLM was completely done — i.e. no real streaming overlap).
"""
def _on_token(tok):
if tokens_cond is not None:
with tokens_cond:
tokens_list.append(tok)
tokens_cond.notify_all()
else:
with tokens_lock:
tokens_list.append(tok)

try:
speech_tokens = self._run_llama_cpp_inference(
self._run_llama_cpp_inference(
text_token_ids=text_token_ids,
prompt_text_token_ids=prompt_text_token_ids,
prompt_speech_tokens=prompt_speech_tokens,
on_token=_on_token,
cancel_event=cancel_event,
)
with tokens_lock:
tokens_list.extend(speech_tokens)
except Exception as e:
logging.error('llama.cpp inference error: {}'.format(e), exc_info=True)
finally:
llm_end_flag['done'] = True
if tokens_cond is not None:
with tokens_cond:
llm_end_flag['done'] = True
tokens_cond.notify_all()
else:
llm_end_flag['done'] = True

# -------------------------------------------------------------------------
# Overridden inference methods with llama.cpp support
Expand Down Expand Up @@ -451,28 +495,39 @@ def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_i
if stream:
tokens_list = []
tokens_lock = threading.Lock()
tokens_cond = threading.Condition(tokens_lock)
llm_end_flag = {'done': False}
cancel_event = threading.Event()

llm_thread = threading.Thread(
target=self._llama_cpp_job,
args=(text_ids, prompt_text_ids, prompt_speech_ids,
tokens_list, llm_end_flag, tokens_lock),
kwargs={'tokens_cond': tokens_cond, 'cancel_event': cancel_event},
daemon=True
)
llm_thread.start()

for model_output in self.model.tts_stream_external_llm(
tokens_list=tokens_list,
tokens_lock=tokens_lock,
llm_end_flag=llm_end_flag,
**{k: v for k, v in model_input.items() if k.startswith('flow') or k.startswith('prompt_speech')}
):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()

llm_thread.join(timeout=5.0)
try:
for model_output in self.model.tts_stream_external_llm(
tokens_list=tokens_list,
tokens_lock=tokens_lock,
llm_end_flag=llm_end_flag,
tokens_cond=tokens_cond,
**{k: v for k, v in model_input.items() if k.startswith('flow') or k.startswith('prompt_speech')}
):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
finally:
# Гарантированно останавливаем LLM-поток, даже если клиент
# оборвал стрим (GeneratorExit): иначе следующий запрос
# сделает llm_gguf.reset() параллельно с живым eval().
cancel_event.set()
with tokens_cond:
tokens_cond.notify_all()
llm_thread.join(timeout=30.0)
else:
speech_tokens = self._run_llama_cpp_inference(
text_token_ids=text_ids,
Expand Down Expand Up @@ -506,28 +561,39 @@ def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', str
if stream:
tokens_list = []
tokens_lock = threading.Lock()
tokens_cond = threading.Condition(tokens_lock)
llm_end_flag = {'done': False}
cancel_event = threading.Event()

llm_thread = threading.Thread(
target=self._llama_cpp_job,
args=(text_ids, prompt_text_ids, prompt_speech_ids,
tokens_list, llm_end_flag, tokens_lock),
kwargs={'tokens_cond': tokens_cond, 'cancel_event': cancel_event},
daemon=True
)
llm_thread.start()

for model_output in self.model.tts_stream_external_llm(
tokens_list=tokens_list,
tokens_lock=tokens_lock,
llm_end_flag=llm_end_flag,
**{k: v for k, v in model_input.items() if k.startswith('flow') or k.startswith('prompt_speech')}
):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()

llm_thread.join(timeout=5.0)
try:
for model_output in self.model.tts_stream_external_llm(
tokens_list=tokens_list,
tokens_lock=tokens_lock,
llm_end_flag=llm_end_flag,
tokens_cond=tokens_cond,
**{k: v for k, v in model_input.items() if k.startswith('flow') or k.startswith('prompt_speech')}
):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
finally:
# Гарантированно останавливаем LLM-поток, даже если клиент
# оборвал стрим (GeneratorExit): иначе следующий запрос
# сделает llm_gguf.reset() параллельно с живым eval().
cancel_event.set()
with tokens_cond:
tokens_cond.notify_all()
llm_thread.join(timeout=30.0)
else:
speech_tokens = self._run_llama_cpp_inference(
text_token_ids=text_ids,
Expand Down Expand Up @@ -561,28 +627,39 @@ def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk
if stream:
tokens_list = []
tokens_lock = threading.Lock()
tokens_cond = threading.Condition(tokens_lock)
llm_end_flag = {'done': False}
cancel_event = threading.Event()

llm_thread = threading.Thread(
target=self._llama_cpp_job,
args=(text_ids, prompt_text_ids, prompt_speech_ids,
tokens_list, llm_end_flag, tokens_lock),
kwargs={'tokens_cond': tokens_cond, 'cancel_event': cancel_event},
daemon=True
)
llm_thread.start()

for model_output in self.model.tts_stream_external_llm(
tokens_list=tokens_list,
tokens_lock=tokens_lock,
llm_end_flag=llm_end_flag,
**{k: v for k, v in model_input.items() if k.startswith('flow') or k.startswith('prompt_speech')}
):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()

llm_thread.join(timeout=5.0)
try:
for model_output in self.model.tts_stream_external_llm(
tokens_list=tokens_list,
tokens_lock=tokens_lock,
llm_end_flag=llm_end_flag,
tokens_cond=tokens_cond,
**{k: v for k, v in model_input.items() if k.startswith('flow') or k.startswith('prompt_speech')}
):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
finally:
# Гарантированно останавливаем LLM-поток, даже если клиент
# оборвал стрим (GeneratorExit): иначе следующий запрос
# сделает llm_gguf.reset() параллельно с живым eval().
cancel_event.set()
with tokens_cond:
tokens_cond.notify_all()
llm_thread.join(timeout=30.0)
else:
speech_tokens = self._run_llama_cpp_inference(
text_token_ids=text_ids,
Expand All @@ -608,4 +685,4 @@ def AutoModel(**kwargs):
elif os.path.exists('{}/cosyvoice3.yaml'.format(kwargs['model_dir'])):
return CosyVoice3(**kwargs)
else:
raise TypeError('No valid model type found!')
raise TypeError('No valid model type found!')
Loading