diff --git a/anchor/utils/whisper.py b/anchor/utils/whisper.py index 312a48b..ed07752 100644 --- a/anchor/utils/whisper.py +++ b/anchor/utils/whisper.py @@ -4,6 +4,7 @@ import os import torch import whisperx +from collections import Counter from rich.console import Console from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn from .ui import make_ui_console, CaptureProgress @@ -49,6 +50,53 @@ def load_whisper_model(device, compute_type, language, model_size="large-v3"): return model +def detect_robust_language(model, audio, batch_size): + """ + Samples three 30-second snippets from the audio (15%, 50%, 85%) + and runs a quick transcription to determine the language. + Returns the most common detected language, or None if it fails. + """ + sample_rate = 16000 + duration_sec = len(audio) / sample_rate + + # If the video is shorter than 2 minutes, let Whisper handle it natively + if duration_sec < 120: + return None + + timestamps = [0.15, 0.50, 0.85] + guesses = [] + + for pct in timestamps: + start_sample = int(pct * duration_sec * sample_rate) + end_sample = start_sample + (30 * sample_rate) + + # Ensure we don't go out of bounds + if end_sample > len(audio): + end_sample = len(audio) + + audio_snippet = audio[start_sample:end_sample] + + try: + # Run transcription silently to extract language guess + result = model.transcribe( + audio_snippet, + batch_size=batch_size, + language=None, + print_progress=False, # Mute progress bar + combined_progress=False + ) + detected = result.get("language") + if detected and detected != "unknown": + guesses.append(detected) + except Exception: + continue + + if not guesses: + return None + + # Return the most frequent guess (majority voting) + return Counter(guesses).most_common(1)[0][0] + def run_whisper_transcription(video_path, device, compute_type, batch_size, model, language=None): """Transcribes audio and aligns phonemes. Returns (whisper_data, detected_lang) or (None, None) on failure.""" safe_console = Console(force_terminal=True) @@ -60,8 +108,16 @@ def run_whisper_transcription(video_path, device, compute_type, batch_size, mode result = None current_batch_size = batch_size + is_windows = (os.name != 'posix') + if language is None: + console.print("[dim]🔍 Auto-detecting language via three separate samples...[/dim]") + robust_lang = detect_robust_language(model, audio, current_batch_size) + if robust_lang: + language = robust_lang + console.print(f"[dim]🌐 Robust auto-detect determined: [bold cyan]{language.upper()}[/bold cyan][/dim]") + while current_batch_size >= 1: try: sys.stdout.flush()