diff --git a/.gitignore b/.gitignore index 65fd507..b482fb1 100644 --- a/.gitignore +++ b/.gitignore @@ -87,6 +87,7 @@ target/ # pyenv .python-version venv +.venv # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e990e6..a0200c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,11 +4,17 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [4.0.1] - 2025-04-30 + +### Added + +- Support RT Multichannel and channel DZ ## [3.0.4] - 2025-04-16 - Support for new parameters `prefer_current_speaker` and `speaker_sensitivity` in Speaker Diarization + ## [3.0.3] - 2025-03-03 ### Added diff --git a/Makefile b/Makefile index 525a59c..818611f 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,11 @@ lint: black --check --diff $(SOURCES) ruff $(SOURCES) +.PHONY: lint-fix +lint-fix: + black $(SOURCES) + ruff --fix $(SOURCES) + .PHONY: format format: black $(SOURCES) diff --git a/speechmatics/adapters.py b/speechmatics/adapters.py index 6788345..a91635c 100644 --- a/speechmatics/adapters.py +++ b/speechmatics/adapters.py @@ -2,7 +2,7 @@ """ Functions for converting our JSON transcription results to other formats. """ -from typing import Any, List +from typing import Any, List, Optional def get_txt_translation(translations: List[dict]): @@ -32,8 +32,9 @@ def get_txt_translation(translations: List[dict]): def convert_to_txt( tokens: List[dict], language: str, - language_pack_info: dict = None, + language_pack_info: Optional[dict] = None, speaker_labels: bool = True, + channel: Optional[str] = None, ) -> str: """ Convert a set of transcription result tokens to a plain text format. @@ -41,6 +42,7 @@ def convert_to_txt( :param tokens: the transcription results. :param language_pack_info: information about the language pack. :param speaker_labels: whether or not to output speaker labels in the text. + :param channel: the channel name (if multichannel). :return: the plain text as a string. """ # Although we should get word_delimiter from language_pack_info, we still want sensible @@ -64,6 +66,8 @@ def convert_to_txt( texts.append(f"SPEAKER: {current_speaker}\n") texts.append(join_tokens(group, word_delimiter=word_delimiter)) texts.append("\n") + if texts and channel: + texts.insert(0, f"{channel}: ") return "".join(texts).rstrip() diff --git a/speechmatics/cli.py b/speechmatics/cli.py index 42aa7db..d367be3 100755 --- a/speechmatics/cli.py +++ b/speechmatics/cli.py @@ -111,6 +111,25 @@ def parse_word_replacements(replacement_words_filepath) -> List[Dict]: return replacement_words +def parse_multichannel_args(multichannel_args: str) -> list[str]: + """ + Parses multichannel arguments from the command line + :param multichannel_args: Multichannel arguments + :type multichannel_args: str + :return: A list of channels to be used. + :rtype: List[str] + :raises SystemExit: If the arguments are not formatted properly. + """ + channels = [] + try: + channels = multichannel_args.split(",") + except ValueError: + raise SystemExit( + f"Invalid format for multichannel arguments: '{multichannel_args}'. Expected format ," + ) + return channels + + def parse_additional_vocab(additional_vocab_filepath): """ Parses an additional vocab list from a file. @@ -329,6 +348,11 @@ def get_transcription_config( "Using additional vocab from file %s", args["additional_vocab_file"] ) + if args.get("multichannel"): + multichannel_args = parse_multichannel_args(args["multichannel"]) + config["channel_diarization_labels"] = multichannel_args + LOGGER.info(f"Using multchannel mode with channels {multichannel_args}") + if args.get("additional_vocab"): if not config.get("additional_vocab"): config["additional_vocab"] = args["additional_vocab"] @@ -518,6 +542,9 @@ def add_printing_handlers( api.add_event_handler( ServerMessageType.AudioAdded, lambda *args: print_symbol("-") ) + api.add_event_handler( + ServerMessageType.ChannelAudioAdded, lambda *args: print_symbol("=") + ) api.add_event_handler( ServerMessageType.AddPartialTranscript, lambda *args: print_symbol(".") ) @@ -525,17 +552,22 @@ def add_printing_handlers( ServerMessageType.AddTranscript, lambda *args: print_symbol("|") ) api.add_middleware(ClientMessageType.AddAudio, lambda *args: print_symbol("+")) + api.add_middleware( + ClientMessageType.AddChannelAudio, lambda *args: print_symbol("x") + ) def partial_transcript_handler(message): # "\n" does not appear in partial transcripts if print_json: print(json.dumps(message)) return + plaintext = speechmatics.adapters.convert_to_txt( message["results"], api.transcription_config.language, language_pack_info=api.get_language_pack_info(), speaker_labels=True, + channel=get_channel(message), ) if plaintext: sys.stderr.write(f"{escape_seq}{plaintext}\r") @@ -545,16 +577,24 @@ def transcript_handler(message): if print_json: print(json.dumps(message)) return + plaintext = speechmatics.adapters.convert_to_txt( message["results"], api.transcription_config.language, language_pack_info=api.get_language_pack_info(), speaker_labels=True, + channel=get_channel(message), ) if plaintext: sys.stdout.write(f"{escape_seq}{plaintext}\n") transcripts.text += plaintext + def get_channel(message): + return next( + (result["channel"] for result in message["results"] if "channel" in result), + None, + ) + def audio_event_handler(message): if print_json: print(json.dumps(message)) @@ -759,12 +799,23 @@ def rt_main(args): translation_config=transcription_config.translation_config, ) - def run(stream): + def run(stream=None, channel_stream_pairs=None): try: + # Pass in either stream or channel_stream_pairs depending on what != None + # Dynamically construct the args based on the input + args_list = [transcription_config] + if stream is not None: + args_list.append(stream) + elif channel_stream_pairs is not None: + args_list.append(None) # This skips the stream argument + args_list.append(channel_stream_pairs) + else: + raise SystemExit( + "Neither stream nor channel_stream_pairs were provided." + ) api.run_synchronously( - stream, - transcription_config, - get_audio_settings(args), + *args_list, + audio_settings=get_audio_settings(args), from_cli=True, extra_headers=extra_headers, ) @@ -773,11 +824,39 @@ def run(stream): LOGGER.warning("Keyboard interrupt received.") if args["files"][0] == "-": - run(sys.stdin.buffer) + if transcription_config.channel_diarization_labels: + raise SystemExit( + "Channel diarization is not yet supported when reading from stdin." + ) + run(stream=sys.stdin.buffer) else: - for filename in args["files"]: - with open(filename, "rb") as audio_file: - run(audio_file) + # Check we have the right diarization type + if transcription_config.channel_diarization_labels: + if ( + transcription_config.diarization != "channel" + and transcription_config.diarization != "channel_and_speaker" + ): + raise SystemExit( + "Multichannel DZ type must be 'channel' or 'channel_and_speaker'." + ) + + num_channels = len(transcription_config.channel_diarization_labels) + if len(args["files"]) != num_channels: + raise SystemExit( + f"Number of files: ({len(args['files'])}) must match number of channels: ({num_channels})." + ) + + channel_stream_pairs = {} + for i in range(num_channels): + # Here the order matters, as stream positions and diarization labels correspond to one another. + channel_name = transcription_config.channel_diarization_labels[i] + channel_stream_pairs[channel_name] = args["files"][i] + run(channel_stream_pairs=channel_stream_pairs) + + else: + for filename in args["files"]: + with open(filename, "rb") as audio_file: + run(stream=audio_file) def batch_main(args): diff --git a/speechmatics/cli_parser.py b/speechmatics/cli_parser.py index 9398899..da29eac 100644 --- a/speechmatics/cli_parser.py +++ b/speechmatics/cli_parser.py @@ -315,6 +315,15 @@ def get_arg_parser(): default=None, help=("Comma-separated list of expected languages for language identification"), ) + config_parser.add_argument( + "--multichannel", + metavar="CHANNELS", + help=( + "Enables multichannel mode and specifies channels. " + "Pass channels as a comma-separated string, e.g.: ,. " + "The number of channels specified must match the number of input files." + ), + ) # Parent parser for batch summarize argument batch_summarization_parser = argparse.ArgumentParser(add_help=False) @@ -547,7 +556,11 @@ def get_arg_parser(): rt_transcribe_command_parser.add_argument( "--diarization", - choices=["none", "speaker"], + choices=[ + "none", + "speaker", + "channel", + ], help="Which type of diarization to use.", ) diff --git a/speechmatics/client.py b/speechmatics/client.py index 842883b..4f25135 100644 --- a/speechmatics/client.py +++ b/speechmatics/client.py @@ -5,6 +5,9 @@ """ import asyncio +import base64 +from collections import defaultdict +from contextlib import AsyncExitStack import copy import json import logging @@ -20,7 +23,12 @@ ForceEndSession, TranscriptionError, ) -from speechmatics.helpers import get_version, json_utf8, read_in_chunks +from speechmatics.helpers import ( + check_tasks_exceptions, + get_version, + json_utf8, + read_in_chunks, +) from speechmatics.models import ( AudioSettings, ClientMessageType, @@ -76,18 +84,19 @@ def __init__( self.event_handlers = {x: [] for x in ServerMessageType} self.middlewares = {x: [] for x in ClientMessageType} - self.seq_no = 0 + self.seq_no: defaultdict = defaultdict(int) self.session_running = False self._language_pack_info = None self._transcription_config_needs_update = False self._session_needs_closing = False + self.channel_stream_pairs = None # The following asyncio fields are fully instantiated in # _init_synchronization_primitives - self._recognition_started = asyncio.Event + self._recognition_started: asyncio.Event # Semaphore used to ensure that we don't send too much audio data to # the server too quickly and burst any buffers downstream. - self._buffer_semaphore = asyncio.BoundedSemaphore + self._buffer_semaphore: asyncio.BoundedSemaphore async def _init_synchronization_primitives(self): """ @@ -167,6 +176,7 @@ def _start_recognition(self, audio_settings): "audio_format": audio_settings.asdict(), "transcription_config": self.transcription_config.as_config(), } + if self.transcription_config.translation_config is not None: msg[ "translation_config" @@ -187,11 +197,34 @@ def _end_of_stream(self): :py:attr:`speechmatics.models.ClientMessageType.EndOfStream` message. """ - msg = {"message": ClientMessageType.EndOfStream, "last_seq_no": self.seq_no} + assert ( + self.channel_stream_pairs is None + ), "End of stream can only be sent for a single channel" + seq_no = 0 + # if client disconnects before sending any audio, seq_no will be empty + if len(self.seq_no) == 1: + seq_no = next(iter(self.seq_no.values())) + msg = {"message": ClientMessageType.EndOfStream, "last_seq_no": seq_no} self._call_middleware(ClientMessageType.EndOfStream, msg, False) LOGGER.debug(msg) return msg + def _end_of_channel(self, channel: str) -> dict: + """ + Constructs a :py:attr:`speechmatics.models.ClientMessageType.EndOfChannel` message. + + :param channel: The name of the channel for which the end message is being constructed. + :type channel: str + """ + msg = { + "message": ClientMessageType.EndOfChannel, + "channel": channel, + "last_seq_no": self.seq_no[channel], + } + self._call_middleware(ClientMessageType.EndOfChannel, msg, False) + LOGGER.debug(msg) + return msg + def _consumer(self, message): """ Consumes messages and acts on them. @@ -205,7 +238,7 @@ def _consumer(self, message): :raises ForceEndSession: If this was raised by the user's event handler. """ - LOGGER.debug(message) + LOGGER.debug(f"{message=}") message = json.loads(message) message_type = message["message"] @@ -222,6 +255,8 @@ def _consumer(self, message): self._set_language_pack_info(message["language_pack_info"]) elif message_type == ServerMessageType.AudioAdded: self._buffer_semaphore.release() + elif message_type == ServerMessageType.ChannelAudioAdded: + self._buffer_semaphore.release() elif message_type == ServerMessageType.EndOfTranscript: raise EndOfTranscriptException() elif message_type == ServerMessageType.Warning: @@ -239,6 +274,71 @@ async def _producer(self, stream, audio_chunk_size): :param audio_chunk_size: Size of audio chunks to send. :type audio_chunk_size: int """ + if self.channel_stream_pairs is not None: + async for msg in self._process_multichannel_streams(audio_chunk_size): + yield msg + else: + async for msg in self._process_single_stream(stream, audio_chunk_size): + yield msg + + async def _stream_channel(self, channel, stream, queue, audio_chunk_size): + """ + Stream audio data for a specific channel and put messages into the queue. + """ + async for audio_chnk in read_in_chunks(stream, audio_chunk_size): + if self._session_needs_closing: + break + if self._transcription_config_needs_update: + await queue.put(self._set_recognition_config()) + self._transcription_config_needs_update = False + await asyncio.wait_for( + self._buffer_semaphore.acquire(), + timeout=self.connection_settings.semaphore_timeout_seconds, + ) + + base64_chunk = base64.b64encode(audio_chnk).decode("utf-8") + message = { + "message": "AddChannelAudio", + "channel": channel, + "data": base64_chunk, + } + + # seq_no is defaultdict is so the keys are created automatically + self.seq_no[channel] += 1 + self._call_middleware(ClientMessageType.AddChannelAudio, message) + await queue.put(message) + await queue.put(self._end_of_channel(channel)) + + async def _process_multichannel_streams(self, audio_chunk_size): + """ + Process multiple channel streams and yield messages to send to the server. + """ + assert ( + self.channel_stream_pairs is not None + ), "Channel stream pairs must be set for multichannel mode" + queue = asyncio.Queue() + tasks = [ + asyncio.create_task( + self._stream_channel(channel, channel_stream, queue, audio_chunk_size) + ) + for channel, channel_stream in self.channel_stream_pairs.items() + ] + while True: + check_tasks_exceptions(tasks) + streams_done = all(task.done() for task in tasks) + if streams_done and queue.empty(): + break + try: + message = await asyncio.wait_for(queue.get(), timeout=0.5) + yield json.dumps(message) + except asyncio.TimeoutError: + continue + + async def _process_single_stream(self, stream, audio_chunk_size): + """ + Process a single channel stream and yield messages to send to the server. + Yields binary audio chunks + """ async for audio_chunk in read_in_chunks(stream, audio_chunk_size): if self._session_needs_closing: break @@ -251,7 +351,7 @@ async def _producer(self, stream, audio_chunk_size): self._buffer_semaphore.acquire(), timeout=self.connection_settings.semaphore_timeout_seconds, ) - self.seq_no += 1 + self.seq_no["single"] += 1 self._call_middleware(ClientMessageType.AddAudio, audio_chunk, True) yield audio_chunk @@ -421,8 +521,9 @@ async def _communicate(self, stream, audio_settings): async def run( self, - stream, transcription_config: TranscriptionConfig, + stream: Optional[Any] = None, + channel_stream_pairs=None, audio_settings: AudioSettings = None, from_cli: bool = False, extra_headers: Dict = None, @@ -433,12 +534,15 @@ async def run( :py:meth:`run_synchronously` which will block until the session is finished. - :param stream: File-like object which an audio stream can be read from. - :type stream: io.IOBase - :param transcription_config: Configuration for the transcription. :type transcription_config: speechmatics.models.TranscriptionConfig + :param stream: Optional file-like object which an audio stream can be read from. + :type stream: io.IOBase + + :param channel_stream_pairs: Optional dict containing channel-name stream pairs. + :type channel_stream_pairs dict[str, io.IOBase] + :param audio_settings: Configuration for the audio stream. :type audio_settings: speechmatics.models.AudioSettings @@ -448,8 +552,19 @@ async def run( :raises Exception: Can raise any exception returned by the consumer/producer tasks. """ + if channel_stream_pairs: + opened_streams = {} + self._stream_exits = AsyncExitStack() + for channel_name, path in channel_stream_pairs.items(): + if isinstance(path, str): + file_object = await asyncio.to_thread(open, path, "rb") + else: + file_object = path + opened_streams[channel_name] = file_object + self.channel_stream_pairs = opened_streams + else: + self.channel_stream_pairs = None self.transcription_config = transcription_config - self.seq_no = 0 self._language_pack_info = None await self._init_synchronization_primitives() if extra_headers is None: diff --git a/speechmatics/helpers.py b/speechmatics/helpers.py index 6063eb4..185904c 100644 --- a/speechmatics/helpers.py +++ b/speechmatics/helpers.py @@ -127,3 +127,18 @@ def _process_status_errors(error): + "(e.g. --lang abc is invalid)." ) sys.exit(f"httpx.HTTPStatusError: {error}") + + +def check_tasks_exceptions(tasks): + for t in tasks: + if not t.done() or t.cancelled(): + continue + + exc = t.exception() + if not exc: + continue + + for other in tasks: + if other is not t: + other.cancel() + raise exc diff --git a/speechmatics/models.py b/speechmatics/models.py index 9957918..e50fa39 100644 --- a/speechmatics/models.py +++ b/speechmatics/models.py @@ -293,6 +293,9 @@ class TranscriptionConfig(_TranscriptionConfig): audio_events_config: Optional[AudioEventsConfig] = None """Optional configuration for audio events""" + channel_diarization_labels: List[str] = None + """Add your own speaker or channel labels to the transcript""" + def as_config(self): dictionary = self.asdict() dictionary.pop("translation_config", None) @@ -332,9 +335,6 @@ class BatchTranscriptionConfig(_TranscriptionConfig): speaker_diarization_config: BatchSpeakerDiarizationConfig = None """Optional parameters for speaker diarization.""" - channel_diarization_labels: List[str] = None - """Add your own speaker or channel labels to the transcript""" - summarization_config: SummarizationConfig = None """Optional configuration for transcript summarization.""" @@ -520,9 +520,17 @@ class ClientMessageType(str, Enum): """Adds more audio data to the recognition job. The server confirms receipt by sending an :py:attr:`ServerMessageType.AudioAdded` message.""" + AddChannelAudio = "AddChannelAudio" + """Adds more audio data to the recognition job for a specific channel. + The server confirms receipt by sending an :py:attr:`ServerMessageType.ChannelAudioAdded` message. + """ + EndOfStream = "EndOfStream" """Indicates that the client has no more audio to send.""" + EndOfChannel = "EndOfChannel" + """Indicates that the client has no more audio to send in particular channel.""" + SetRecognitionConfig = "SetRecognitionConfig" """Allows the client to re-configure the recognition session.""" @@ -542,6 +550,10 @@ class ServerMessageType(str, Enum): """Server response to :py:attr:`ClientMessageType.AddAudio`, indicating that audio has been added successfully.""" + ChannelAudioAdded = "ChannelAudioAdded" + """Server response to :py:attr:`ClientMessageType.AddAChanneludio`, indicating + that audio has been added successfully.""" + AddPartialTranscript = "AddPartialTranscript" """Indicates a partial transcript, which is an incomplete transcript that is immediately produced and may change as more context becomes available. diff --git a/tests/data/ch_converted.wav b/tests/data/ch_converted.wav new file mode 100644 index 0000000..2b66380 --- /dev/null +++ b/tests/data/ch_converted.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:558582cd82166d38607dbe73b8a8b9d4a297443e28f226f2df5460ed74f051f9 +size 320078 diff --git a/tests/data/short-text.mp3 b/tests/data/short-text.mp3 new file mode 100644 index 0000000..28103a4 --- /dev/null +++ b/tests/data/short-text.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c17368dcf4abf4e68613ab4055ab3d6040c5c400a47151efe1f1617892003e64 +size 31584 diff --git a/tests/data/short-text_converted.wav b/tests/data/short-text_converted.wav new file mode 100644 index 0000000..b477767 --- /dev/null +++ b/tests/data/short-text_converted.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12b778f8f469fe6703d658c3e3f6daed80b730a7fdd1cd45d49d26a5e6c34aac +size 252750 diff --git a/tests/mock_rt_server.py b/tests/mock_rt_server.py index 387c452..14da6ff 100644 --- a/tests/mock_rt_server.py +++ b/tests/mock_rt_server.py @@ -2,6 +2,7 @@ import logging import time import websockets +from collections import defaultdict class MockRealtimeLogbook: @@ -153,8 +154,12 @@ def dummy_add_transcript(): } +channel_seq = defaultdict(int) + + async def mock_server_handler(websocket, logbook): mock_server_handler.next_audio_seq_no = 1 + mock_server_handler.number_of_channels = 0 address, _ = websocket.remote_address logbook.connection_request = websocket.request.headers logbook.path = websocket.request.path @@ -203,6 +208,11 @@ def get_responses(message, is_binary=False): raise ValueError(message) if msg_name == "StartRecognition": + mock_server_handler.number_of_channels = len( + message.get("transcription_config", {}).get( + "channel_diarization_labels", [] + ) + ) responses.append( { "message": "RecognitionStarted", @@ -216,11 +226,24 @@ def get_responses(message, is_binary=False): ) elif msg_name == "EndOfStream": responses.append({"message": "EndOfTranscript"}) + elif msg_name == "EndOfChannel": + mock_server_handler.number_of_channels -= 1 + if mock_server_handler.number_of_channels == 0: + responses.append({"message": "EndOfTranscript"}) elif msg_name == "SetRecognitionConfig": pass + elif msg_name == "AddChannelAudio": + channel = message["channel"] + channel_seq[channel] += 1 + responses.append( + { + "message": "ChannelAudioAdded", + "channel": channel, + "seq_no": channel_seq[channel], + } + ) else: raise ValueError(f"Unrecognized message: {message}") - return responses def is_str(data_in): diff --git a/tests/test_cli.py b/tests/test_cli.py index 141d858..72f3459 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,6 +1,7 @@ import argparse import collections import logging +from math import ceil import os import pytest @@ -75,6 +76,21 @@ } }, ), + ( + [ + "rt", + "transcribe", + "--multichannel", + "test_channel_1,test_channel_2", + "--diarization=speaker", + ], + { + "mode": "rt", + "command": "transcribe", + "multichannel": "test_channel_1,test_channel_2", + "diarization": "speaker", + }, + ), ( ["batch", "transcribe", "--additional-vocab", "Speechmatics", "gnocchi"], {"additional_vocab": ["Speechmatics", "gnocchi"]}, @@ -385,6 +401,8 @@ ) def test_cli_arg_parse_with_file(args, values): common_transcribe_args = ["--auth-token=xyz", "--url=example", "fake_file.wav"] + if "--multichannel" in args: + common_transcribe_args.append("very_fake_file.wav") test_args = args + common_transcribe_args actual_values = vars(cli.parse_args(args=test_args)) @@ -738,6 +756,72 @@ def test_rt_main_with_all_options(mock_server, tmp_path): assert -1 <= (len(add_audio_messages) - expected_num_messages) <= 1 +def test_rt_main_with_multichannel_option(mock_server): + chunk_size = 512 + audio_path_1 = path_to_test_resource("ch_converted.wav") + audio_path_2 = path_to_test_resource("short-text_converted.wav") + + args = [ + "rt", + "transcribe", + "--ssl-mode=insecure", + "--url", + mock_server.url, + "--diarization=channel", + "--multichannel=channel_1,channel_2", + "--lang=en", + "--chunk-size", + str(chunk_size), + "--raw=pcm_s16le", + "--sample-rate=16000", + audio_path_1, + audio_path_2, + ] + + cli.main(vars(cli.parse_args(args))) + + assert mock_server.clients_connected_count == 1 + assert mock_server.clients_disconnected_count == 1 + assert mock_server.messages_received + assert mock_server.messages_sent + + # Check that the StartRecognition message contains the correct fields + msg = mock_server.find_start_recognition_message() + + # Check that audio types are preserved + assert msg["audio_format"]["type"] == "raw" + assert msg["audio_format"]["encoding"] == "pcm_s16le" + assert msg["audio_format"]["sample_rate"] == 16000 + assert msg["transcription_config"]["language"] == "en" + assert msg["transcription_config"]["diarization"] == "channel" + assert msg["transcription_config"].get("operating_point") is None + assert len(msg["transcription_config"]["channel_diarization_labels"]) == 2 + + # Check we get all channels in the add channel audio messages + eoc = mock_server.find_messages_by_type("EndOfChannel") + assert eoc + seq_no_map = {m["channel"]: m["last_seq_no"] for m in eoc} + add_channel_audio_messages = mock_server.find_messages_by_type("AddChannelAudio") + assert add_channel_audio_messages + + for channel, seq in seq_no_map.items(): + assert seq == sum( + 1 for msg in add_channel_audio_messages if msg.get("channel") == channel + ) + + assert all( + msg.get("channel") in seq_no_map.keys() for msg in add_channel_audio_messages + ), "Some messages have invalid channels!" + + # Check file sizes are respected + size_of_audio_file_1 = os.stat(audio_path_1).st_size + size_of_audio_file_2 = os.stat(audio_path_2).st_size + expected_num_messages = ceil(size_of_audio_file_1 / chunk_size) + ceil( + size_of_audio_file_2 / chunk_size + ) + assert -1 <= (len(add_channel_audio_messages) - expected_num_messages) <= 1 + + def test_rt_main_with_config_file(mock_server): audio_path = path_to_test_resource("ch.wav") config_path = path_to_test_resource("transcription_config.json") diff --git a/tests/test_client.py b/tests/test_client.py index 597051c..fc72aa3 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -94,7 +94,11 @@ def test_handlers_called(mock_server, mocker): ws_client.add_event_handler("all", all_handler) with open(path_to_test_resource("ch.wav"), "rb") as audio_stream: - ws_client.run_synchronously(audio_stream, transcription_config, audio_settings) + ws_client.run_synchronously( + transcription_config=transcription_config, + stream=audio_stream, + audio_settings=audio_settings, + ) mock_server.wait_for_clean_disconnects() # Each handler should have been called once for every message @@ -134,7 +138,11 @@ def language_changing_middleware(msg, is_binary): ) with open(path_to_test_resource("ch.wav"), "rb") as audio_stream: - ws_client.run_synchronously(audio_stream, transcription_config, audio_settings) + ws_client.run_synchronously( + transcription_config=transcription_config, + stream=audio_stream, + audio_settings=audio_settings, + ) mock_server.wait_for_clean_disconnects() # Each handler should have been called once for every message @@ -169,7 +177,11 @@ def session_ender(event): ws_client.add_event_handler(ServerMessageType.RecognitionStarted, session_ender) with open(path_to_test_resource("ch.wav"), "rb") as audio_stream: - ws_client.run_synchronously(audio_stream, transcription_config, audio_settings) + ws_client.run_synchronously( + transcription_config=transcription_config, + stream=audio_stream, + audio_settings=audio_settings, + ) mock_server.wait_for_clean_disconnects() # Only one message should have been sent from the server @@ -193,7 +205,10 @@ def test_run_synchronously_with_timeout(mock_server): with open(path_to_test_resource("ch.wav"), "rb") as audio_stream: with pytest.raises(asyncio.TimeoutError): ws_client.run_synchronously( - audio_stream, transcription_config, audio_settings, timeout=0.0001 + transcription_config=transcription_config, + stream=audio_stream, + audio_settings=audio_settings, + timeout=0.0001, ) @@ -250,7 +265,11 @@ def session_ender(event, _): ws_client.add_middleware(client_message_type, session_ender) with open(path_to_test_resource("ch.wav"), "rb") as audio_stream: - ws_client.run_synchronously(audio_stream, transcription_config, audio_settings) + ws_client.run_synchronously( + transcription_config=transcription_config, + stream=audio_stream, + audio_settings=audio_settings, + ) mock_server.wait_for_clean_disconnects() assert len(mock_server.messages_received) == expect_received_count @@ -270,7 +289,11 @@ def config_updater(msg): # pylint: disable=unused-argument ws_client.add_event_handler(ServerMessageType.RecognitionStarted, config_updater) with open(path_to_test_resource("ch.wav"), "rb") as audio_stream: - ws_client.run_synchronously(audio_stream, transcription_config, audio_settings) + ws_client.run_synchronously( + transcription_config=transcription_config, + stream=audio_stream, + audio_settings=audio_settings, + ) mock_server.wait_for_clean_disconnects() set_recognition_config_msgs = mock_server.find_messages_by_type( @@ -293,7 +316,11 @@ def test_start_recognition_sends_speaker_diarization_config(mock_server): ) with open(path_to_test_resource("ch.wav"), "rb") as audio_stream: - ws_client.run_synchronously(audio_stream, transcription_config, audio_settings) + ws_client.run_synchronously( + transcription_config=transcription_config, + stream=audio_stream, + audio_settings=audio_settings, + ) mock_server.wait_for_clean_disconnects() start_recognition_msgs = mock_server.find_messages_by_type("StartRecognition") @@ -332,7 +359,11 @@ def stopper(msg): # pylint: disable=unused-argument ws_client.add_event_handler(ServerMessageType.RecognitionStarted, stopper) with open(path_to_test_resource("ch.wav"), "rb") as audio_stream: - ws_client.run_synchronously(audio_stream, transcription_config, audio_settings) + ws_client.run_synchronously( + transcription_config=transcription_config, + stream=audio_stream, + audio_settings=audio_settings, + ) mock_server.wait_for_clean_disconnects() num_messages_after_stop = len(mock_server.messages_received) @@ -360,7 +391,9 @@ async def mock_connect(*_, **__): with patch.object(client.LOGGER, "error", mock_logger_error_method): try: ws_client.run_synchronously( - MagicMock(), TranscriptionConfig(language="en"), MagicMock() + transcription_config=TranscriptionConfig(language="en"), + stream=MagicMock(), + audio_settings=MagicMock(), ) except ConnectionResetError as exc: assert exc is not None @@ -381,9 +414,9 @@ def call_exit(*args, **kwargs): with patch("websockets.connect", connect_mock): try: ws_client.run_synchronously( - stream, - transcription_config, - audio_settings, + transcription_config=transcription_config, + stream=stream, + audio_settings=audio_settings, extra_headers=extra_headers, ) except Exception: @@ -478,19 +511,27 @@ async def test__producer_happy_path(mocker): if index < exp_iters - 1: assert msg == index # from range in mock_read_in_chunks exp_current_seq_no += 1 - cmp_dicts(original_state, state, exp_diffs={"seq_no": exp_current_seq_no}) + cmp_dicts( + original_state, + state, + exp_diffs={"seq_no": {"single": exp_current_seq_no}}, + ) else: assert msg == json.dumps( {"message": "EndOfStream", "last_seq_no": exp_final_seq_no} ) - cmp_dicts(original_state, state, exp_diffs={"seq_no": exp_final_seq_no}) + cmp_dicts( + original_state, + state, + exp_diffs={"seq_no": {"single": exp_current_seq_no}}, + ) assert exp_iters == len(msgs_states) cmp_dicts( original_state, deepcopy_state(ws_client), - exp_diffs={"seq_no": exp_final_seq_no}, + exp_diffs={"seq_no": {"single": exp_current_seq_no}}, ) @@ -609,7 +650,11 @@ def test_language_pack_info_is_stored(mock_server): mock_server.url ) with open(path_to_test_resource("ch.wav"), "rb") as audio_stream: - ws_client.run_synchronously(audio_stream, transcription_config, audio_settings) + ws_client.run_synchronously( + transcription_config=transcription_config, + stream=audio_stream, + audio_settings=audio_settings, + ) info = ws_client.get_language_pack_info() assert info is not None