diff --git a/DEVELOPER.md b/DEVELOPER.md index fcfae0c..483657b 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -11,7 +11,7 @@ - [Operations Pipeline](#operations-pipeline) - [Service Providers](#service-providers) - [Local Services](#local-services) - - [KoboldCPP Setup](#koboldcpp-setup) + - [whisper.cpp and llama.cpp Setup](#whispercpp-and-llamacpp-setup) - [MeloTTS Setup](#melotts-setup) - [RVC (Voice Conversion)](#rvc-voice-conversion) - [Cloud Services](#cloud-services) @@ -22,16 +22,15 @@ - [Speech-to-Text (STT)](#speech-to-text-stt) - [Azure](#azure) - [Fish](#fish) - - [Kobold (using Whisper)](#kobold-using-whisper) + - [whisper.cpp](#whispercpp) - [OpenAI](#openai-1) - [Text-to-Text (T2T)](#text-to-text-t2t) - - [Kobold](#kobold) + - [llama.cpp](#llamacpp) - [OpenAI](#openai-2) - [Text-to-Speech (TTS)](#text-to-speech-tts) - [MeloTTS (Recommended)](#melotts-recommended) - [Azure](#azure-1) - [Fish](#fish-1) - - [Kobold](#kobold-1) - [OpenAI](#openai-3) - [pytts](#pytts) - [Audio Filters](#audio-filters) @@ -74,10 +73,7 @@ - [Making Operations](#making-operations) - [Implementing an Operation](#implementing-an-operation) - [Connecting an Operation for Use](#connecting-an-operation-for-use) - - [Adding Managed Processes](#adding-managed-processes) - - [Implementing a Process](#implementing-a-process) - - [Connecting a Process for Use](#connecting-a-process-for-use) - - [Connecting with Operations for Management](#connecting-with-operations-for-management) + - [Local Server Subprocesses](#local-server-subprocesses) - [Adding MCP Servers](#adding-mcp-servers) - [Making Applications](#making-applications) - [Extending Configuration](#extending-configuration) @@ -164,51 +160,38 @@ operations: Run everything on your own hardware without external API calls. -#### KoboldCPP Setup +#### whisper.cpp and llama.cpp Setup -**Compatibility:** Limited (depends on model) +**Compatibility:** Depends on model and hardware **Cost:** Free (local) -**Supports:** STT, T2T, TTS +**Supports:** STT (whisper.cpp), T2T (llama.cpp) **Installation:** -Installation should already be handled by Make. However if you want to use a specific varient, follow the instructions below. - -1. **Download KoboldCPP** from [releases](https://github.com/LostRuins/koboldcpp/releases): - - **NVIDIA GPU (e.g. RTX series):** `koboldcpp.exe` for Windows or `koboldcpp-linux-x64` for Linux - - **Older NVIDIA GPU (CUDA 11):** `koboldcpp-oldpc.exe` for Windows or `koboldcpp-linux-x64-oldpc` for Linux - - **Non-NVIDIA (No CUDA):** `koboldcpp-nocuda.exe` for Windows or `koboldcpp-linux-x64-nocuda` for Linux - - Place the KoboldCPP executable in the `models/kobold/` directory. - -2. **Download models:** - - **For T2T (LLM):** Download GGUF models as described [here](https://github.com/LostRuins/koboldcpp?tab=readme-ov-file#Obtaining-a-GGUF-model). Generally, any text-generation GGUF model from HuggingFace will work as long as your hardware meets its requirements. - - **For STT (Whisper):** Download the desired `.bin` file from [koboldcpp/whisper](https://huggingface.co/koboldcpp/whisper/tree/main) - - Recommended: `base.en` or `tiny.en` for balanced performance (English only), or `small` for multilingual support. - - Place all models in `models/kobold/` - -3. **Configure KoboldCPP:** - - Run the KoboldCPP executable to open the configuration interface - - **Under Quick Launch:** - - Select the correct GPU ID from the dropdown - - Disable "Launch Browser" - - Enable "Quiet Mode" (optional, reduces console spam) - - Enable "Use FlashAttention" (improves performance) - - Set Context Size based on your available VRAM (2048-8192+ tokens) - - Click "Browse" and load your GGUF LLM model - - **Under Context (optional):** - - Enable "Quantize KV Cache" and set to 8-bit or 4-bit to reduce VRAM usage with minimal quality impact - - **Under Audio (for STT):** - - Click "Browse" and load your Whisper model (`.bin` file) - - **IMPORTANT:** Click "Save" and save the configuration as a `.kcpps` file in `models/kobold/` - -4. **Update JAIson configuration:** +`make bootstrap` downloads pinned [whisper.cpp](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.4) and [llama.cpp](https://github.com/ggml-org/llama.cpp/releases/tag/b9381) server binaries into `bin/` (alongside ffmpeg). On Windows x64 the default is a CUDA build when available, with automatic CPU fallback. + +Override variants with environment variables before bootstrap: + +- `WHISPERCPP_VARIANT=cuda|cpu` (default: `cuda` on Windows x64) +- `LLAMACPP_VARIANT=cuda|cpu|rocm|vulkan` (default: `cuda` on Windows x64) + +Linux and macOS use OS-specific llama.cpp release archives; whisper.cpp server binaries are only published for Windows in the current upstream release (set `WHISPERCPP_SKIP=1` and install manually on other OSes if needed). + +1. **Download models** (not included in bootstrap): + - **STT:** ggml Whisper `.bin` from [ggerganov/whisper.cpp on Hugging Face](https://huggingface.co/ggerganov/whisper.cpp/tree/main) (e.g. `ggml-base.en.bin`) + - **T2T:** any compatible [GGUF](https://github.com/ggml-org/llama.cpp) chat model + +2. **Reference models on each operation** in `operations` (paths must be absolute on Windows use `\\`): ```yaml - kobold_filepath: "C:\\path\\to\\models\\kobold\\koboldcpp.exe" - kcpps_filepath: "C:\\path\\to\\models\\kobold\\myconfig.kcpps" + - role: stt + id: whispercpp + model_filepath: "C:\\path\\to\\models\\whispercpp\\ggml-base.en.bin" + language: en + - role: t2t + id: llamacpp + model_filepath: "C:\\path\\to\\models\\llamacpp\\your-model.gguf" ``` - **Note:** On Windows, use double backslashes (`\\`) in file paths + Server executables are always `bin/whisper-server` and `bin/llama-server` after bootstrap. #### MeloTTS Setup @@ -355,15 +338,17 @@ Convert spoken audio into text. id: fish ``` -#### Kobold (using Whisper) -- **Service:** KoboldCPP (Local) +#### whisper.cpp +- **Service:** whisper.cpp server (local) - **Cost:** Free - **Config:** ```yaml - role: stt - id: kobold - suppress_non_speech: true - langcode: "en" + id: whispercpp + model_filepath: "C:\\path\\to\\ggml-base.en.bin" + language: en + temperature: 0.0 + response_format: json ``` #### OpenAI @@ -384,23 +369,30 @@ Convert spoken audio into text. Process and generate conversational responses using LLMs. -#### Kobold -- **Service:** KoboldCPP (Local) +#### llama.cpp +- **Service:** llama-server (local) - **Cost:** Free -- **Features:** Advanced sampler controls - **Config:** ```yaml - role: t2t - id: kobold - max_context_length: 4096 # Context length set during Kobold config - max_length: 200 # Max response length - quiet: true # Quiet mode - rep_pen: 1.1 # Repetition penalty - depends on model, but 1.1 is common - rep_pen_range: 1024 # Depends on model - temperature: 0.7 # Controls randomness: higher is more creative, lower is more deterministic - top_k: 40 # Limits the next word selection to the top X most likely candidates - top_p: 0.95 # Nucleus sampling: only considers tokens that make up the top X% probability mass - typical: 1 # Typical sampling threshold; 1 = disabled + id: llamacpp + model_filepath: "C:\\path\\to\\models\\llamacpp\\your-model.gguf" # GGUF loaded by llama-server (-m); required + n_predict: 256 # Max tokens to generate (aliases: max_tokens, max_length) + temperature: 0.8 # Sampling temperature (>= 0) + top_p: 0.95 # Nucleus sampling (0–1) + top_k: 40 # Top-k sampling (0 = disabled on server) + min_p: 0.05 # Min-p sampling relative to best token (0–1) + typical_p: 1.0 # Locally typical sampling (1.0 = disabled) + repeat_penalty: 1.1 # Penalty for repeated token sequences + repeat_last_n: 64 # Recent tokens to consider for repetition (0 off, -1 = full ctx) + presence_penalty: 0.0 # Presence penalty (0.0 = disabled) + frequency_penalty: 0.0 # Frequency penalty (0.0 = disabled) + dry_multiplier: 0.0 # DRY strength (> 0 enables; include "dry" in samplers) + dry_base: 1.75 # DRY exponential base + dry_allowed_length: 2 # Repeat length allowed before DRY penalty ramps + dry_penalty_last_n: -1 # Tokens scanned for DRY (0 off, -1 = full context) + dry_sequence_breakers: ["\n", ":", "\"", "*"] # Strings that break DRY sequences + samplers: [dry, top_k, typ_p, top_p, min_p, temperature] # Sampler chain order; also: xtc, penalties, top_n_sigma ``` #### OpenAI @@ -468,17 +460,6 @@ Convert text responses into spoken audio. latency: "normal" # "normal" or "balanced" ``` -#### Kobold -- **Service:** KoboldCPP (Local) -- **Cost:** Free -- **Note:** Basic quality, included for completeness -- **Config:** - ```yaml - - role: tts - id: kobold - voice: "default" - ``` - #### OpenAI - **Service:** OpenAI or compatible (Cloud/Local) - **Cost:** Varies @@ -603,6 +584,18 @@ Generate text embeddings for semantic operations. model: "text-embedding-3-small" ``` +#### llama.cpp +- **Service:** llama-server (local) +- **Cost:** Free +- **Config:** Uses a dedicated embedding GGUF model. Server starts with `--embeddings`; vectors come from native `POST /embedding` (not `/v1/embeddings`). + ```yaml + - role: embedding + id: llamacpp + model_filepath: "C:\\path\\to\\models\\llamacpp\\your-embedding-model.gguf" # Embedding GGUF; required + pooling: mean # mean, cls, last, rank, or none (none = per-token vectors, mean-pooled here) + embd_normalize: 2 # -1 none, 0 max-abs, 1 taxicab, 2 Euclidean/L2, >2 p-norm + ``` + --- ## REST API @@ -992,10 +985,7 @@ In case you really want to use an unsupported service, directly implement a mode - [Making Operations](#making-operations) - [Implementing an Operation](#implementing-an-operation) - [Connecting an Operation for Use](#connecting-an-operation-for-use) -- [Adding Managed Processes](#adding-managed-processes) - - [Implementing a Process](#implementing-a-process) - - [Connecting a Process for Use](#connecting-a-process-for-use) - - [Connecting with Operations for Management](#connecting-with-operations-for-management) +- [Local Server Subprocesses](#local-server-subprocesses) - [Adding MCP Servers](#adding-mcp-servers) - [Making Applications](#making-applications) - [Extending Configuration](#extending-configuration) @@ -1035,7 +1025,7 @@ There are 2 inherited attributes: There are 6 functions to note: -`__init__(self)`: must be implemented with no additional arguments. In here, you must also call `super().__init__(op_id)` where `op_id` will be the id of this operation, unique to the one's of the same type (there are multiple `kobold`, but each `kobold` operation is in a different type). You can initialize attributes in here, but this is only ran once and is synchronous. +`__init__(self)`: must be implemented with no additional arguments. In here, you must also call `super().__init__(op_id)` where `op_id` will be the id of this operation, unique to the one's of the same type (e.g. `whispercpp` for STT and `llamacpp` for T2T). You can initialize attributes in here, but this is only ran once and is synchronous. `__call__`: **DO NOT IMPLEMENT** @@ -1060,48 +1050,9 @@ All operations are accessed from the `OperationManager` located in `utils/operat You can now use your custom operation. -### Adding Managed Processes - -#### Implementing a Process - -If you have an operation that depends on another running application, you can have jaison-core automatically start and stop that application whenever that operation is in use or not. This is done for KoboldCPP, and can be done for your application as well as long as you can start and get an instance of that process in Python (see `utils/processes/processes/koboldcpp.py` for example). - -Code for managing processes can be found in `utils/processes`. Process specific code is in `utils/processes/processes`. You will need to implement `BaseProcess` found in `utils/processes/base.py`. - -You only need to implement 2 functions. All else should not be modified. Check the base implementation to know which these are. - -`__init__`: Be sure to call `super().__init__(process_id)` where `process_id` is the a unique name chose purely for logging purposes. - -`async reload(self)`: Starting logic. You will need to start the process and save it to the `process` attribute. You can also save the `port` is applicable for use in your operations. - -#### Connecting a Process for Use - -All processes are accessed through the `ProcessManager` found in `utils/processes/manager.py`. We need to add it here so it's exposed for use. - -1. Open `utils/processes/manager.py` -2. Add an entry to the `ProcessType` enum for your process. -3. Create a new case in function `load` - - Import your process in there - - Add a new instance with the enum as the key - - asynchronously call `reload` on that instance - -#### Connecting with Operations for Management - -The process does not start until an operation demands it. Likewise, it does not stop until there are no more operations that use it. To setup this relationship, we need to know 2 functions from the `ProcessManager`: - -`link(link_id, process_type)`: Link an operation to that process. This lets the process know it's being used by that operation. `link_id` is an ID unique across all operations for that specific operation. `process_type` is the enum you created for your process. - -`unlink(link_id, process_type)`: Unlink an operation to that process. This lets the process know the operation no longer needs it (because its closing or just doesn't need it). `link_id` is an ID unique across all operations for that specific operation. `process_type` is the enum you created for your process. - -When all links are gone, a process will unload itself. Once an operation links up again, the process will start up again. For examples of how this is used, see any `kobold` operation. - -There are additional helper functions you may find useful: - -`get_process(process_type)`: Get the instance of that process. Useful if you need direct access to its attributes such as `port`. - -`signal_reload(process_type)`: Have the process restart on the next clock cycle. Typically not needed for an operation and moreso for restarting a process with modified configuration. +### Local Server Subprocesses -`signal_unload(process_type)`: Have the process foribly unload on the next clock cycle. Ignores existing links and just shuts down the process. Typically not needed for an operation and moreso for jaison-core shutdown. +Operations that need a local HTTP server (e.g. `whispercpp`, `llamacpp`) start and stop their own subprocess in `async start()` / `async close()`. Each loaded operation instance owns one server process and an ephemeral port. Shared helpers live in `utils/helpers/subprocess_server.py`; see `utils/operations/stt/whispercpp.py` and `utils/operations/t2t/llamacpp.py`. ### Adding MCP Servers @@ -1129,7 +1080,7 @@ Majority of interactions are job-based. It will most likely be necessary to crea ### Extending Configuration -All configuration lives in `utils/config.py`. They are accessible all throughout the code by importing this module and fetching the singleton via `Config()`. Extending this configuration is as simple as adding a new attribute. **This attribute must have a type hint and a default value**. Now you can configure this value from your config files using the same name as the attribute. +All configuration lives in `utils/config.py`. They are accessible throughout the code via the module-level `config` instance (`from utils.config import config`). Extending this configuration is as simple as adding a new attribute. **This attribute must have a type hint and a default value**. Now you can configure this value from your config files using the same name as the attribute. ### Extending API diff --git a/Makefile b/Makefile index 84dde94..f2d1da9 100644 --- a/Makefile +++ b/Makefile @@ -9,14 +9,14 @@ help: @echo Targets: @echo make sync - uv sync, runtime deps only @echo make sync-dev - uv sync --dev - @echo make setup - uv sync then install NLTK UniDic KoboldCPP and RVC base models + @echo make setup - uv sync then install NLTK UniDic whisper/llama and RVC base models @echo make dev - sync-dev then full bootstrap scripts - @echo make bootstrap - install.py plus KoboldCPP and RVC HF assets + @echo make bootstrap - install.py plus whisper.cpp llama.cpp and RVC HF assets @echo make test - pytest, run sync-dev or dev first @echo make lint - ruff and black checks @echo make lint-fix - ruff --fix and ruff format @echo make fmt - black and ruff format in place - @echo make bootstrap-force - force re-download of KoboldCPP and RVC HF assets + @echo make bootstrap-force - force re-download of whisper.cpp llama.cpp and RVC HF assets @echo make lock - uv lock sync: diff --git a/README.md b/README.md index e874d94..5a84232 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,7 @@ Select the desired operations (refer to **[Development Guide](DEVELOPER.md)**) a ``` See the **[Development Guide](DEVELOPER.md)** for detailed configuration instructions, including: -- Setting up local services (KoboldCPP, MeloTTS, RVC) +- Setting up local services (whisper.cpp, llama.cpp, MeloTTS, RVC) - Configuring cloud providers (Azure, OpenAI, Fish Audio) - Customizing prompts and operations - Choosing the right services for your use case diff --git a/api.yaml b/api.yaml deleted file mode 100644 index bcef6ef..0000000 --- a/api.yaml +++ /dev/null @@ -1,727 +0,0 @@ -openapi: 3.1.0 -info: - title: jaison-core REST API overview - description: |- - This is an overview of the REST API for jaison-core. This is only the REST API endpoints and does not cover websocket or websocket events. For that, see DEVELOPER.md - version: 1.0.0 -externalDocs: - description: Find out more in developer docs - url: https://github.com/limitcantcode/jaison-core/blob/main/DEVELOPER.md -servers: - - url: http://localhost:7272/api -tags: - - name: misc - description: General management - - name: response - description: Request running of various generation pipelines - - name: context - description: Add information to the script - - name: operation - description: Manage and use specific operations - - name: configuration - description: Save, load, and modify configuration - -paths: - # MISC - /job: - delete: - tags: - - misc - summary: Cancel a job - description: Immediately cancel a queued or already running job by job_id. Will fail if job finished or doesn't exist. Cancellation will be reported over websockets. - operationId: jobCancel - requestBody: - description: Target a job by UUID - required: True - content: - application/json: - schema: - type: object - required: - - job_id - properties: - job_id: - type: string - format: uuid - description: Job ID to cancel - reason: - type: string - format: uuid - description: Reason for cancelling - responses: - '200': - description: Successfully cancelled job - content: - application/json: - schema: - type: object - required: - - status - - message - - response - properties: - status: - type: integer - enum: [200] - message: - type: string - enum: ["Job flagged for cancellation"] - description: Description of response result - response: - type: object - description: Empty object - '400': - description: Invalid job request - content: - application/json: - schema: - type: object - required: - - status - - message - - response - properties: - status: - type: integer - enum: [400] - message: - type: string - enum: ["Job ID does not exist or already finished","Request missing job_id"] - description: Description of response result - response: - type: object - description: Empty object - '500': - $ref: '#/components/responses/InternalErrorResponse' - # RESPONSE - /response: - post: - tags: - - response - summary: Request a text/audio response - description: Add a text/audio response job to the job queue. Results will be communicated over websockets. - operationId: responseAdd - requestBody: - description: Response request arguments - required: False - content: - application/json: - schema: - type: object - properties: - include_audio: - type: boolean - description: Whether to try and generate audio - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - # CONTEXT - /context: - delete: - tags: - - context - summary: Clear all history - description: Clear cached script including conversation history, context history, etc. Status is communicated over websockets. - operationId: contextDelete - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - /context/config: - put: - tags: - - context - summary: Configure prompter - description: Configure specific values within the prompter such as prompts, names, and history length - operationId: contextConfigure - requestBody: - description: Options to be configured - content: - application/json: - schema: - type: object - properties: - name_translations: - type: object - description: Request to be given to the LLM - additionalProperties: - type: string - description: "Untranslated key translates into given value" - character_name: - type: string - description: Name of the character - history_length: - type: integer - description: Line count in script to retain - instruction_prompt_filename: - type: string - description: File name of instruction prompt file - character_prompt_filename: - type: string - description: File name of character prompt file - scene_prompt_filename: - type: string - description: File name of scene prompt file - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - /context/request: - post: - tags: - - context - summary: Append request in script - description: Add a request to the script for the LLM to process in conversation. Status is communicated over websockets. - operationId: responseRequestAdd - requestBody: - description: Content of the request - required: True - content: - application/json: - schema: - type: object - required: - - content - properties: - content: - type: string - description: Request to be given to the LLM - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - /context/conversation/text: - post: - tags: - - context - summary: Append conversation text in script - description: Add a conversational text to the script. Status is communicated over websockets. - operationId: responseConvTextAdd - requestBody: - description: Content of the request - required: True - content: - application/json: - schema: - type: object - required: - - user - - content - properties: - user: - type: string - description: Name of user associated with content - timestamp: - type: integer - minimum: 0 - maximum: 9999999999 - description: UNIX timestamp of message - content: - type: string - description: Message from user - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - /context/conversation/audio: - post: - tags: - - context - summary: Append conversation audio to script - description: Transcribe conversational audio and add to the script. Status is communicated over websockets. - operationId: responseConvAudioAdd - requestBody: - description: Content of the request - required: True - content: - application/json: - schema: - type: object - required: - - user - - audio_bytes - - sr - - sw - - ch - properties: - user: - type: string - description: Name of user associated with speech - timestamp: - type: integer - minimum: 0 - maximum: 9999999999 - description: UNIX timestamp of message - audio_bytes: - type: string - format: byte - description: PCM audio bytes containing speech - sr: - type: integer - minimum: 0 - description: Sample rate of audio - sw: - type: integer - minimum: 0 - description: Number of bytes per audio sample - ch: - type: integer - minimum: 0 - description: Number of audio channels - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - /context/custom: - put: - tags: - - context - summary: Register custom context - description: Register custom context details for use in script for later use describing things outside of conversation and requests. Status is communicated over websockets. - operationId: responseCustomRegister - requestBody: - description: Details of custom context to register - required: True - content: - application/json: - schema: - type: object - required: - - context_id - - context_name - properties: - context_id: - type: string - description: Custom context id used by future requests - context_name: - type: string - description: Name of the context as will appear in the script - context_description: - type: string - description: Context description as will be described to the LLM - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - delete: - tags: - - context - summary: Unregister a custom context - description: Remove a previously registered custom context so it is no longer described or addable to the script. Status is communicated over websockets. - operationId: responseCustomRemove - requestBody: - description: Target custom context to remove - required: True - content: - application/json: - schema: - type: object - required: - - context_id - properties: - context_id: - type: string - description: Targetted context id to delete - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - post: - tags: - - context - summary: Add custom context to script - description: Add custom context to script for external descriptions alongside conversation and requests. Status is communicated over websockets. - operationId: responseCustomAppend - requestBody: - description: Content of the custom context - required: True - content: - application/json: - schema: - type: object - required: - - context_id - - context_contents - properties: - context_id: - type: string - context_contents: - type: string - timestamp: - type: integer - minimum: 0 - maximum: 9999999999 - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - # OPERATION - /operations: - get: - tags: - - operation - summary: Get all loaded operations - description: Get names of which operations are loaded for which operation role if available. - operationId: operationGet - responses: - '200': - description: Successfully got loaded operations - content: - application/json: - schema: - type: object - required: - - status - - message - - response - properties: - status: - type: integer - enum: [200] - message: - type: string - enum: ["Loaded operations gotten"] - description: Description of response result - response: - type: object - description: Mapping of operation type to loaded operation id - properties: - stt: - type: string - mcp: - type: string - t2t: - type: string - tts: - type: string - filter_audio: - type: array - items: - type: string - filter_text: - type: array - items: - type: string - /operation/config: - post: - tags: - - operation - summary: Configure operation - description: Configure a list of operations by role. Configuration differs per operation and role. - operationId: operationConfigure - requestBody: - description: Operations to load - required: True - content: - application/json: - schema: - type: object - required: - - ops - properties: - ops: - type: array - description: List of operation identifiers and configuration - items: - type: object - required: - - role - - id - properties: - role: - type: string - enum: ['stt', 'mcp', 't2t', 'tts', 'filter_audio', 'filter_text', 'embedding'] - description: Operation role to load - id: - type: string - description: Operation under specified role's type to load - additionalProperties: - type: [string, number] - - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - /operations/use: - post: - tags: - - operation - summary: Use an operation - description: Use a specific operation (loaded or not). Results returned over websockets. - operationId: operationUse - requestBody: - description: Target and input for operation - required: True - content: - application/json: - schema: - type: object - required: - - type - - payload - properties: - role: - type: string - enum: ['stt', 'mcp', 't2t', 'tts', 'filter_audio', 'filter_text', 'embedding'] - description: Operation role to use - id: - type: string - description: Specific operation under role's type to use. Defaults to already loaded operation. - payload: - type: object - description: Input chunk/payload for operation to process (see DEVELOPER.md for payload details per operation) - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - /operations/load: - post: - tags: - - operation - summary: Load operations for later use - description: Load a list of operations into operation manager for default use. Will unload any existing operations and load new ones. Status is communicated over websockets. - operationId: operationLoad - requestBody: - description: Operations to load - required: True - content: - application/json: - schema: - type: object - required: - - ops - properties: - ops: - type: array - description: List of operation identifiers - items: - type: object - required: - - role - - id - properties: - role: - type: string - enum: ['stt', 'mcp', 't2t', 'tts', 'filter_audio', 'filter_text', 'embedding'] - description: Operation role to load - id: - type: string - description: Operation under specified role's type to load - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - /operations/reload: - post: - tags: - - operation - summary: Load all operations as configured in configuration - description: Load all operations as configured in current configuration, unloading any existing operations as necessary. Status is communicated over websockets. - operationId: operationReload - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - /operations/unload: - post: - tags: - - operation - summary: Unload operations so they can no longer be used - description: Unload a list of operations from the operation manager so they no longer get used by default. Nothing will take its place until requested. Status is communicated over websockets. - operationId: operationUnload - requestBody: - description: Operations to unload - required: True - content: - application/json: - schema: - type: object - required: - - ops - properties: - ops: - type: array - description: List of operation identifiers - items: - type: object - required: - - role - properties: - role: - type: string - enum: ['stt', 'mcp', 't2t', 'tts', 'filter_audio', 'filter_text', 'embedding'] - description: Operation role to unload - id: - type: string - description: Specific operation to unload - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - # CONFIGURATION - /config: - get: - tags: - - configuration - summary: Get current configuration - description: Get all fields and values in the current configuration state. Configuration will be passed in response property. - operationId: configGet - responses: - '200': - description: Successfully got current configuration - content: - application/json: - schema: - type: object - required: - - status - - message - - response - properties: - status: - type: integer - enum: [200] - message: - type: string - enum: ["Current config gotten"] - response: - type: object - /config/load: - post: - tags: - - configuration - summary: Load a saved config - description: Load a saved config from file. Status is communicated over websockets. - operationId: configLoad - requestBody: - description: Configuration to load - required: True - content: - application/json: - schema: - type: object - required: - - config_name - properties: - config_name: - type: string - description: Name of config to load - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - /config/update: - post: - tags: - - configuration - summary: Update the current config. - description: Update the current configuration without saving to file. Status is communicated over websockets. - operationId: configUpdate - requestBody: - description: Configuration fields to update. - required: True - content: - application/json: - schema: - type: object - description: JSON equivilant of YAML configuration with fields to be updated - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' - /config/save: - post: - tags: - - configuration - summary: Save current configuration to file. - description: Save the current configuration to file using the name specified. Will overwrite configurations with the same name. Status is communicated over websockets. - operationId: configSave - requestBody: - description: Config name to save to - required: True - content: - application/json: - schema: - type: object - required: - - config_name - properties: - config_name: - type: string - description: Name of config to save to - responses: - '200': - $ref: '#/components/responses/JobResponse' - '500': - $ref: '#/components/responses/InternalErrorResponse' -components: - schemas: - Job: - type: object - required: - - status - - message - - response - properties: - status: - type: integer - enum: [200] - message: - type: string - enum: ["... job created"] - description: Description of response result - response: - type: object - required: - - job_id - properties: - job_id: - type: string - format: uuid - description: Job ID of job created for this request - InternalError: - type: object - required: - - status - - message - - response - properties: - status: - type: integer - enum: [500] - message: - type: string - description: Description of response result - response: - type: object - description: Empty object - responses: - JobResponse: - description: Successfully requested job - content: - application/json: - schema: - $ref: '#/components/schemas/Job' - InternalErrorResponse: - description: Unexpected internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/InternalError' \ No newline at end of file diff --git a/configs/example.yaml b/configs/example.yaml index 25f6f3b..bb38bbf 100644 --- a/configs/example.yaml +++ b/configs/example.yaml @@ -8,9 +8,11 @@ operations: # - role: stt # id: fish # - role: stt -# id: kobold -# suppress_non_speech: true -# langcode: en +# id: whispercpp +# model_filepath: E:\\jaison-core\\models\\whispercpp\\ggml-base.en.bin +# language: en +# temperature: 0.0 +# response_format: json - role: stt id: openai # Openai STT https://platform.openai.com/docs/guides/speech-to-text base_url: https://api.openai.com/v1/ @@ -19,19 +21,15 @@ operations: # T2T # - role: t2t -# id: kobold -# max_context_length: 2048 -# max_length: 100 -# quiet: true -# rep_pen: 1.1 -# rep_pen_range: 256 -# rep_pen_slope: 1 -# temperature: 0.5 -# tfs: 1 -# top_a: 0 -# top_k: 100 -# top_p: 0.9 -# typical: 1 +# id: llamacpp +# model_filepath: E:\\jaison-core\\models\\llamacpp\\your-model.gguf +# ctx_size: 2048 +# n_predict: 256 +# temperature: 0.8 +# top_p: 0.95 +# top_k: 40 +# dry_multiplier: 0.5 +# samplers: [dry, top_k, typ_p, top_p, min_p, temperature] - role: t2t id: openai base_url: https://api.openai.com/v1/ @@ -62,28 +60,25 @@ operations: # normalize: true # latency: normal # - role: tts -# id: kobold -# voice: kobo -- role: tts - id: openai # OpenAI TTSG https://platform.openai.com/docs/guides/text-to-speech - base_url: https://api.openai.com/v1/ - voice: nova - model: tts-1 +# id: openai # OpenAI TTSG https://platform.openai.com/docs/guides/text-to-speech +# base_url: https://api.openai.com/v1/ +# voice: nova +# model: tts-1 # - role: tts # id: pytts # voice: 'HKEY_LOCAL_MACHINE\\SOFTWARE\\Microsoft\\Speech\\Voices\\Tokens\\TTS_MS_EN-US_ZIRA_11.0' # gender: female -# - role: tts -# id: melo -# config_filepath: null -# model_filepath: null -# speaker_id: EN-Default -# device: cuda -# language: EN -# sdp_ratio: 0.8 -# noise_scale: 0.6 -# noise_scale_w: 0.8 -# speed: 1.0 +- role: tts + id: melo + config_filepath: null + model_filepath: null + speaker_id: EN-Default + device: cpu + language: EN + sdp_ratio: 0.8 + noise_scale: 0.6 + noise_scale_w: 0.8 + speed: 1.0 # Audio filters # - role: filter_audio @@ -94,8 +89,8 @@ operations: # voice: my-voice-model # f0_up_key: 0 # f0_method: rmvpe -# f0_file: null -# index_file: null +# f0_filepath: null +# index_filepath: null # index_rate: 0.0 # filter_radius: 3 # resample_sr: 0 @@ -113,6 +108,11 @@ operations: # frequency_penalty: 0.2 # Embedding model +# - role: embedding +# id: llamacpp +# model_filepath: E:\\jaison-core\\models\\llamacpp\\your-embedding-model.gguf +# pooling: mean +# embd_normalize: 2 - role: embedding id: openai base_url: https://api.openai.com/v1/ @@ -136,9 +136,5 @@ prompter: "old name": "new name" history_length: 20 -# Kobold -kobold_filepath: E:\\jaison-core\\models\\kobold\\koboldcpp_cu12.exe # must be absolute -kcpps_filepath: E:\\jaison-core\\models\\kobold\\save.kcpps # must be absolute - # Spacy NLP spacy_model: en_core_web_sm diff --git a/models/kobold/.gitignore b/models/llamacpp/.gitignore similarity index 100% rename from models/kobold/.gitignore rename to models/llamacpp/.gitignore diff --git a/models/whispercpp/.gitignore b/models/whispercpp/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/models/whispercpp/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index ce2513d..7355e3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,11 +13,9 @@ classifiers = [ "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", ] -requires-python = ">=3.10,<3.11" +requires-python = ">=3.12,<3.13" dependencies = [ "av>=17.0.1", "nltk>=3.9", @@ -25,6 +23,7 @@ dependencies = [ "en-core-web-sm", "ffmpeg-python>=0.2.0", "fish-audio-sdk>=1.3.0", + "httpx>=0.28.0", "huggingface-hub>=0.24.0", "mcp>=1.27.1", "melotts", @@ -35,7 +34,6 @@ dependencies = [ "python-dotenv>=1.2.2", "pyttsx3>=2.99", "pyyaml>=6.0.3", - "quart>=0.20.0", "requests>=2.34.2", "rvc", "soundfile<0.13.0", @@ -45,6 +43,9 @@ dependencies = [ "torch==2.5.1", "torchvision==0.20.1", "torchaudio==2.5.1", + "fastapi>=0.136.3", + "uvicorn>=0.48.0", + "rich>=13.9.4", ] [dependency-groups] @@ -59,25 +60,23 @@ dev = [ [tool.uv.sources] en-core-web-sm = { url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl" } -rvc = { git = "https://github.com/limitcantcode/Retrieval-based-Voice-Conversion.git" } -melotts = { git = "https://github.com/myshell-ai/MeloTTS", rev = "209145371cff8fc3bd60d7be902ea69cbdb7965a" } +rvc = { git = "https://github.com/limitcantcode/Retrieval-based-Voice-Conversion.git", rev = "8f010c0f9873a13d3dddd88881e25594af2561a9" } +melotts = { git = "https://github.com/limitcantcode/MeloTTS", rev = "3a60df458cd7abb907af7f2d78fc6e2681c6d640" } [tool.uv] override-dependencies = [ - "python-multipart>=0.0.9", - "av>=17.0.1", - "transformers==4.52.4", - "librosa>=0.10.1,<0.11.0", ] [tool.black] line-length = 100 -target-version = ["py310", "py311", "py312"] +target-version = ["py312"] include = '\.pyi?$' [tool.ruff] line-length = 100 -target-version = "py38" +target-version = "py312" + +[tool.ruff.lint] select = [ "E", # pycodestyle errors "W", # pycodestyle warnings @@ -92,11 +91,11 @@ ignore = [ "B008", # do not perform function calls in argument defaults ] -[tool.ruff.isort] +[tool.ruff.lint.isort] known-first-party = ["jaison-core"] [tool.mypy] -python_version = "3.8" +python_version = "3.12" warn_return_any = true warn_unused_configs = true disallow_untyped_defs = false diff --git a/scripts/bootstrap_runtime_deps.py b/scripts/bootstrap_runtime_deps.py index 9884896..da9381d 100644 --- a/scripts/bootstrap_runtime_deps.py +++ b/scripts/bootstrap_runtime_deps.py @@ -3,25 +3,25 @@ (Hubert content encoder + RMVPE pitch; see `VC.vc_inference` / `Pipeline.get_f0`). Layout: - models/kobold/koboldcpp[.exe] + bin/whisper-server[.exe] (+ bundled DLLs on Windows) + bin/llama-server[.exe] (+ bundled libs; shared with ffmpeg-static) models/rvc/base/hubert/hubert_base.pt models/rvc/base/rmvpe/rmvpe.pt models/rvc/weights/ — place your *.pth voice checkpoints here (see weight_root) -Not downloaded (unused by VC inference path): pretrained G/D blobs, UVR5 weights, -rmvpe.onnx (only needed for DirectML / DeviceType privateuseone). - Environment (optional): - KOBOLDCPP_VARIANT cuda | nocuda | oldpc (default: cuda — NVIDIA-oriented builds where available) - KOBOLDCPP_ASSET_NAME Full GitHub asset filename override (needed for linux aarch64: no official v1.113.2 build) - KOBOLDCPP_SKIP set to 1 to skip KoboldCPP download - RVC_SKIP set to 1 to skip RVC HF download - RVC_ASSETS_REVISION Hugging Face revision for lj1995/VoiceConversionWebUI (default: main) - -Hubert + rmvpe file source: Hugging Face `lj1995/VoiceConversionWebUI` -(legacy parity with subsets of Retrieval-based-Voice-Conversion tooling). - -See: KoboldCPP https://github.com/LostRuins/koboldcpp/releases/tag/v1.113.2 + WHISPERCPP_VARIANT cuda | cpu (default: cuda on Windows x64 when available, else cpu) + WHISPERCPP_ASSET_NAME Full GitHub release asset filename override + WHISPERCPP_SKIP set to 1 to skip whisper.cpp server download + LLAMACPP_VARIANT cuda | cpu | rocm | vulkan (default: cuda on Windows x64 when available) + LLAMACPP_ASSET_NAME Full GitHub release asset filename override + LLAMACPP_SKIP set to 1 to skip llama.cpp server download + RVC_SKIP set to 1 to skip RVC HF download + RVC_ASSETS_REVISION Hugging Face revision for lj1995/VoiceConversionWebUI (default: main) + +See: + whisper.cpp https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.4 + llama.cpp https://github.com/ggml-org/llama.cpp/releases/tag/b9381 """ from __future__ import annotations @@ -32,14 +32,22 @@ import shutil import stat import sys +import tarfile import tempfile +import zipfile +from collections.abc import Callable +from dataclasses import dataclass from pathlib import Path +from typing import Literal import requests - -KOBOLDCPP_TAG = "v1.113.2" -RELEASE_BASE = f"https://github.com/LostRuins/koboldcpp/releases/download/{KOBOLDCPP_TAG}" +WHISPERCPP_TAG = "v1.8.4" +WHISPERCPP_RELEASE_BASE = ( + f"https://github.com/ggml-org/whisper.cpp/releases/download/{WHISPERCPP_TAG}" +) +LLAMACPP_TAG = "b9381" +LLAMACPP_RELEASE_BASE = f"https://github.com/ggml-org/llama.cpp/releases/download/{LLAMACPP_TAG}" FFMPEG_STATIC_TAG = "b6.1.1" FFMPEG_STATIC_RELEASE_BASE = ( @@ -73,125 +81,322 @@ def _download_url(url: str, dest: Path, *, chunk: int = 8 * 1024 * 1024) -> None partial.replace(dest) -def _resolve_kobold_asset_name(machine: str, system: str, variant: str) -> str | None: - """Return GitHub release asset basename, or None if user must supply KOBOLDCPP_ASSET_NAME.""" - ov = os.environ.get("KOBOLDCPP_ASSET_NAME", "").strip() - if ov: - return ov +@dataclass(frozen=True) +class _ReleaseAsset: + filename: str + archive: Literal["zip", "tar.gz"] + extract_mode: Literal["flat", "release_dir"] - m = machine.lower() + +def _server_binary_name(system: str, base: str) -> str: + return f"{base}.exe" if system == "Windows" else base + + +def _normalize_variant(variant: str, allowed: set[str], default: str) -> str: v = variant.lower().strip() + if v not in allowed: + print(f"Unknown variant {variant!r}; using {default}.") + return default + return v - if system == "Windows": - is_arm_win = os.environ.get("PROCESSOR_ARCHITECTURE", "").upper() == "ARM64" or "arm" in m - if is_arm_win: - print( - " Windows ARM: using koboldcpp-nocuda.exe (no CUDA KoboldCPP for this CPU class)." - ) - return "koboldcpp-nocuda.exe" - if v not in {"cuda", "nocuda", "oldpc"}: - print(f"Unknown KOBOLDCPP_VARIANT={v!r}; using cuda.") - v = "cuda" - names = { - "cuda": "koboldcpp.exe", - "nocuda": "koboldcpp-nocuda.exe", - "oldpc": "koboldcpp-oldpc.exe", - } - return names[v] - if system == "Linux": - if m in {"aarch64", "arm64"}: - print( - "ERROR: KoboldCPP v1.113.2 has no official linux aarch64 GitHub asset.\n" - " Set KOBOLDCPP_ASSET_NAME to a release asset name you can run on this machine,\n" - " or set KOBOLDCPP_SKIP=1 and install KoboldCPP manually into models/kobold/.\n" - " Release index: " - + f"https://github.com/LostRuins/koboldcpp/releases/tag/{KOBOLDCPP_TAG}", - file=sys.stderr, - ) +def _resolve_whisper_asset(machine: str, system: str, variant: str) -> _ReleaseAsset | None: + override = os.environ.get("WHISPERCPP_ASSET_NAME", "").strip() + if override: + archive: Literal["zip", "tar.gz"] = ( + "tar.gz" if override.endswith(".tar.gz") else "zip" + ) + mode: Literal["flat", "release_dir"] = ( + "release_dir" if archive == "zip" and "bin" in override else "flat" + ) + return _ReleaseAsset(override, archive, mode) + + m = machine.lower() + v = _normalize_variant(variant, {"cuda", "cpu"}, "cuda") + + if system == "Windows": + is_arm = os.environ.get("PROCESSOR_ARCHITECTURE", "").upper() == "ARM64" or "arm" in m + if is_arm: + print(" Windows ARM: whisper.cpp CUDA builds are x64-only; using cpu.") + v = "cpu" + if m not in {"x86_64", "amd64"} and not is_arm: + print(f"ERROR: Unsupported Windows machine for whisper.cpp: {machine!r}", file=sys.stderr) return None + if v == "cuda": + return _ReleaseAsset("whisper-cublas-12.4.0-bin-x64.zip", "zip", "release_dir") + return _ReleaseAsset("whisper-bin-x64.zip", "zip", "release_dir") - if m in {"x86_64", "amd64"}: - if v not in {"cuda", "nocuda", "oldpc"}: - print(f"Unknown KOBOLDCPP_VARIANT={v!r}; using cuda.") - v = "cuda" - names = { - "cuda": "koboldcpp-linux-x64", - "nocuda": "koboldcpp-linux-x64-nocuda", - "oldpc": "koboldcpp-linux-x64-oldpc", - } - return names[v] - - print(f"ERROR: Unsupported Linux machine type: {machine!r}", file=sys.stderr) + if system == "Linux": + print( + "ERROR: whisper.cpp v1.8.4 has no official Linux server binary in GitHub releases.\n" + " Set WHISPERCPP_ASSET_NAME to a compatible asset, or WHISPERCPP_SKIP=1 and install manually.\n" + f" Release index: https://github.com/ggml-org/whisper.cpp/releases/tag/{WHISPERCPP_TAG}", + file=sys.stderr, + ) return None if system == "Darwin": - if m in {"aarch64", "arm64"}: - return "koboldcpp-mac-arm64" print( - f"ERROR: KoboldCPP v1.113.2 has no macOS Intel build; machine={machine!r}", + "ERROR: whisper.cpp v1.8.4 macOS release ships an xcframework only (no whisper-server binary).\n" + " Set WHISPERCPP_SKIP=1 and install whisper-server manually, or set WHISPERCPP_ASSET_NAME.", file=sys.stderr, ) return None - print(f"ERROR: Unsupported OS for Kobold bootstrap: {system!r}", file=sys.stderr) + print(f"ERROR: Unsupported OS for whisper.cpp bootstrap: {system!r}", file=sys.stderr) return None -def _kobold_destination_name(system: str) -> Path: - return Path("koboldcpp.exe") if system == "Windows" else Path("koboldcpp") +def _resolve_llama_asset(machine: str, system: str, variant: str) -> _ReleaseAsset | None: + override = os.environ.get("LLAMACPP_ASSET_NAME", "").strip() + if override: + archive: Literal["zip", "tar.gz"] = ( + "tar.gz" if override.endswith(".tar.gz") else "zip" + ) + return _ReleaseAsset(override, archive, "flat") + m = machine.lower() + v = _normalize_variant(variant, {"cuda", "cpu", "rocm", "vulkan"}, "cuda") + tag = LLAMACPP_TAG -def _maybe_skip_kobold(out_dir: Path, force: bool) -> bool: - marker = out_dir / ".koboldcpp-version" - dest_name = _kobold_destination_name(platform.system()) - binary = out_dir / dest_name + if system == "Windows": + is_arm = os.environ.get("PROCESSOR_ARCHITECTURE", "").upper() == "ARM64" or "arm" in m + if is_arm: + if v == "cuda": + print(" Windows ARM: llama.cpp CUDA builds are x64-only; using cpu arm64.") + v = "cpu" + return _ReleaseAsset(f"llama-{tag}-bin-win-cpu-arm64.zip", "zip", "flat") + if m not in {"x86_64", "amd64"}: + print(f"ERROR: Unsupported Windows machine for llama.cpp: {machine!r}", file=sys.stderr) + return None + if v == "cuda": + return _ReleaseAsset(f"llama-{tag}-bin-win-cuda-12.4-x64.zip", "zip", "flat") + if v == "vulkan": + return _ReleaseAsset(f"llama-{tag}-bin-win-vulkan-x64.zip", "zip", "flat") + return _ReleaseAsset(f"llama-{tag}-bin-win-cpu-x64.zip", "zip", "flat") - version_label = marker.read_text(encoding="utf-8").strip() if marker.is_file() else None - if not force and binary.is_file() and version_label == KOBOLDCPP_TAG: - print(f"KoboldCPP {KOBOLDCPP_TAG} already present at {binary}, skipping.") - return True - return False + if system == "Linux": + if m in {"x86_64", "amd64"}: + if v == "cuda": + print( + " Linux x64: no official CUDA zip in this release; using ubuntu cpu build.", + file=sys.stderr, + ) + v = "cpu" + if v == "rocm": + return _ReleaseAsset(f"llama-{tag}-bin-ubuntu-rocm-7.2-x64.tar.gz", "tar.gz", "flat") + if v == "vulkan": + return _ReleaseAsset( + f"llama-{tag}-bin-ubuntu-vulkan-x64.tar.gz", "tar.gz", "flat" + ) + return _ReleaseAsset(f"llama-{tag}-bin-ubuntu-x64.tar.gz", "tar.gz", "flat") + if m in {"aarch64", "arm64"}: + if v == "vulkan": + return _ReleaseAsset( + f"llama-{tag}-bin-ubuntu-vulkan-arm64.tar.gz", "tar.gz", "flat" + ) + return _ReleaseAsset(f"llama-{tag}-bin-ubuntu-arm64.tar.gz", "tar.gz", "flat") + print(f"ERROR: Unsupported Linux machine for llama.cpp: {machine!r}", file=sys.stderr) + return None + if system == "Darwin": + if m in {"aarch64", "arm64"}: + return _ReleaseAsset(f"llama-{tag}-bin-macos-arm64.tar.gz", "tar.gz", "flat") + if m == "x86_64": + return _ReleaseAsset(f"llama-{tag}-bin-macos-x64.tar.gz", "tar.gz", "flat") + print(f"ERROR: Unsupported macOS machine for llama.cpp: {machine!r}", file=sys.stderr) + return None -def download_koboldcpp(*, project_root: Path, force: bool) -> None: - if os.environ.get("KOBOLDCPP_SKIP") == "1": - print("KoboldCPP download skipped (KOBOLDCPP_SKIP=1).") - return + print(f"ERROR: Unsupported OS for llama.cpp bootstrap: {system!r}", file=sys.stderr) + return None - out_dir = project_root / "models" / "kobold" + +def _marker_stamp(tag: str, variant: str, asset: _ReleaseAsset) -> str: + return f"{tag}\n{variant}\n{asset.filename}\n" + + +def _installed_marker_matches(marker: Path, stamp: str, server_bin: Path, *, force: bool) -> bool: + if force or not server_bin.is_file(): + return False + if not marker.is_file(): + return False + return marker.read_text(encoding="utf-8") == stamp + + +def _chmod_executable(path: Path) -> None: + mode = path.stat().st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH + path.chmod(mode) + + +def _extract_zip(archive: Path, dest_dir: Path, *, extract_mode: Literal["flat", "release_dir"]) -> None: + prefix = "Release/" if extract_mode == "release_dir" else "" + with zipfile.ZipFile(archive) as zf: + for member in zf.namelist(): + if prefix and not member.startswith(prefix): + continue + rel = member[len(prefix) :] if prefix else member + if not rel or rel.endswith("/"): + continue + target = dest_dir / rel + target.parent.mkdir(parents=True, exist_ok=True) + with zf.open(member) as src, target.open("wb") as dst: + shutil.copyfileobj(src, dst) + + +def _extract_tar_gz(archive: Path, dest_dir: Path) -> None: + with tarfile.open(archive, "r:gz") as tf: + tf.extractall(dest_dir, filter="data") + + +def _install_release_tree(staging_dir: Path, dest_dir: Path, *, merge: bool = False) -> None: + if not merge and dest_dir.exists(): + shutil.rmtree(dest_dir) + dest_dir.mkdir(parents=True, exist_ok=True) + children = list(staging_dir.iterdir()) + if len(children) == 1 and children[0].is_dir(): + source_root = children[0] + else: + source_root = staging_dir + for item in source_root.iterdir(): + target = dest_dir / item.name + if item.is_dir(): + shutil.copytree(item, target, dirs_exist_ok=merge) + else: + shutil.copy2(item, target) + + +def _download_release_asset( + *, + label: str, + tag: str, + release_base: str, + out_dir: Path, + asset: _ReleaseAsset, + server_name: str, + variant: str, + force: bool, + resolve_fallback: Callable[[str, str, str], _ReleaseAsset | None] | None = None, +) -> None: out_dir.mkdir(parents=True, exist_ok=True) + server_bin = out_dir / server_name + marker = out_dir / f".{label}-version" + stamp = _marker_stamp(tag, variant, asset) + + if _installed_marker_matches(marker, stamp, server_bin, force=force): + print(f"{label} {tag} ({asset.filename}) already present at {server_bin}, skipping.") + return + + candidates = [asset] + if resolve_fallback and variant == "cuda": + fallback = resolve_fallback(platform.machine(), platform.system(), "cpu") + if fallback and fallback.filename != asset.filename: + candidates.append(fallback) + + last_error: Exception | None = None + for candidate in candidates: + if candidate is not asset: + print(f" Falling back to {candidate.filename} ...") + stamp = _marker_stamp(tag, "cpu", candidate) + + url = f"{release_base}/{candidate.filename}" + print(f"Downloading {label} {tag}: {candidate.filename} ...") + + with tempfile.TemporaryDirectory(prefix=f"{label}_") as tmp_s: + tmp = Path(tmp_s) + archive_path = tmp / candidate.filename + try: + _download_url(url, archive_path) + except requests.HTTPError as e: + last_error = e + if candidate is not candidates[-1]: + continue + raise + + extract_root = tmp / "extract" + extract_root.mkdir() + if candidate.archive == "zip": + _extract_zip(archive_path, extract_root, extract_mode=candidate.extract_mode) + else: + _extract_tar_gz(archive_path, extract_root) + + _install_release_tree(extract_root, out_dir, merge=out_dir.name == "bin") + + if not server_bin.is_file(): + raise SystemExit( + f"{label} install incomplete: expected {server_bin} after extracting {candidate.filename}" + ) + + if platform.system() != "Windows": + _chmod_executable(server_bin) + + marker.write_text(stamp, encoding="utf-8") + print(f" -> {server_bin}") + return + + if last_error: + raise last_error - if _maybe_skip_kobold(out_dir, force): + +def download_whispercpp(*, project_root: Path, force: bool) -> None: + if os.environ.get("WHISPERCPP_SKIP") == "1": + print("whisper.cpp download skipped (WHISPERCPP_SKIP=1).") return system = platform.system() - variant = os.environ.get("KOBOLDCPP_VARIANT", "cuda") - asset = _resolve_kobold_asset_name(platform.machine(), system, variant) + machine = platform.machine() + variant = os.environ.get("WHISPERCPP_VARIANT", "cuda") + asset = _resolve_whisper_asset(machine, system, variant) if asset is None: raise SystemExit(1) - url = f"{RELEASE_BASE}/{asset}" - dest_bin = _kobold_destination_name(system) - staging = out_dir / asset - final = out_dir / dest_bin - - print(f"Downloading KoboldCPP {KOBOLDCPP_TAG}: {asset} ...") - _download_url(url, staging) - - if final != staging: - if final.exists(): - final.unlink() + out_dir = project_root / "bin" + server_name = _server_binary_name(system, "whisper-server") + + def _cpu_fallback(m: str, s: str, _v: str) -> _ReleaseAsset | None: + return _resolve_whisper_asset(m, s, "cpu") + + _download_release_asset( + label="whispercpp", + tag=WHISPERCPP_TAG, + release_base=WHISPERCPP_RELEASE_BASE, + out_dir=out_dir, + asset=asset, + server_name=server_name, + variant=variant, + force=force, + resolve_fallback=_cpu_fallback if variant == "cuda" else None, + ) - shutil.move(staging, final) - if system != "Windows": - mode = final.stat().st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH - final.chmod(mode) +def download_llamacpp(*, project_root: Path, force: bool) -> None: + if os.environ.get("LLAMACPP_SKIP") == "1": + print("llama.cpp download skipped (LLAMACPP_SKIP=1).") + return - (out_dir / ".koboldcpp-version").write_text(KOBOLDCPP_TAG + "\n", encoding="utf-8") + system = platform.system() + machine = platform.machine() + variant = os.environ.get("LLAMACPP_VARIANT", "cuda") + asset = _resolve_llama_asset(machine, system, variant) + if asset is None: + raise SystemExit(1) - print(f" -> {final}") + out_dir = project_root / "bin" + server_name = _server_binary_name(system, "llama-server") + + def _cpu_fallback(m: str, s: str, _v: str) -> _ReleaseAsset | None: + return _resolve_llama_asset(m, s, "cpu") + + _download_release_asset( + label="llamacpp", + tag=LLAMACPP_TAG, + release_base=LLAMACPP_RELEASE_BASE, + out_dir=out_dir, + asset=asset, + server_name=server_name, + variant=variant, + force=force, + resolve_fallback=_cpu_fallback if variant == "cuda" else None, + ) def _ffmpeg_static_asset_names(system: str, machine: str) -> tuple[str, str] | None: @@ -219,7 +424,9 @@ def _ffmpeg_static_asset_names(system: str, machine: str) -> tuple[str, str] | N return ("ffmpeg-linux-arm", "ffprobe-linux-arm") if m in {"i386", "i686"}: return ("ffmpeg-linux-ia32", "ffprobe-linux-ia32") - print(f"ERROR: Unsupported Linux machine type for ffmpeg-static: {machine!r}", file=sys.stderr) + print( + f"ERROR: Unsupported Linux machine type for ffmpeg-static: {machine!r}", file=sys.stderr + ) return None if sys_norm == "Darwin": @@ -227,7 +434,9 @@ def _ffmpeg_static_asset_names(system: str, machine: str) -> tuple[str, str] | N return ("ffmpeg-darwin-arm64", "ffprobe-darwin-arm64") if m in {"x86_64"}: return ("ffmpeg-darwin-x64", "ffprobe-darwin-x64") - print(f"ERROR: Unsupported macOS machine type for ffmpeg-static: {machine!r}", file=sys.stderr) + print( + f"ERROR: Unsupported macOS machine type for ffmpeg-static: {machine!r}", file=sys.stderr + ) return None print(f"ERROR: Unsupported OS for ffmpeg-static bootstrap: {system!r}", file=sys.stderr) @@ -305,9 +514,7 @@ def download_ffmpeg_and_ffprobe(*, project_root: Path, force: bool) -> None: mode = final.stat().st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH final.chmod(mode) - (bin_dir / ".ffmpeg-static-version").write_text( - FFMPEG_STATIC_TAG + "\n", encoding="utf-8" - ) + (bin_dir / ".ffmpeg-static-version").write_text(FFMPEG_STATIC_TAG + "\n", encoding="utf-8") print(f" ffmpeg -> {ffmpeg_final}") print(f" ffprobe -> {ffprobe_final}") @@ -406,14 +613,15 @@ def download_rvc_assets(*, project_root: Path, force: bool, revision: str | None def parse_args(argv: list[str]) -> argparse.Namespace: p = argparse.ArgumentParser( - description="Download KoboldCPP, ffmpeg/ffprobe, and RVC runtime assets." + description="Download whisper.cpp, llama.cpp, ffmpeg/ffprobe, and RVC runtime assets." ) p.add_argument( "--force", action="store_true", help="Re-download even if version markers indicate an existing install.", ) - p.add_argument("--skip-kobold", action="store_true", help="Skip KoboldCPP.") + p.add_argument("--skip-whispercpp", action="store_true", help="Skip whisper.cpp server.") + p.add_argument("--skip-llamacpp", action="store_true", help="Skip llama.cpp server.") p.add_argument( "--skip-ffmpeg", action="store_true", @@ -430,13 +638,15 @@ def main(argv: list[str] | None = None) -> None: args = parse_args(sys.argv[1:] if argv is None else argv) root = _project_root() - if args.skip_kobold: - os.environ["KOBOLDCPP_SKIP"] = "1" - + if args.skip_whispercpp: + os.environ["WHISPERCPP_SKIP"] = "1" + if args.skip_llamacpp: + os.environ["LLAMACPP_SKIP"] = "1" if args.skip_rvc: os.environ["RVC_SKIP"] = "1" - download_koboldcpp(project_root=root, force=args.force) + download_whispercpp(project_root=root, force=args.force) + download_llamacpp(project_root=root, force=args.force) if not args.skip_ffmpeg: download_ffmpeg_and_ffprobe(project_root=root, force=args.force) download_rvc_assets(project_root=root, force=args.force, revision=args.rvc_revision) diff --git a/src/main.py b/src/main.py index 06747aa..2edbcc7 100644 --- a/src/main.py +++ b/src/main.py @@ -2,13 +2,14 @@ setup_logger() -from utils.args import args -from dotenv import load_dotenv +from dotenv import load_dotenv # noqa: E402 + +from utils.args import args # noqa: E402 load_dotenv(dotenv_path=args.env) -import os -from pathlib import Path +import os # noqa: E402 +from pathlib import Path # noqa: E402 # Patch path for local binaries project_root = Path(__file__).resolve().parents[1] @@ -16,7 +17,14 @@ if bin_dir.is_dir(): os.environ["PATH"] = f"{bin_dir}{os.pathsep}{os.environ.get('PATH', '')}" -import asyncio -from utils.server import start_web_server +import uvicorn # noqa: E402 + +from utils.server import app # noqa: E402 -asyncio.run(start_web_server()) +uvicorn.run( + app, + host=args.host, + port=args.port, + log_level=args.log_level.lower(), + log_config=None, +) diff --git a/src/utils/args.py b/src/utils/args.py index 0c682c2..314ee0b 100644 --- a/src/utils/args.py +++ b/src/utils/args.py @@ -2,11 +2,34 @@ import os args = argparse.ArgumentParser() -args.add_argument('-e', '--env', default=None, type=str, help='Filepath to .env if located elsewhere') -args.add_argument('-c', '--config', default=None, type=str, help='Filename to your yaml config. For example: "example" refers to configs/example.yaml') -args.add_argument('--host', default='127.0.0.1', type=str, help='IP to use as host API and websocket server on. Try 0.0.0.0 for cross machine access to API.') -args.add_argument('--port', default=7272, type=int, help='Post to host API and websocket server on.') -args.add_argument('--log_level', default='INFO', type=str, choices=['DEBUG','INFO','WARNING','ERROR','CRITICAL'], help='Level of logs to show') -args.add_argument('--log_dir', default=os.path.join(os.getcwd(), 'logs'), type=str, help='Storing folder for logs') -args.add_argument('--silent', action='store_true', help='Suppress console outputs') -args = args.parse_args() \ No newline at end of file +args.add_argument( + "-e", "--env", default=None, type=str, help="Filepath to .env if located elsewhere" +) +args.add_argument( + "-c", + "--config", + default=None, + type=str, + help='Filename to your yaml config. For example: "example" refers to configs/example.yaml', +) +args.add_argument( + "--host", + default="127.0.0.1", + type=str, + help="IP to use as host API and websocket server on. Try 0.0.0.0 for cross machine access to API.", +) +args.add_argument( + "--port", default=7272, type=int, help="Post to host API and websocket server on." +) +args.add_argument( + "--log_level", + default="INFO", + type=str, + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Level of logs to show", +) +args.add_argument( + "--log_dir", default=os.path.join(os.getcwd(), "logs"), type=str, help="Storing folder for logs" +) +args.add_argument("--silent", action="store_true", help="Suppress console outputs") +args = args.parse_args() diff --git a/src/utils/config.py b/src/utils/config.py index 83a4874..41ae77d 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -1,70 +1,74 @@ import os +from typing import get_type_hints + import yaml -from typing import get_type_hints, List, Dict -from .helpers.singleton import Singleton -from .helpers.path import portable_path + from .args import args +from .helpers.path import portable_path + class UnknownField(Exception): def __init__(self, field: str): - super().__init__("Config field {} does not exist".format(field)) + super().__init__(f"Config field {field} does not exist") + class UnknownFile(Exception): def __init__(self, filepath: str): - super().__init__("Config file {} does not exist".format(filepath)) + super().__init__(f"Config file {filepath} does not exist") -class Config(metaclass=Singleton): + +class Config: # Every attribute must be typed for validation CONFIG_DIR: str = portable_path(os.path.join(os.getcwd(), "configs")) - WORKING_DIR: str = portable_path(os.path.join(os.getcwd(),"output","temp")) + WORKING_DIR: str = portable_path(os.path.join(os.getcwd(), "output", "temp")) current_config: str = "Unsaved" - + # Defaults - operations: list = list() - + operations: list = [] + # Prompter PROMPT_DIR: str = portable_path(os.path.join(os.getcwd(), "prompts")) PROMPT_INSTRUCTION_SUBDIR: str = "instructions" PROMPT_CHARACTER_SUBDIR: str = "characters" PROMPT_SCENE_SUBDIR: str = "scenes" - - prompter: dict = dict() - history_filepath: str = portable_path(os.path.join(os.getcwd(), "output", "history.txt")) # debug + + prompter: dict = {} + history_filepath: str = portable_path( + os.path.join(os.getcwd(), "output", "history.txt") + ) # debug # MCP MCP_DIR: str = portable_path(os.path.join(os.getcwd(), "models", "mcp")) - mcp: list = list() + mcp: list = [] - # Kobold - kobold_filepath: str = None - kcpps_filepath: str = None - # Melo MELO_DIR: str = portable_path(os.path.join(os.getcwd(), "models", "melotts")) # Shared - stt_working_src: str = portable_path(os.path.join(WORKING_DIR,'stt_src.wav')) - ffmpeg_working_src: str = portable_path(os.path.join(WORKING_DIR,'ffmpeg_src.wav')) - ffmpeg_working_dest: str = portable_path(os.path.join(WORKING_DIR,'ffmpeg_dest.wav')) + stt_working_src: str = portable_path(os.path.join(WORKING_DIR, "stt_src.wav")) + ffmpeg_working_src: str = portable_path(os.path.join(WORKING_DIR, "ffmpeg_src.wav")) + ffmpeg_working_dest: str = portable_path(os.path.join(WORKING_DIR, "ffmpeg_dest.wav")) spacy_model: str = None - + def __init__(self): # Every attribute must be typed for validation - if args.config is not None: self.load_from_name(args.config) - - # Can raise: FileNotFoundError, + if args.config is not None: + self.load_from_name(args.config) + + # Can raise: FileNotFoundError, def load_from_name(self, config_name: str): - filepath = os.path.join(self.CONFIG_DIR, config_name+".yaml") - if not os.path.isfile(filepath): + filepath = os.path.join(self.CONFIG_DIR, config_name + ".yaml") + if not os.path.isfile(filepath): filepath = os.path.join(self.CONFIG_DIR, config_name) - if not os.path.isfile(filepath): raise UnknownFile(filepath) - + if not os.path.isfile(filepath): + raise UnknownFile(filepath) + with open(filepath) as f: conf_d = yaml.safe_load(f) - + self.load_from_dict(**conf_d) self.current_config = config_name - + def load_from_dict(self, **conf_d): uncommitted = dict(conf_d) config_typings = get_type_hints(Config) @@ -73,17 +77,22 @@ def load_from_dict(self, **conf_d): for field in conf_d: if field not in config_typings: raise UnknownField(field) - uncommitted[field] = config_typings[field](conf_d[field]) if conf_d[field] is not None else None # attempt cast to correct typing - + uncommitted[field] = ( + config_typings[field](conf_d[field]) if conf_d[field] is not None else None + ) # attempt cast to correct typing + # Commit config change request for field in uncommitted: setattr(self, field, uncommitted[field]) - + self.current_config = "Unsaved" - + def save(self, config_name: str): with open(portable_path(os.path.join(self.CONFIG_DIR, config_name))) as f: - yaml.dump(self.get_config_dict(),f) - + yaml.dump(self.get_config_dict(), f) + def get_config_dict(self): - return vars(self) \ No newline at end of file + return vars(self) + + +config = Config() diff --git a/src/utils/helpers/audio.py b/src/utils/helpers/audio.py index 6d26413..e713b28 100644 --- a/src/utils/helpers/audio.py +++ b/src/utils/helpers/audio.py @@ -1,31 +1,22 @@ import wave + import ffmpeg -from utils.config import Config +from utils.config import config + def pitch_audio(ab: bytes, sr: int, sw: int, ch: int, pitch_amount: int): # ffmpeg -i "input.wav" -af "rubberband=smoothing=on:pitch=2^(1/2):pitchq=quality:window=short:channels=apart:phase=independent" "output.wav" - speed_factor = 2 ** (pitch_amount/12) - - with wave.open(Config().ffmpeg_working_src, 'wb') as f: + speed_factor = 2 ** (pitch_amount / 12) + + with wave.open(config.ffmpeg_working_src, "wb") as f: f.setframerate(sr) f.setsampwidth(sw) f.setnchannels(ch) f.writeframes(ab) - - ffmpeg.input( - Config().ffmpeg_working_src - ).filter( - "atempo", - 1/speed_factor - ).filter( - "asetrate", - sr*speed_factor - ).output( - Config().ffmpeg_working_dest - ).run( - overwrite_output=True, - quiet=True - ) - with wave.open(Config().ffmpeg_working_dest, 'r') as f: + + ffmpeg.input(config.ffmpeg_working_src).filter("atempo", 1 / speed_factor).filter( + "asetrate", sr * speed_factor + ).output(config.ffmpeg_working_dest).run(overwrite_output=True, quiet=True) + with wave.open(config.ffmpeg_working_dest, "r") as f: return f.readframes(f.getnframes()), f.getframerate(), f.getsampwidth(), f.getnchannels() diff --git a/src/utils/helpers/iterable.py b/src/utils/helpers/iterable.py index 921f7fa..bd2eb51 100644 --- a/src/utils/helpers/iterable.py +++ b/src/utils/helpers/iterable.py @@ -1,14 +1,15 @@ CHUNK_SIZE = 4096 + async def list_to_agen(target_list): for item in target_list: yield item - - + + def chunk_buffer(buf): - chunks = list() + chunks = [] while len(buf) > 0: chunks.append(buf[:CHUNK_SIZE]) buf = buf[CHUNK_SIZE:] - - return chunks \ No newline at end of file + + return chunks diff --git a/src/utils/helpers/multiplexor.py b/src/utils/helpers/multiplexor.py index 9f95087..5896d71 100644 --- a/src/utils/helpers/multiplexor.py +++ b/src/utils/helpers/multiplexor.py @@ -1,44 +1,55 @@ -from typing import List, Callable, AsyncGenerator, Dict, Tuple -import logging import asyncio +from collections.abc import AsyncGenerator, Callable -async def _queue_to_generator(queue: asyncio.Queue, queue_event: asyncio.Event, finish_event: asyncio.Event): + +async def _queue_to_generator( + queue: asyncio.Queue, queue_event: asyncio.Event, finish_event: asyncio.Event +): while True: await queue_event.wait() if queue.empty(): queue_event.clear() - if finish_event.is_set(): break + if finish_event.is_set(): + break else: yield await queue.get() - -async def _multiplex(in_stream: AsyncGenerator, queue_list: List[asyncio.Queue], queue_event_list: List[asyncio.Event], finish_event: asyncio.Event): + + +async def _multiplex( + in_stream: AsyncGenerator, + queue_list: list[asyncio.Queue], + queue_event_list: list[asyncio.Event], + finish_event: asyncio.Event, +): async for in_d in in_stream: for q in queue_list: await q.put(dict(in_d)) for qe in queue_event_list: qe.set() - + for qe in queue_event_list: qe.set() finish_event.set() - + + def multiplexor( - func_d: Dict[str, Callable[[AsyncGenerator], AsyncGenerator | None]], - in_stream: AsyncGenerator -) -> Tuple[Dict[str, AsyncGenerator], asyncio.Task]: - queue_list: List[asyncio.Queue] = list() - queue_event_list: List[asyncio.Event] = list() + func_d: dict[str, Callable[[AsyncGenerator], AsyncGenerator | None]], in_stream: AsyncGenerator +) -> tuple[dict[str, AsyncGenerator], asyncio.Task]: + queue_list: list[asyncio.Queue] = [] + queue_event_list: list[asyncio.Event] = [] stream_end_event = asyncio.Event() - - result_d = dict() + + result_d = {} for fun_key in func_d: q = asyncio.Queue() q_event = asyncio.Event() - agen = func_d[fun_key](_queue_to_generator(q,q_event,stream_end_event)) + agen = func_d[fun_key](_queue_to_generator(q, q_event, stream_end_event)) result_d[fun_key] = agen queue_list.append(q) queue_event_list.append(q_event) - - multi_task = asyncio.create_task(_multiplex(in_stream, queue_list, queue_event_list, stream_end_event)) - - return result_d, multi_task \ No newline at end of file + + multi_task = asyncio.create_task( + _multiplex(in_stream, queue_list, queue_event_list, stream_end_event) + ) + + return result_d, multi_task diff --git a/src/utils/helpers/observer.py b/src/utils/helpers/observer.py index 3c9902a..9eb3f53 100644 --- a/src/utils/helpers/observer.py +++ b/src/utils/helpers/observer.py @@ -1,13 +1,15 @@ import asyncio -from typing import List, AsyncGenerator +from collections.abc import AsyncGenerator -'''ObserverClient asynchronously handles events in queue populated by an ObserverServer''' -class BaseObserverClient(): - def __init__(self, server = None): +"""ObserverClient asynchronously handles events in queue populated by an ObserverServer""" + + +class BaseObserverClient: + def __init__(self, server=None): self.server = None if server: self.listen(server) - + self.queue = asyncio.Queue() self.event_listener = None @@ -17,26 +19,29 @@ def listen(self, server): self.server = server self.server.join(self) - + self.event_listener = asyncio.create_task(self._event_listener()) def close(self): self.server.detach(self) self.server = None - + async def _event_listener(self): while True: next_event = await self.queue.get() - await self.handle_event(next_event['event'], next_event['payload']) - + await self.handle_event(next_event["event"], next_event["payload"]) + # To Be Implement async def handle_event(self, event_id: str, payload) -> None: raise NotImplementedError -'''ObserverServer adds events and payloads to all listening client queues.''' -class ObserverServer(): + +"""ObserverServer adds events and payloads to all listening client queues.""" + + +class ObserverServer: def __init__(self): - self.clients: List[ObserverClient] = [] + self.clients: list[BaseObserverClient] = [] def join(self, client): if client not in self.clients: @@ -46,17 +51,13 @@ def detach(self, client): if client in self.clients: self.clients.remove(client) - async def broadcast_event(self, event_id: str, payload: dict = {}): + async def broadcast_event(self, event_id: str, payload: dict = None): + if payload is None: + payload = {} for client in self.clients: - await client.queue.put({ - "event": event_id, - "payload": payload - }) - + await client.queue.put({"event": event_id, "payload": payload}) + async def broadcast_stream(self, event_id: str, payload_stream: AsyncGenerator): async for payload in payload_stream: for client in self.clients: - await client.queue.put({ - "event": event_id, - "payload": payload - }) \ No newline at end of file + await client.queue.put({"event": event_id, "payload": payload}) diff --git a/src/utils/helpers/path.py b/src/utils/helpers/path.py index 5d0c047..91d2477 100644 --- a/src/utils/helpers/path.py +++ b/src/utils/helpers/path.py @@ -1,2 +1,2 @@ def portable_path(target_path: str): - return target_path.encode('unicode_escape').decode() \ No newline at end of file + return target_path.encode("unicode_escape").decode() diff --git a/src/utils/helpers/singleton.py b/src/utils/helpers/singleton.py deleted file mode 100644 index 016f697..0000000 --- a/src/utils/helpers/singleton.py +++ /dev/null @@ -1,9 +0,0 @@ -class Singleton(type): - def __init__(cls, name, bases, dict): - super(Singleton, cls).__init__(name, bases, dict) - cls.instance = None - - def __call__(cls,*args,**kw): - if cls.instance is None: - cls.instance = super(Singleton, cls).__call__(*args, **kw) - return cls.instance \ No newline at end of file diff --git a/src/utils/helpers/subprocess_server.py b/src/utils/helpers/subprocess_server.py new file mode 100644 index 0000000..44ae1d6 --- /dev/null +++ b/src/utils/helpers/subprocess_server.py @@ -0,0 +1,50 @@ +"""Helpers for operations that spawn local HTTP server binaries.""" + +from __future__ import annotations + +import logging +import os +import platform +import socket +import subprocess + +import psutil + +from utils.helpers.path import portable_path + + +def bin_executable(name: str) -> str: + """Path to a server binary installed under ``bin/`` by bootstrap.""" + if platform.system() == "Windows" and not name.endswith(".exe"): + name = f"{name}.exe" + return portable_path(os.path.join(os.getcwd(), "bin", name)) + + +def allocate_port() -> int: + sock = socket.socket() + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + return port + + +def start_shell_process(cmd: str, *, label: str) -> subprocess.Popen: + from subprocess import DEVNULL + + logging.debug(f"Starting {label}: {cmd}") + proc = subprocess.Popen(cmd, shell=True, stdout=DEVNULL, stderr=DEVNULL) + logging.info(f"Started {label} (PID: {proc.pid})") + return proc + + +def stop_process(proc: subprocess.Popen | None, *, label: str) -> None: + if proc is None: + return + try: + ps_process = psutil.Process(proc.pid) + for child in ps_process.children(recursive=True): + child.kill() + ps_process.kill() + except psutil.NoSuchProcess: + pass + logging.info(f"Stopped {label} (PID: {proc.pid})") diff --git a/src/utils/helpers/time.py b/src/utils/helpers/time.py index 61df93d..a000c2a 100644 --- a/src/utils/helpers/time.py +++ b/src/utils/helpers/time.py @@ -1,6 +1,8 @@ import datetime + from dateutil import tz + def get_current_time(include_ms: bool = True, as_str: bool = True): time = datetime.datetime.now() @@ -9,12 +11,14 @@ def get_current_time(include_ms: bool = True, as_str: bool = True): time = time.astimezone(tz.tzlocal()) if as_str: time = time.isoformat() - + return time + def timestamp_to_str(timestamp: int, include_ms: bool = True): time = datetime.datetime.fromtimestamp(timestamp) - if include_ms: time = time.replace(microsecond=0) + if include_ms: + time = time.replace(microsecond=0) time = time.astimezone(tz.tzlocal()).isoformat() - - return time \ No newline at end of file + + return time diff --git a/src/utils/jaison.py b/src/utils/jaison.py index 00c85d0..e7d8376 100644 --- a/src/utils/jaison.py +++ b/src/utils/jaison.py @@ -1,375 +1,426 @@ -import logging import asyncio -import uuid import base64 import datetime -from typing import Dict, Coroutine, List, Any, Tuple +import logging +from collections.abc import Coroutine from enum import Enum +from typing import Any -from utils.helpers.singleton import Singleton +from utils.config import UnknownField, UnknownFile, config from utils.helpers.iterable import chunk_buffer from utils.helpers.observer import ObserverServer - -from utils.config import Config, UnknownField, UnknownFile -from utils.prompter import Prompter -from utils.prompter.message import ( - RawMessage, - RequestMessage, - ChatMessage, - MCPMessage, - CustomMessage -) -from utils.processes import ProcessManager +from utils.mcp import MCPManager from utils.operations import ( - OperationManager, - OpRoles, - Operation, - UnknownOpType, - UnknownOpRole, - UnknownOpID, + CloseInactiveError, DuplicateFilter, + Operation, + OperationManager, OperationUnloaded, + OpRoles, StartActiveError, - CloseInactiveError, - UsedInactiveError + UnknownOpID, + UnknownOpRole, + UnknownOpType, + UsedInactiveError, ) -from utils.mcp import MCPManager +from utils.prompter import Prompter +from utils.prompter.message import ( + ChatMessage, + CustomMessage, + MCPMessage, + RawMessage, + RequestMessage, +) + class NonexistantJobException(Exception): pass + class UnknownJobType(Exception): pass + class JobType(Enum): - RESPONSE = 'response' - CONTEXT_CLEAR = 'context_clear' + RESPONSE = "response" + CONTEXT_CLEAR = "context_clear" CONTEXT_CONFIGURE = "context_configure" - CONTEXT_REQUEST_ADD = 'context_request_add' - CONTEXT_CONVERSATION_ADD_TEXT = 'context_conversation_add_text' - CONTEXT_CONVERSATION_ADD_AUDIO = 'context_conversation_add_audio' - CONTEXT_CUSTOM_REGISTER = 'context_custom_register' - CONTEXT_CUSTOM_REMOVE = 'context_custom_remove' - CONTEXT_CUSTOM_ADD = 'context_custom_add' - OPERATION_LOAD = 'operation_load' + CONTEXT_REQUEST_ADD = "context_request_add" + CONTEXT_CONVERSATION_ADD_TEXT = "context_conversation_add_text" + CONTEXT_CONVERSATION_ADD_AUDIO = "context_conversation_add_audio" + CONTEXT_CUSTOM_REGISTER = "context_custom_register" + CONTEXT_CUSTOM_REMOVE = "context_custom_remove" + CONTEXT_CUSTOM_ADD = "context_custom_add" + OPERATION_LOAD = "operation_load" OPERATION_CONFIG_RELOAD = "operation_reload_from_config" - OPERATION_UNLOAD = 'operation_unload' - OPERATION_CONFIGURE = 'operation_configure' - OPERATION_USE = 'operation_use' - CONFIG_LOAD = 'config_load' - CONFIG_UPDATE = 'config_update' - CONFIG_SAVE = 'config_save' - -class JAIson(metaclass=Singleton): - def __init__(self): # attribute stubs + OPERATION_UNLOAD = "operation_unload" + OPERATION_CONFIGURE = "operation_configure" + OPERATION_USE = "operation_use" + CONFIG_LOAD = "config_load" + CONFIG_UPDATE = "config_update" + CONFIG_SAVE = "config_save" + + +class JAIson: + def __init__(self): # attribute stubs self.job_loop: asyncio.Task = None self.job_queue: asyncio.Queue = None - self.job_map: Dict[str, Tuple[JobType, Coroutine]] = None + self.job_map: dict[str, tuple[JobType, Coroutine]] = None self.job_current_id: str = None self.job_current: asyncio.Task = None self.job_skips: dict = None - + # Any asyncio.Tasks in this list will be cancelled before the next job runs - self.tasks_to_clean: List = list() - + self.tasks_to_clean: list = [] + self.event_server: ObserverServer = None - + self.prompter: Prompter = None - self.process_manager: ProcessManager = None self.op_manager: OperationManager = None self.mcp_manager: MCPManager = None - + async def start(self): logging.info("Starting JAIson application layer.") self.job_queue = asyncio.Queue() - self.job_map = dict() - self.job_skips = dict() + self.job_map = {} + self.job_skips = {} self.job_loop = asyncio.create_task(self._process_job_loop()) - + self.event_server = ObserverServer() - + self.prompter = Prompter() - await self.prompter.configure(Config().prompter) - - self.process_manager = ProcessManager() - self.op_manager = OperationManager() - self.mcp_manager = MCPManager() + await self.prompter.configure(config.prompter) + + self.op_manager = OperationManager(self.prompter) + + self.mcp_manager = MCPManager(self.op_manager) await self.mcp_manager.start() - self.prompter.add_mcp_usage_prompt(self.mcp_manager.get_tooling_prompt(), self.mcp_manager.get_response_prompt()) - await self.op_manager.load_operations_from_config() - await self.process_manager.reload() + self.prompter.add_mcp_usage_prompt( + self.mcp_manager.get_tooling_prompt(), self.mcp_manager.get_response_prompt() + ) + await self._reload_operations_from_config() logging.info("JAIson application layer has started.") - + async def stop(self): logging.info("Shutting down JAIson application layer") await self.op_manager.close_operation_all() await self.mcp_manager.close() - await self.process_manager.unload() logging.info("JAIson application layer has been shut down") - + ## Job Queueing ######################### - + # Add async task to Queue to be ran in the order it was requested - async def create_job(self, job_type: Enum, **kwargs): - new_job_id = str(uuid.uuid4()) - + async def create_job(self, job_type: Enum, job_id: str, **kwargs): job_type_enum = JobType(job_type) - + coro = None - if job_type_enum == JobType.RESPONSE: coro = self.response_pipeline(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.CONTEXT_REQUEST_ADD: coro = self.append_request_context(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.CONTEXT_CONVERSATION_ADD_TEXT: coro = self.append_conversation_context_text(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.CONTEXT_CONVERSATION_ADD_AUDIO: coro = self.append_conversation_context_audio(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.CONTEXT_CLEAR: coro = self.clear_context(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.CONTEXT_CONFIGURE: coro = self.configure_context(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.CONTEXT_CUSTOM_REGISTER: coro = self.register_custom_context(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.CONTEXT_CUSTOM_REMOVE: coro = self.remove_custom_context(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.CONTEXT_CUSTOM_ADD: coro = self.add_custom_context(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.OPERATION_LOAD: coro = self.load_operations(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.OPERATION_CONFIG_RELOAD: coro = self.load_operations_from_config(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.OPERATION_UNLOAD: coro = self.unload_operations(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.OPERATION_CONFIGURE: coro = self.configure_operations(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.OPERATION_USE: coro = self.use_operation(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.CONFIG_LOAD: coro = self.load_config(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.CONFIG_UPDATE: coro = self.update_config(new_job_id, job_type_enum, **kwargs) - elif job_type_enum == JobType.CONFIG_SAVE: coro = self.save_config(new_job_id, job_type_enum, **kwargs) - self.job_map[new_job_id] = (job_type_enum, coro) - - await self.job_queue.put(new_job_id) - - logging.info("Queued new {} job {}".format(job_type_enum.value, new_job_id)) - return new_job_id - + if job_type_enum == JobType.RESPONSE: + coro = self.response_pipeline(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.CONTEXT_REQUEST_ADD: + coro = self.append_request_context(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.CONTEXT_CONVERSATION_ADD_TEXT: + coro = self.append_conversation_context_text(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.CONTEXT_CONVERSATION_ADD_AUDIO: + coro = self.append_conversation_context_audio(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.CONTEXT_CLEAR: + coro = self.clear_context(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.CONTEXT_CONFIGURE: + coro = self.configure_context(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.CONTEXT_CUSTOM_REGISTER: + coro = self.register_custom_context(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.CONTEXT_CUSTOM_REMOVE: + coro = self.remove_custom_context(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.CONTEXT_CUSTOM_ADD: + coro = self.add_custom_context(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.OPERATION_LOAD: + coro = self.load_operations(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.OPERATION_CONFIG_RELOAD: + coro = self.load_operations_from_config(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.OPERATION_UNLOAD: + coro = self.unload_operations(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.OPERATION_CONFIGURE: + coro = self.configure_operations(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.OPERATION_USE: + coro = self.use_operation(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.CONFIG_LOAD: + coro = self.load_config(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.CONFIG_UPDATE: + coro = self.update_config(job_id, job_type_enum, **kwargs) + elif job_type_enum == JobType.CONFIG_SAVE: + coro = self.save_config(job_id, job_type_enum, **kwargs) + self.job_map[job_id] = (job_type_enum, coro) + + await self.job_queue.put(job_id) + + logging.info(f"Queued new {job_type_enum.value} job {job_id}") + return job_id + async def cancel_job(self, job_id: str, reason: str = None): - if job_id not in self.job_map: raise NonexistantJobException(f"Job {job_id} does not exist or already finished") - + if job_id not in self.job_map: + raise NonexistantJobException(f"Job {job_id} does not exist or already finished") + cancel_message = f"Setting job {job_id} to cancel" - if reason: cancel_message += f" because {reason}" + if reason: + cancel_message += f" because {reason}" logging.info(cancel_message) if job_id == self.job_current_id: # If job is already running self._clear_current_job(reason=cancel_message) - else: + else: # If job is still in Queue - # Simply flag to skip. Unzipping queue can potentially process a job out of order + # Simply flag to skip. Unzipping queue can potentially process a job out of order self.job_skips[job_id](cancel_message) - + def _clear_current_job(self, reason: str = None): self.job_map.pop(self.job_current_id, None) self.job_skips.pop(self.job_current_id, None) self.job_current_id = None - + for task in self.tasks_to_clean: task.cancel(reason) self.tasks_to_clean.clear() - + if self.job_current is not None: self.job_current.cancel(reason) self.job_current = None - + # Side loop responsible for processing the next job in the Queue async def _process_job_loop(self): while True: try: - await self.process_manager.reload() - await self.process_manager.unload() - self.job_current_id = await self.job_queue.get() job_type, coro = self.job_map[self.job_current_id] - + if self.job_current_id in self.job_skips: # Skip cancelled jobs reason = self.job_skips[self.job_current_id] - await self._handle_broadcast_error(self.job_current_id, job_type, asyncio.CancelledError(reason)) + await self._handle_broadcast_error( + self.job_current_id, job_type, asyncio.CancelledError(reason) + ) self._clear_current_job(reason=reason) del coro else: # Run and wait for completion self.job_current = asyncio.create_task(coro) await asyncio.wait([self.job_current]) - + # Handle finishing with error err = self.job_current.exception() if self.job_current else None if err is not None: logging.warning(f"Job was cancelled due to an error: {err}", exc_info=err) await self._handle_broadcast_error(self.job_current_id, job_type, err) - + # Cleanup self._clear_current_job() except Exception as err: logging.error("Encountered error in main job processing loop", exc_info=True) await asyncio.sleep(1) - + ## Regular Request Handlers ################### - + def get_loaded_operations(self): op_d = self.op_manager.get_operation_all() for key in op_d: if isinstance(op_d[key], Operation): op_d[key] = op_d[key].op_id elif isinstance(op_d[key], list): - op_d[key] = list(map(lambda x: x.op_id, op_d[key])) + op_d[key] = [x.op_id for x in op_d[key]] else: op_d[key] = "unknown" - + return op_d - + def get_current_config(self): - return Config().get_config_dict() - + return config.get_config_dict() + ## Async Job Handlers ######################### - - ''' + + """ Generate responses from the current contexts. This does not take an input. Context for what to repond to must be added prior to running this. - ''' - async def response_pipeline( - self, - job_id: str, - job_type: JobType, - include_audio: bool = True - ): - + """ + + async def response_pipeline(self, job_id: str, job_type: JobType, include_audio: bool = True): + # Adjust flags based on loaded ops - if not self.op_manager.get_operation(OpRoles.TTS): include_audio = False - + if not self.op_manager.get_operation(OpRoles.TTS): + include_audio = False + # Broadcast start conditions await self._handle_broadcast_start(job_id, job_type, {"include_audio": include_audio}) - + # Handle MCP stuff if self.op_manager.get_operation(OpRoles.MCP): - self.prompter.add_mcp_usage_prompt(self.mcp_manager.get_tooling_prompt(), self.mcp_manager.get_response_prompt()) - mcp_sys_prompt, mcp_user_prompt = self.prompter.generate_mcp_system_context(), self.prompter.generate_mcp_user_context() + self.prompter.add_mcp_usage_prompt( + self.mcp_manager.get_tooling_prompt(), self.mcp_manager.get_response_prompt() + ) + mcp_sys_prompt, mcp_user_prompt = ( + self.prompter.generate_mcp_system_context(), + self.prompter.generate_mcp_user_context(), + ) tooling_response = "" - async for chunk in self.op_manager.use_operation(OpRoles.MCP, {"instruction_prompt": mcp_sys_prompt, "messages": [RawMessage(mcp_user_prompt)]}): - tooling_response += chunk['content'] + async for chunk in self.op_manager.use_operation( + OpRoles.MCP, + {"instruction_prompt": mcp_sys_prompt, "messages": [RawMessage(mcp_user_prompt)]}, + ): + tooling_response += chunk["content"] ## Perform MCP tool calls tool_call_results = await self.mcp_manager.use(tooling_response) - + ## Add results and usage prompt to prompter self.prompter.add_mcp_results(tool_call_results) # Get prompts instruction_prompt, history = self.prompter.get_sys_prompt(), self.prompter.get_history() - + # Appy t2t t2t_result = "" - async for chunk_out in self.op_manager.use_operation(OpRoles.T2T, {"instruction_prompt": instruction_prompt, "messages": history}): + async for chunk_out in self.op_manager.use_operation( + OpRoles.T2T, {"instruction_prompt": instruction_prompt, "messages": history} + ): t2t_result += chunk_out["content"] - + # Broadcast raw results - await self._handle_broadcast_event(job_id, job_type, {"instruction_prompt": instruction_prompt}) - await self._handle_broadcast_event(job_id, job_type, {"history": [msg.to_dict() for msg in history]}) + await self._handle_broadcast_event( + job_id, job_type, {"instruction_prompt": instruction_prompt} + ) + await self._handle_broadcast_event( + job_id, job_type, {"history": [msg.to_dict() for msg in history]} + ) await self._handle_broadcast_event(job_id, job_type, {"raw_content": t2t_result}) # Apply text filters - async for text_chunk_out in self.op_manager.use_operation(OpRoles.FILTER_TEXT, {"content": t2t_result}): - self.prompter.add_chat(self.prompter.character_name, text_chunk_out['content']) + async for text_chunk_out in self.op_manager.use_operation( + OpRoles.FILTER_TEXT, {"content": t2t_result} + ): + self.prompter.add_chat(self.prompter.character_name, text_chunk_out["content"]) await self._handle_broadcast_event(job_id, job_type, text_chunk_out) if include_audio: # Apply tts - async for audio_chunk_out in self.op_manager.use_operation(OpRoles.TTS, text_chunk_out): + async for audio_chunk_out in self.op_manager.use_operation( + OpRoles.TTS, text_chunk_out + ): # Apply tts filters - async for final_audio_chunk_out in self.op_manager.use_operation(OpRoles.FILTER_AUDIO, audio_chunk_out): + async for final_audio_chunk_out in self.op_manager.use_operation( + OpRoles.FILTER_AUDIO, audio_chunk_out + ): # Broadcast results (only the audio data for now) - for ws_chunk in chunk_buffer(base64.b64encode(final_audio_chunk_out['audio_bytes']).decode('utf-8')): - await self._handle_broadcast_event(job_id, job_type, { - "audio_bytes": ws_chunk, - "sr": final_audio_chunk_out['sr'], - "sw": final_audio_chunk_out['sw'], - "ch": final_audio_chunk_out['ch'] - }) - + for ws_chunk in chunk_buffer( + base64.b64encode(final_audio_chunk_out["audio_bytes"]).decode("utf-8") + ): + await self._handle_broadcast_event( + job_id, + job_type, + { + "audio_bytes": ws_chunk, + "sr": final_audio_chunk_out["sr"], + "sw": final_audio_chunk_out["sw"], + "ch": final_audio_chunk_out["ch"], + }, + ) + # Broadcast completion await self._handle_broadcast_success(job_id, job_type) - # Context modification - async def clear_context( - self, - job_id: str, - job_type: JobType - ): + async def clear_context(self, job_id: str, job_type: JobType): await self._handle_broadcast_start(job_id, job_type, {}) self.prompter.clear_history() await self._handle_broadcast_success(job_id, job_type) - + async def configure_context( self, job_id: str, job_type: JobType, - name_translations: Dict[str, str] = None, + name_translations: dict[str, str] = None, character_name: str = None, history_length: int = None, instruction_prompt_filename: str = None, character_prompt_filename: str = None, - scene_prompt_filename: str = None + scene_prompt_filename: str = None, ): - await self._handle_broadcast_start(job_id, job_type, { - "name_translations": name_translations, - "character_name": character_name, - "history_length": history_length, - "instruction_prompt_filename": instruction_prompt_filename, - "character_prompt_filename": character_prompt_filename, - "scene_prompt_filename": scene_prompt_filename - }) - payload = dict() - if name_translations: payload |= {"name_translations": name_translations} - if character_name: payload |= {"character_name": character_name} - if history_length: payload |= {"history_length": history_length} - if history_length: payload |= {"history_length": history_length} - if instruction_prompt_filename: payload |= {"instruction_prompt_filename": instruction_prompt_filename} - if character_prompt_filename: payload |= {"character_prompt_filename": character_prompt_filename} - if scene_prompt_filename: payload |= {"scene_prompt_filename": scene_prompt_filename} - + await self._handle_broadcast_start( + job_id, + job_type, + { + "name_translations": name_translations, + "character_name": character_name, + "history_length": history_length, + "instruction_prompt_filename": instruction_prompt_filename, + "character_prompt_filename": character_prompt_filename, + "scene_prompt_filename": scene_prompt_filename, + }, + ) + payload = {} + if name_translations: + payload |= {"name_translations": name_translations} + if character_name: + payload |= {"character_name": character_name} + if history_length: + payload |= {"history_length": history_length} + if history_length: + payload |= {"history_length": history_length} + if instruction_prompt_filename: + payload |= {"instruction_prompt_filename": instruction_prompt_filename} + if character_prompt_filename: + payload |= {"character_prompt_filename": character_prompt_filename} + if scene_prompt_filename: + payload |= {"scene_prompt_filename": scene_prompt_filename} + self.prompter.configure(payload) - + await self._handle_broadcast_success(job_id, job_type) - async def append_request_context( - self, - job_id: str, - job_type: JobType, - content: str = None - ): + async def append_request_context(self, job_id: str, job_type: JobType, content: str = None): await self._handle_broadcast_start(job_id, job_type, {"content": content}) self.prompter.add_request(content) last_line_o = self.prompter.history[-1] - await self._handle_broadcast_event(job_id, job_type, { - "timestamp": last_line_o.time.timestamp(), - "content": last_line_o.message, - "line": last_line_o.to_line() - }) + await self._handle_broadcast_event( + job_id, + job_type, + { + "timestamp": last_line_o.time.timestamp(), + "content": last_line_o.message, + "line": last_line_o.to_line(), + }, + ) await self._handle_broadcast_success(job_id, job_type) - + async def append_conversation_context_text( - self, - job_id: str, - job_type: JobType, - user: str = None, - timestamp: int = None, - content: str = None + self, + job_id: str, + job_type: JobType, + user: str = None, + timestamp: int = None, + content: str = None, ): - await self._handle_broadcast_start(job_id, job_type, {"user": user, "timestamp": timestamp, "content": content}) + await self._handle_broadcast_start( + job_id, job_type, {"user": user, "timestamp": timestamp, "content": content} + ) self.prompter.add_chat( user, content, time=( - datetime.datetime.fromtimestamp(timestamp) \ - if not isinstance(timestamp, datetime.datetime) else timestamp - ) + datetime.datetime.fromtimestamp(timestamp) + if not isinstance(timestamp, datetime.datetime) + else timestamp + ), ) last_line_o = self.prompter.history[-1] - await self._handle_broadcast_event(job_id, job_type, { - "user": last_line_o.user, - "timestamp": last_line_o.time.timestamp(), - "content": last_line_o.message, - "line": last_line_o.to_line() - }) + await self._handle_broadcast_event( + job_id, + job_type, + { + "user": last_line_o.user, + "timestamp": last_line_o.time.timestamp(), + "content": last_line_o.message, + "line": last_line_o.to_line(), + }, + ) await self._handle_broadcast_success(job_id, job_type) - + async def append_conversation_context_audio( self, job_id: str, @@ -379,155 +430,229 @@ async def append_conversation_context_audio( audio_bytes: str = None, sr: int = None, sw: int = None, - ch: int = None + ch: int = None, ): - await self._handle_broadcast_start(job_id, job_type, {"user": user, "timestamp": timestamp, "sr": sr, "sw": sw, "ch": ch, "audio_bytes": (audio_bytes is not None)}) # Don't send full audio bytes over websocket, just flag as gotten + await self._handle_broadcast_start( + job_id, + job_type, + { + "user": user, + "timestamp": timestamp, + "sr": sr, + "sw": sw, + "ch": ch, + "audio_bytes": (audio_bytes is not None), + }, + ) # Don't send full audio bytes over websocket, just flag as gotten audio_bytes: bytes = base64.b64decode(audio_bytes) - prompt = self.prompter.get_history_text() or "You're name is {}".format(self.prompter.character_name) + prompt = ( + self.prompter.get_history_text() or f"You're name is {self.prompter.character_name}" + ) content = "" - async for out_d in self.op_manager.use_operation(OpRoles.STT, {"prompt": prompt, "audio_bytes": audio_bytes, "sr": sr, "sw": sw, "ch": ch}): - content += out_d['transcription'] - + async for out_d in self.op_manager.use_operation( + OpRoles.STT, + {"prompt": prompt, "audio_bytes": audio_bytes, "sr": sr, "sw": sw, "ch": ch}, + ): + content += out_d["transcription"] + self.prompter.add_chat( user, content, time=( - datetime.datetime.fromtimestamp(timestamp) \ - if isinstance(timestamp, int) else timestamp - ) + datetime.datetime.fromtimestamp(timestamp) + if isinstance(timestamp, int) + else timestamp + ), ) last_line_o = self.prompter.history[-1] - await self._handle_broadcast_event(job_id, job_type, { - "user": last_line_o.user, - "timestamp": last_line_o.time.timestamp(), - "content": last_line_o.message, - "line": last_line_o.to_line() - }) + await self._handle_broadcast_event( + job_id, + job_type, + { + "user": last_line_o.user, + "timestamp": last_line_o.time.timestamp(), + "content": last_line_o.message, + "line": last_line_o.to_line(), + }, + ) await self._handle_broadcast_success(job_id, job_type) - + async def register_custom_context( self, job_id: str, job_type: JobType, context_id: str = None, context_name: str = None, - context_description: str = None + context_description: str = None, ): - await self._handle_broadcast_start(job_id, job_type, {"context_id": context_id, "context_name": context_name, "context_description": context_description}) - self.prompter.register_custom_context(context_id, context_name, context_description=context_description) + await self._handle_broadcast_start( + job_id, + job_type, + { + "context_id": context_id, + "context_name": context_name, + "context_description": context_description, + }, + ) + self.prompter.register_custom_context( + context_id, context_name, context_description=context_description + ) await self._handle_broadcast_success(job_id, job_type) - - async def remove_custom_context(self, - job_id: str, - job_type: JobType, - context_id: str = None - ): + + async def remove_custom_context(self, job_id: str, job_type: JobType, context_id: str = None): await self._handle_broadcast_start(job_id, job_type, {"context_id": context_id}) self.prompter.remove_custom_context(context_id) await self._handle_broadcast_success(job_id, job_type) - + async def add_custom_context( self, job_id: str, job_type: JobType, context_id: str = None, context_contents: str = None, - timestamp: int = None + timestamp: int = None, ): - await self._handle_broadcast_start(job_id, job_type, {"context_id": context_id, "context_contents": context_contents, "timestamp": timestamp}) - if timestamp is not None: timestamp = datetime.datetime.fromtimestamp(timestamp) + await self._handle_broadcast_start( + job_id, + job_type, + { + "context_id": context_id, + "context_contents": context_contents, + "timestamp": timestamp, + }, + ) + if timestamp is not None: + timestamp = datetime.datetime.fromtimestamp(timestamp) self.prompter.add_custom_context(context_id, context_contents) last_line_o = self.prompter.history[-1] - await self._handle_broadcast_event(job_id, job_type, { - "timestamp": last_line_o.time.timestamp(), - "content": last_line_o.message, - "line": last_line_o.to_line() - }) + await self._handle_broadcast_event( + job_id, + job_type, + { + "timestamp": last_line_o.time.timestamp(), + "content": last_line_o.message, + "line": last_line_o.to_line(), + }, + ) await self._handle_broadcast_success(job_id, job_type) - - # Operation management + + # Operation management async def load_operations( - self, - job_id: str, - job_type: JobType, - ops: List[Dict[str, str]] = [] + self, job_id: str, job_type: JobType, ops: list[dict[str, str]] = None ): + if ops is None: + ops = [] await self._handle_broadcast_start(job_id, job_type, {"ops": ops}) for op_d in ops: - await self.op_manager.load_operation(OpRoles(op_d.get('role', None)), op_d.get('id', None), op_d.get('config', dict())) - await self._handle_broadcast_event(job_id, job_type, { - "role": op_d.get('role', None), - "id": op_d.get('id', None), - "loose_key": op_d.get("loose_key", None) - }) + await self.op_manager.load_operation( + OpRoles(op_d.get("role", None)), op_d.get("id", None), op_d.get("config", {}) + ) + await self._handle_broadcast_event( + job_id, + job_type, + { + "role": op_d.get("role", None), + "id": op_d.get("id", None), + "loose_key": op_d.get("loose_key", None), + }, + ) await self._handle_broadcast_success(job_id, job_type) - + + async def _reload_operations_from_config(self) -> None: + """Load, start, and save all operations listed in the active config.""" + await self.op_manager.close_operation_all() + for op_details in config.operations: + await self.op_manager.load_operation( + OpRoles(op_details["role"]), + op_details["id"], + op_details, + ) + async def load_operations_from_config( self, job_id: str, job_type: JobType, ): await self._handle_broadcast_start(job_id, job_type, {}) - await self.op_manager.load_operations_from_config() + await self._reload_operations_from_config() await self._handle_broadcast_success(job_id, job_type) - + async def unload_operations( - self, - job_id: str, - job_type: JobType, - ops: List[Dict[str, str]] = [] + self, job_id: str, job_type: JobType, ops: list[dict[str, str]] = None ): + if ops is None: + ops = [] await self._handle_broadcast_start(job_id, job_type, {"ops": ops}) for op_d in ops: - await self.op_manager.close_operation(OpRoles(op_d.get('role', None)), op_d.get('id', None)) - await self._handle_broadcast_event(job_id, job_type, { - "role": op_d.get('role', None), - "id": op_d.get('id', None) - }) + await self.op_manager.close_operation( + OpRoles(op_d.get("role", None)), op_d.get("id", None) + ) + await self._handle_broadcast_event( + job_id, job_type, {"role": op_d.get("role", None), "id": op_d.get("id", None)} + ) await self._handle_broadcast_success(job_id, job_type) - - async def configure_operations( # TODO document and add endpoint - self, - job_id: str, - job_type: JobType, - ops: List[Dict[str, str]] = [] + + async def configure_operations( # TODO document and add endpoint + self, job_id: str, job_type: JobType, ops: list[dict[str, str]] = None ): + if ops is None: + ops = [] await self._handle_broadcast_start(job_id, job_type, {"ops": ops}) for op_d in ops: - await self.op_manager.configure(OpRoles(op_d.get('role', None)), op_d, op_id=op_d.get('id', None)) + await self.op_manager.configure( + OpRoles(op_d.get("role", None)), op_d, op_id=op_d.get("id", None) + ) await self._handle_broadcast_event(job_id, job_type, op_d) await self._handle_broadcast_success(job_id, job_type) - + async def use_operation( self, job_id: str, job_type: JobType, role: str = None, id: str = None, - payload: Dict[str, Any] = None + payload: dict[str, Any] = None, ): await self._handle_broadcast_start(job_id, job_type, {"role": role, "id": id}) - - if 'audio_bytes' in payload: - payload['audio_bytes'] = base64.b64decode(payload['audio_bytes']) - - if 'messages' in payload: - msg_list = list() - for msg in payload['messages']: - assert 'type' in msg - if msg['type'] == "raw": - msg_list.append(RawMessage(msg['message'])) - elif msg['type'] == "request": - msg_list.append(RequestMessage(msg['message'], datetime.datetime.fromtimestamp(msg['time']))) - elif msg['type'] == "chat": - msg_list.append(ChatMessage(msg['user'], msg['message'], datetime.datetime.fromtimestamp(msg['time']))) - elif msg['type'] == "tool": - msg_list.append(MCPMessage(msg['tool'], msg['message'], datetime.datetime.fromtimestamp(msg['time']))) - elif msg['type'] == "custom": - msg_list.append(CustomMessage(msg['id'], msg['message'], datetime.datetime.fromtimestamp(msg['time']))) + + if "audio_bytes" in payload: + payload["audio_bytes"] = base64.b64decode(payload["audio_bytes"]) + + if "messages" in payload: + msg_list = [] + for msg in payload["messages"]: + assert "type" in msg + if msg["type"] == "raw": + msg_list.append(RawMessage(msg["message"])) + elif msg["type"] == "request": + msg_list.append( + RequestMessage(msg["message"], datetime.datetime.fromtimestamp(msg["time"])) + ) + elif msg["type"] == "chat": + msg_list.append( + ChatMessage( + msg["user"], + msg["message"], + datetime.datetime.fromtimestamp(msg["time"]), + ) + ) + elif msg["type"] == "tool": + msg_list.append( + MCPMessage( + msg["tool"], + msg["message"], + datetime.datetime.fromtimestamp(msg["time"]), + ) + ) + elif msg["type"] == "custom": + msg_list.append( + CustomMessage( + msg["id"], msg["message"], datetime.datetime.fromtimestamp(msg["time"]) + ) + ) else: raise Exception("Invalid message type") - payload['messages'] = msg_list + payload["messages"] = msg_list try: async for chunk_out in self.op_manager.use_operation(OpRoles(role), payload, op_id=id): @@ -536,80 +661,84 @@ async def use_operation( op = self.op_manager.loose_load_operation(OpRoles(role), id) await op.start() async for chunk_out in op(payload): - if "audio_bytes" in chunk_out: chunk_out["audio_bytes"] = base64.b64encode(chunk_out['audio_bytes']).decode('utf-8') + if "audio_bytes" in chunk_out: + chunk_out["audio_bytes"] = base64.b64encode(chunk_out["audio_bytes"]).decode( + "utf-8" + ) await self._handle_broadcast_event(job_id, job_type, chunk_out) await op.close() - + await self._handle_broadcast_success(job_id, job_type) - + # Configuration async def load_config(self, job_id: str, job_type: JobType, config_name: str): await self._handle_broadcast_start(job_id, job_type, {"config_name": config_name}) - Config().load_from_name(config_name) + config.load_from_name(config_name) await self._handle_broadcast_success(job_id, job_type) - + async def update_config(self, job_id: str, job_type: JobType, config_d: str): await self._handle_broadcast_start(job_id, job_type, {"config_d": config_d}) - Config().load_from_dict(config_d) + config.load_from_dict(config_d) await self._handle_broadcast_success(job_id, job_type) - + async def save_config(self, job_id: str, job_type: JobType, config_name: str): await self._handle_broadcast_start(job_id, job_type, {"config_name": config_name}) - Config().save(config_name) + config.save(config_name) await self._handle_broadcast_success(job_id, job_type) - + ## General helpers ############################### async def _handle_broadcast_start(self, job_id: str, job_type: JobType, payload: dict): - to_broadcast = { - "job_id": job_id, - "start": payload - } - logging.debug("Broadcasting start ({}) {} {:.500}".format(job_id, job_type.value, str(to_broadcast))) + to_broadcast = {"job_id": job_id, "start": payload} + logging.debug(f"Broadcasting start ({job_id}) {job_type.value} {str(to_broadcast):.500}") await self.event_server.broadcast_event(job_type.value, to_broadcast) - + async def _handle_broadcast_event(self, job_id: str, job_type: JobType, payload: dict): - to_broadcast = { - "job_id": job_id, - "finished": False, - "result": payload - } - logging.debug("Broadcasting event ({}) {} {:.500}".format(job_id, job_type.value, str(to_broadcast))) + to_broadcast = {"job_id": job_id, "finished": False, "result": payload} + logging.debug(f"Broadcasting event ({job_id}) {job_type.value} {str(to_broadcast):.500}") await self.event_server.broadcast_event(job_type.value, to_broadcast) - + async def _handle_broadcast_success(self, job_id: str, job_type: JobType): - to_broadcast = { - "job_id": job_id, - "finished": True, - "success": True - } - logging.debug("Broadcasting success ({}) {} {}".format(job_id, job_type.value, str(to_broadcast))) + to_broadcast = {"job_id": job_id, "finished": True, "success": True} + logging.debug(f"Broadcasting success ({job_id}) {job_type.value} {str(to_broadcast)}") await self.event_server.broadcast_event(job_type.value, to_broadcast) - + async def _handle_broadcast_error(self, job_id: str, job_type: JobType, err: Exception): # TODO: extend with all errors error_type = "unknown" - if isinstance(err, UnknownOpType): error_type = "operation_unknown_type" - if isinstance(err, UnknownOpRole): error_type = "operation_unknown_role" - elif isinstance(err, UnknownOpID): error_type = "operation_unknown_id" - elif isinstance(err, DuplicateFilter): error_type = "operation_duplicate" - elif isinstance(err, OperationUnloaded): error_type = "operation_unloaded" - elif isinstance(err, StartActiveError): error_type = "operation_active" - elif isinstance(err, CloseInactiveError): error_type = "operation_inactive" - elif isinstance(err, UsedInactiveError): error_type = "operation_inactive" - elif isinstance(err, UnknownField): error_type = "config_unknown_field" - elif isinstance(err, UnknownFile): error_type = "config_unknown_file" - elif isinstance(err, UnknownJobType): error_type = "job_unknown" - elif isinstance(err, asyncio.CancelledError): error_type = "job_cancelled" - + if isinstance(err, UnknownOpType): + error_type = "operation_unknown_type" + if isinstance(err, UnknownOpRole): + error_type = "operation_unknown_role" + elif isinstance(err, UnknownOpID): + error_type = "operation_unknown_id" + elif isinstance(err, DuplicateFilter): + error_type = "operation_duplicate" + elif isinstance(err, OperationUnloaded): + error_type = "operation_unloaded" + elif isinstance(err, StartActiveError): + error_type = "operation_active" + elif isinstance(err, CloseInactiveError): + error_type = "operation_inactive" + elif isinstance(err, UsedInactiveError): + error_type = "operation_inactive" + elif isinstance(err, UnknownField): + error_type = "config_unknown_field" + elif isinstance(err, UnknownFile): + error_type = "config_unknown_file" + elif isinstance(err, UnknownJobType): + error_type = "job_unknown" + elif isinstance(err, asyncio.CancelledError): + error_type = "job_cancelled" + to_broadcast = { "job_id": job_id, "finished": True, "success": False, - "result": { - "type": error_type, - "reason": str(err) - } + "result": {"type": error_type, "reason": str(err)}, } - - logging.debug("Broadcasting error ({}) {} {}".format(job_id, job_type.value, str(to_broadcast))) - await self.event_server.broadcast_event(job_type.value, to_broadcast) \ No newline at end of file + + logging.debug(f"Broadcasting error ({job_id}) {job_type.value} {str(to_broadcast)}") + await self.event_server.broadcast_event(job_type.value, to_broadcast) + + +jaison = JAIson() diff --git a/src/utils/logging.py b/src/utils/logging.py index 3765ad3..4e3d6bc 100644 --- a/src/utils/logging.py +++ b/src/utils/logging.py @@ -1,49 +1,68 @@ import logging import os import sys -from utils.helpers.time import get_current_time +import uuid + +from rich.console import Console +from rich.logging import RichHandler + from utils.args import args +from utils.helpers.time import get_current_time + +START_TIME = get_current_time(include_ms=False, as_str=False) + +_LOG_TIME_FORMAT = "[%Y-%m-%d %H:%M:%S]" + +# Uvicorn installs its own handlers/formatters unless log_config=None; route through root. +_UVICORN_LOGGERS = ("uvicorn", "uvicorn.error", "uvicorn.access", "uvicorn.asgi") + + +def _log_filename() -> str: + now = get_current_time(include_ms=False, as_str=False) + short_id = uuid.uuid4().hex[:8] + # %f is 6-digit microseconds; drop the last 3 digits for padded milliseconds. + return f"{now.strftime('%Y-%m-%d_%H-%M-%S-%f')}_{short_id}.log" + + +def _configure_uvicorn_loggers() -> None: + level = getattr(logging, args.log_level) + for name in _UVICORN_LOGGERS: + uvicorn_logger = logging.getLogger(name) + uvicorn_logger.handlers.clear() + uvicorn_logger.propagate = True + uvicorn_logger.setLevel(level) + + +def _create_rich_handler(console: Console, *, enable_link_path: bool = True) -> RichHandler: + return RichHandler( + console=console, + rich_tracebacks=True, + show_time=True, + show_level=True, + show_path=True, + enable_link_path=enable_link_path, + log_time_format=_LOG_TIME_FORMAT, + markup=False, + ) + -START_TIME = get_current_time(include_ms=False,as_str=False) - -# Setup formatters and handlers -class CustomFormatter(logging.Formatter): - # Using console color codes to style text - reset = "\x1b[0m" - base_time = "[%(asctime)s]" + reset - base_level = "[%(levelname)-5.5s]" + reset - base_func = "[%(filename)s::%(lineno)d %(funcName)s]:" + reset - base_msg = "%(message)s" + reset - - template_line = "\x1b[1m\x1b[1;34m" + base_time + " {}" + base_level + " \x1b[1m\x1b[1;33m" + base_func + " " + base_msg - - FORMATS = { - logging.DEBUG: template_line.format("\x1b[1m\x1b[1;30m\x1b[47m"), - logging.INFO: template_line.format("\x1b[1m\x1b[1;30m\x1b[42m"), - logging.WARNING: template_line.format("\x1b[1m\x1b[1;30m\x1b[43m"), - logging.ERROR: template_line.format("\x1b[1m\x1b[1;30m\x1b[41m"), - logging.CRITICAL: template_line.format("\x1b[1m\x1b[31m\x1b[45m") - } - - def format(self, record): - log_fmt = self.FORMATS.get(record.levelno) - formatter = logging.Formatter(log_fmt) - return formatter.format(record) - def setup_logger(): global START_TIME - + logger = logging.getLogger() logger.setLevel(getattr(logging, args.log_level)) - file_formatter = logging.Formatter("[%(asctime)s] [%(levelname)-5.5s] [%(filename)s::%(lineno)d %(funcName)s]: %(message)s") - file_handler = logging.FileHandler( - os.path.join(args.log_dir, "{}.log".format(get_current_time(include_ms=False,as_str=False).strftime("%Y-%m-%d")))) - file_handler.setFormatter(file_formatter) - + log_path = os.path.join(args.log_dir, _log_filename()) + log_file = open(log_path, "a", encoding="utf-8") + file_handler = _create_rich_handler( + Console(file=log_file, width=200, no_color=True, highlight=False), + enable_link_path=False, + ) + file_handler._log_file = log_file # keep handle open for process lifetime logger.addHandler(file_handler) if not args.silent: - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setFormatter(CustomFormatter()) - logger.addHandler(console_handler) \ No newline at end of file + console_handler = _create_rich_handler(Console(file=sys.stdout)) + logger.addHandler(console_handler) + + _configure_uvicorn_loggers() diff --git a/src/utils/mcp/__init__.py b/src/utils/mcp/__init__.py index c41a9cf..c0be947 100644 --- a/src/utils/mcp/__init__.py +++ b/src/utils/mcp/__init__.py @@ -1 +1 @@ -from .manager import MCPManager \ No newline at end of file +from .manager import MCPManager as MCPManager diff --git a/src/utils/mcp/manager.py b/src/utils/mcp/manager.py index b71f573..1f6df31 100644 --- a/src/utils/mcp/manager.py +++ b/src/utils/mcp/manager.py @@ -1,23 +1,23 @@ -import os -import json import datetime +import json +import logging +import os import re import urllib -import logging -from typing import List, Dict + from mcp import ClientSession, StdioServerParameters, types from mcp.client.stdio import stdio_client from mcp.types import ( - TextContent, - ImageContent, + BlobResourceContents, EmbeddedResource, + ImageContent, + TextContent, TextResourceContents, - BlobResourceContents ) +from utils.config import config +from utils.operations import OperationManager, OpRoles from utils.prompter.message import RawMessage -from utils.config import Config -from utils.operations import OperationManager, OpRoles def parse_tool_result(result): if isinstance(result, TextContent): @@ -33,145 +33,145 @@ def parse_tool_result(result): else: raise Exception("Unknown result type") + def details_to_response_prompt(details): - tools = details['tools'] - resources = details['resources'] - templates = details['templates'] - + tools = details["tools"] + resources = details["resources"] + templates = details["templates"] + prompt = "" - + for tool in tools: name = tool.name description = tool.description prompt += f"<{name}> {description}\n\n" - + for resource in resources: name = resource.name description = resource.description prompt += f"<{name}> {description}\n" - + for template in templates: name = template.name description = template.description prompt += f"<{name}> {description}\n\n" - + return prompt - + + def details_to_tool_prompt(details): - tools = details['tools'] - resources = details['resources'] - templates = details['templates'] - + tools = details["tools"] + resources = details["resources"] + templates = details["templates"] + prompt = "" - + for tool in tools: name = tool.name description = tool.description inputSchema = json.dumps(tool.inputSchema) prompt += f"<{name}> {description}\nThis is the input schema for {name}: {inputSchema}\n" - + for resource in resources: name = resource.name description = resource.description prompt += f"<{name}> {description}\n" - + for template in templates: name = template.name description = template.description uri_template = template.uriTemplate prompt += f"<{name}> {description}\nThis is the URI template: {uri_template}\n" - + return prompt + class MCPClient: - '''Managing of a single server instance''' - - def __init__(self, mcp_id: str, params: StdioServerParameters): + """Managing of a single server instance""" + + def __init__( + self, mcp_id: str, params: StdioServerParameters, op_manager: OperationManager + ): self.mcp_id = mcp_id self.params = params + self.op_manager = op_manager self.server_generator = None self.server_read = None self.server_write = None self.session = None - - self.tools = list() - self.resources = list() - self.templates = list() - - self.tool_names = list() - self.resource_names = list() - self.template_names = list() - + + self.tools = [] + self.resources = [] + self.templates = [] + + self.tool_names = [] + self.resource_names = [] + self.template_names = [] + self.tool_prompt = "" self.response_prompt = "" - + async def start(self): self.server_generator = stdio_client(self.params) logging.debug("starting context") self.server_read, self.server_write = await self.server_generator.__aenter__() - logging.debug("{} {}".format(type(self.server_read), type(self.server_write))) + logging.debug(f"{type(self.server_read)} {type(self.server_write)}") logging.debug("starting session") self.session = ClientSession( self.server_read, self.server_write, read_timeout_seconds=datetime.timedelta(seconds=10), - sampling_callback=self.handle_sampling_message + sampling_callback=self.handle_sampling_message, ) - + logging.debug("initializing session") await self.session.__aenter__() await self.session.initialize() - + details = await self.get_details() self.tool_prompt = details_to_tool_prompt(details) self.response_prompt = details_to_response_prompt(details) - + async def close(self): await self.session.__aexit__(None, None, None) await self.server_generator.__aexit__(None, None, None) - + async def get_details(self): try: self.tools = (await self.session.list_tools()).tools - except: + except Exception: pass try: self.resources = (await self.session.list_resources()).resources - except: + except Exception: pass try: self.templates = (await self.session.list_resource_templates()).resourceTemplates - except: + except Exception: pass - + for tool in self.tools: self.tool_names.append(tool.name) for resource in self.resources: self.resource_names.append(resource.name) for template in self.templates: self.template_names.append(template.name) - - return { - "tools": self.tools, - "resources": self.resources, - "templates": self.templates - } - + + return {"tools": self.tools, "resources": self.resources, "templates": self.templates} + async def handle_sampling_message( - self, - ctx, - message: types.CreateMessageRequestParams + self, ctx, message: types.CreateMessageRequestParams ) -> types.CreateMessageResult: try: - metadata = message.metadata or dict() + metadata = message.metadata or {} sample_type = metadata.get("sample_type", "t2t") if sample_type == "t2t": - response_stream = OperationManager().use_operation( + response_stream = self.op_manager.use_operation( OpRoles.MCP, { "instruction_prompt": message.systemPrompt, - "messages": [RawMessage(message.messages[0].content.text)] - } + "messages": [RawMessage(message.messages[0].content.text)], + }, ) response = "" @@ -188,11 +188,11 @@ async def handle_sampling_message( stopReason="endTurn", ) elif sample_type == "embedding": - response_stream = OperationManager().use_operation( + response_stream = self.op_manager.use_operation( OpRoles.EMBEDDING, { "content": message.systemPrompt[:10000], - } + }, ) response = "" @@ -208,10 +208,11 @@ async def handle_sampling_message( model="embedding", stopReason="endTurn", ) - except Exception as err: + except Exception: logging.error("MCP sampler encountered an issue", exc_info=True) return "" - + + class MCPManager: tooling_prompt = """ You are calling tools based on the user input to gather more information to enrich a role-playing response and to perform relevant actions. Only reply with the appropriate tool calls and nothing else. @@ -224,86 +225,85 @@ class MCPManager: Below is a list of descriptions for all available tool:\n """ -# response_prompt = """ -# You are an assistant answer a user's question. -# You will be given the user's question under the header . Answer this using the additional information. Do not hallucinate. -# You are given additional information to the actions and context retrieved prior to answering under the header . -# This additional information will each be on a new line, formated as follows: context_name: context_result -# For example, context by the name of "memories" that gave context "you are an ai" will look like "memories: you are an ai" -# Below is a list of all available contexts and their descriptions: -# """ + # response_prompt = """ + # You are an assistant answer a user's question. + # You will be given the user's question under the header . Answer this using the additional information. Do not hallucinate. + # You are given additional information to the actions and context retrieved prior to answering under the header . + # This additional information will each be on a new line, formated as follows: context_name: context_result + # For example, context by the name of "memories" that gave context "you are an ai" will look like "memories: you are an ai" + # Below is a list of all available contexts and their descriptions: + # """ pattern = re.compile(r"^<[\S]*>") - - def __init__(self): - # servers are loaded at start and at no other point - # self.client_params: List[StdioServerParameters] = list() - self.clients: Dict[str, MCPClient] = dict() - + + def __init__(self, op_manager: OperationManager): + self.op_manager = op_manager + self.clients: dict[str, MCPClient] = {} + async def start(self): - config = Config() for mcp_detail in config.mcp: await self.load_mcp(mcp_detail) - - async def load_mcp(self, mcp_detail: Dict): - # TODO validate the mcp_detail - + + async def load_mcp(self, mcp_detail: dict): + # TODO validate the mcp_detail + params = StdioServerParameters( - command=mcp_detail['command'], # Executable - args=mcp_detail['args'], # Optional command line arguments + command=mcp_detail["command"], # Executable + args=mcp_detail["args"], # Optional command line arguments env=os.environ, # Optional environment variables - cwd=mcp_detail['cwd'] + cwd=mcp_detail["cwd"], ) - client = MCPClient(mcp_detail["id"], params) + client = MCPClient(mcp_detail["id"], params, self.op_manager) await client.start() - self.clients[mcp_detail['id']] = client - + self.clients[mcp_detail["id"]] = client + async def close_mcp(self, mcp_id: str): target = self.clients.get(mcp_id, None) if target: await target.close() del self.clients[mcp_id] - + def get_tooling_prompt(self): prompt = self.tooling_prompt for client_key in self.clients: prompt += self.clients[client_key].tool_prompt - + return prompt - + def get_response_prompt(self): # prompt = self.response_prompt prompt = "" for client_key in self.clients: prompt += self.clients[client_key].response_prompt - + return prompt - + async def use(self, tooling_response: str): tool_calls = tooling_response.split("\n") - - result_list = list() - + + result_list = [] + for tool_call in tool_calls: result = None tool_name = "" try: match = self.pattern.search(tool_call) - if match is None: continue - name_token = tool_call[:match.span()[1]] + if match is None: + continue + name_token = tool_call[: match.span()[1]] tool_name = name_token.lstrip("<").rstrip(">") - if name_token == "no_op": continue - tool_call = tool_call[match.span()[1]:].rstrip(" ") - input_json = json.loads(tool_call) if len(tool_call) else dict() - - tool = { - "name": name_token, - "input": input_json - } - + if name_token == "no_op": + continue + tool_call = tool_call[match.span()[1] :].rstrip(" ") + input_json = json.loads(tool_call) if len(tool_call) else {} + + tool = {"name": name_token, "input": input_json} + for client in self.clients: if tool_name in self.clients[client].tool_names: - result = await self.clients[client].session.call_tool(tool_name, arguments=tool['input']) + result = await self.clients[client].session.call_tool( + tool_name, arguments=tool["input"] + ) result = parse_tool_result(result.content[0]) break elif tool_name in self.clients[client].resource_names: @@ -321,25 +321,29 @@ async def use(self, tooling_response: str): if templates.name == tool_name: uri_template = templates.uriTemplate break - for key in tool['input']: - if isinstance(tool['input'][key], str): - urllib.parse.quote(tool['input'][key]) - tool['input'][key] = urllib.parse.quote(tool['input'][key]) - logging.debug("Calling resource: {} {} {}".format(tool_name, tool['input'], uri_template)) - logging.debug(uri_template.format(**tool['input'])) + for key in tool["input"]: + if isinstance(tool["input"][key], str): + urllib.parse.quote(tool["input"][key]) + tool["input"][key] = urllib.parse.quote(tool["input"][key]) + logging.debug( + "Calling resource: {} {} {}".format( + tool_name, tool["input"], uri_template + ) + ) + logging.debug(uri_template.format(**tool["input"])) result = await self.clients[client].session.read_resource( - uri_template.format(**tool['input']) + uri_template.format(**tool["input"]) ) result = parse_tool_result(result.contents[0]) break except Exception as err: logging.critical("Error occured during MCP", exc_info=True) - result = "Attempt to use MCP tool failed due to {}".format(str(err)) + result = f"Attempt to use MCP tool failed due to {str(err)}" if result: result_list.append((tool_name, result)) - + return result_list async def close(self): for client_key in self.clients: - await self.clients[client_key].close() \ No newline at end of file + await self.clients[client_key].close() diff --git a/src/utils/operations/__init__.py b/src/utils/operations/__init__.py index 656cd9b..6b45289 100644 --- a/src/utils/operations/__init__.py +++ b/src/utils/operations/__init__.py @@ -1,3 +1,11 @@ -from .manager import OpRoles, OperationManager -from .base import Operation, StartActiveError, CloseInactiveError, UsedInactiveError -from .error import UnknownOpType, UnknownOpRole, UnknownOpID, DuplicateFilter, OperationUnloaded \ No newline at end of file +from .base import CloseInactiveError as CloseInactiveError +from .base import Operation as Operation +from .base import StartActiveError as StartActiveError +from .base import UsedInactiveError as UsedInactiveError +from .error import DuplicateFilter as DuplicateFilter +from .error import OperationUnloaded as OperationUnloaded +from .error import UnknownOpID as UnknownOpID +from .error import UnknownOpRole as UnknownOpRole +from .error import UnknownOpType as UnknownOpType +from .manager import OperationManager as OperationManager +from .manager import OpRoles as OpRoles diff --git a/src/utils/operations/base/__init__.py b/src/utils/operations/base/__init__.py index 643386c..1a5ec2a 100644 --- a/src/utils/operations/base/__init__.py +++ b/src/utils/operations/base/__init__.py @@ -1,2 +1,4 @@ -from .operation import Operation -from .error import StartActiveError, CloseInactiveError, UsedInactiveError \ No newline at end of file +from .error import CloseInactiveError as CloseInactiveError +from .error import StartActiveError as StartActiveError +from .error import UsedInactiveError as UsedInactiveError +from .operation import Operation as Operation diff --git a/src/utils/operations/base/error.py b/src/utils/operations/base/error.py index 67c0635..daa7b8f 100644 --- a/src/utils/operations/base/error.py +++ b/src/utils/operations/base/error.py @@ -1,11 +1,13 @@ class StartActiveError(Exception): def __init__(self, op_type: str, op_id: str): - super().__init__("Start called on already active {} operation {}".format(op_type, op_id)) - + super().__init__(f"Start called on already active {op_type} operation {op_id}") + + class CloseInactiveError(Exception): def __init__(self, op_type: str, op_id: str): - super().__init__("Close called on already inactive {} operation {}".format(op_type, op_id)) - + super().__init__(f"Close called on already inactive {op_type} operation {op_id}") + + class UsedInactiveError(Exception): def __init__(self, op_type: str, op_id: str): - super().__init__("Usage on inactive {} operation {}".format(op_type, op_id)) \ No newline at end of file + super().__init__(f"Usage on inactive {op_type} operation {op_id}") diff --git a/src/utils/operations/base/operation.py b/src/utils/operations/base/operation.py index ce609a2..c494b0b 100644 --- a/src/utils/operations/base/operation.py +++ b/src/utils/operations/base/operation.py @@ -1,56 +1,69 @@ -from typing import Dict, Any, AsyncGenerator -import time +from __future__ import annotations + import logging +import time +from collections.abc import AsyncGenerator +from typing import Any, Optional + +from .error import CloseInactiveError, StartActiveError, UsedInactiveError -from .error import StartActiveError, CloseInactiveError, UsedInactiveError class Operation: def __init__(self, op_type: str, op_id: str): self.op_type = op_type self.op_id = op_id - + self.active = False - - async def __call__(self, chunk_in: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]: - '''Generates a stream of chunks similar to chunk_in but augmented with new data''' - if not self.active: raise UsedInactiveError(self.op_type, self.op_id) + self.prompter: Optional["Prompter"] = None + + def bind_runtime(self, prompter: Optional["Prompter"] = None) -> Operation: + self.prompter = prompter + return self + + async def __call__(self, chunk_in: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]: + """Generates a stream of chunks similar to chunk_in but augmented with new data""" + if not self.active: + raise UsedInactiveError(self.op_type, self.op_id) start_time = time.perf_counter() - + kwargs = await self._parse_chunk(chunk_in) - + async for chunk_out in self._generate(**kwargs): # yield chunk_in | chunk_out yield chunk_out end_time = time.perf_counter() - logging.info("{} operation {} completed in {} ms".format(self.op_type, self.op_id, (end_time-start_time)*1000)) - - + logging.info( + f"{self.op_type} operation {self.op_id} completed in {(end_time - start_time) * 1000} ms" + ) + ## TO BE OVERRIDEN #### async def start(self) -> None: - '''General setup needed to start generated''' - if self.active: raise StartActiveError(self.op_type, self.op_id) - logging.info("Starting {} operation {}".format(self.op_type, self.op_id)) + """General setup needed to start generated""" + if self.active: + raise StartActiveError(self.op_type, self.op_id) + logging.info(f"Starting {self.op_type} operation {self.op_id}") self.active = True - + async def close(self) -> None: - '''Clean up resources before unloading''' - if not self.active: raise CloseInactiveError(self.op_type, self.op_id) - logging.info("Closing {} operation {}".format(self.op_type, self.op_id)) + """Clean up resources before unloading""" + if not self.active: + raise CloseInactiveError(self.op_type, self.op_id) + logging.info(f"Closing {self.op_type} operation {self.op_id}") self.active = False - + ## TO BE IMPLEMENTED #### - async def configure(self, config_d: Dict[str, Any]): - '''Configure and validate operation-specific configuration''' + async def configure(self, config_d: dict[str, Any]): + """Configure and validate operation-specific configuration""" raise NotImplementedError - - async def get_configuration(self) -> Dict[str, Any]: - '''Returns values of configurable fields''' + + async def get_configuration(self) -> dict[str, Any]: + """Returns values of configurable fields""" + raise NotImplementedError + + async def _parse_chunk(self, chunk_in: dict[str, Any]) -> dict[str, Any]: + """Extract information from input for use in _generate""" raise NotImplementedError - - async def _parse_chunk(self, chunk_in: Dict[str, Any]) -> Dict[str, Any]: - '''Extract information from input for use in _generate''' + + async def _generate(self, **kwargs) -> AsyncGenerator[dict[str, Any], None]: + """Generate a output stream""" raise NotImplementedError - - async def _generate(self, **kwargs) -> AsyncGenerator[Dict[str, Any], None]: - '''Generate a output stream''' - raise NotImplementedError \ No newline at end of file diff --git a/src/utils/operations/embedding/base.py b/src/utils/operations/embedding/base.py index 0b64f0e..76e172b 100644 --- a/src/utils/operations/embedding/base.py +++ b/src/utils/operations/embedding/base.py @@ -1,51 +1,51 @@ -''' +""" Embedding Operations (at minimum) require the following fields for input chunks: - content: (str) text to be embedded Adds to chunk: - embedding: (str) UTF-8 string containing base64 bytes -''' +""" -from typing import Dict, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import Any from ..base import Operation + class EmbeddingOperation(Operation): def __init__(self, op_id: str): super().__init__("EMBEDDING", op_id) - + ## TO BE OVERRIDEN #### async def start(self) -> None: - '''General setup needed to start generated''' + """General setup needed to start generated""" await super().start() - + async def close(self) -> None: - '''Clean up resources before unloading''' + """Clean up resources before unloading""" await super().close() - - async def _parse_chunk(self, chunk_in: Dict[str, Any]) -> Dict[str, Any]: - '''Extract information from input for use in _generate''' + + async def _parse_chunk(self, chunk_in: dict[str, Any]) -> dict[str, Any]: + """Extract information from input for use in _generate""" assert "content" in chunk_in assert isinstance(chunk_in["content"], str) assert len(chunk_in["content"]) > 0 - - return { - "content": chunk_in["content"] - } - + + return {"content": chunk_in["content"]} + ## TO BE IMPLEMENTED #### - async def configure(self, config_d: Dict[str, Any]): - '''Configure and validate operation-specific configuration''' + async def configure(self, config_d: dict[str, Any]): + """Configure and validate operation-specific configuration""" raise NotImplementedError - - async def get_configuration(self) -> Dict[str, Any]: - '''Returns values of configurable fields''' + + async def get_configuration(self) -> dict[str, Any]: + """Returns values of configurable fields""" raise NotImplementedError - - async def _generate(self, content: str = None, **kwargs) -> AsyncGenerator[Dict[str, Any], None]: - '''Generate a output stream''' + + async def _generate( + self, content: str = None, **kwargs + ) -> AsyncGenerator[dict[str, Any], None]: + """Generate a output stream""" raise NotImplementedError - - yield { - "embedding": b"" - } \ No newline at end of file + + yield {"embedding": b""} diff --git a/src/utils/operations/embedding/llamacpp.py b/src/utils/operations/embedding/llamacpp.py new file mode 100644 index 0000000..e87cbf1 --- /dev/null +++ b/src/utils/operations/embedding/llamacpp.py @@ -0,0 +1,119 @@ +import base64 +import struct + +import httpx + +from utils.helpers.subprocess_server import ( + allocate_port, + bin_executable, + start_shell_process, + stop_process, +) + +from .base import EmbeddingOperation + + +def _embedding_to_base64(float_list: list[float]) -> str: + format_string = "<" + "f" * len(float_list) + packed_bytes = struct.pack(format_string, *float_list) + return base64.b64encode(packed_bytes).decode("utf-8") + + +def _extract_embedding_vector(body: object) -> list[float]: + if isinstance(body, dict) and "embedding" in body: + vec = body["embedding"] + elif isinstance(body, list) and body: + first = body[0] + vec = first["embedding"] if isinstance(first, dict) else first + else: + raise RuntimeError(f"Unexpected llamacpp embedding response: {body!r}") + + if not vec: + raise RuntimeError("llamacpp embedding response was empty") + + if isinstance(vec[0], (list, tuple)): + width = len(vec[0]) + return [sum(row[i] for row in vec) / len(vec) for i in range(width)] + + return [float(x) for x in vec] + + +class LlamaCPPEmbedding(EmbeddingOperation): + def __init__(self): + super().__init__("llamacpp") + self.uri = None + self.model_filepath = None + self.pooling = "mean" + self.embd_normalize = 2 + self._server_process = None + self._port: int | None = None + self._label: str = "llama.cpp embedding server" + self._http: httpx.AsyncClient | None = None + + async def start(self) -> None: + await super().start() + if self._server_process is not None: + return + + server = bin_executable("llama-server") + self._port = allocate_port() + cmd = ( + f'"{server}" -m "{self.model_filepath}" --host 127.0.0.1 --port {self._port} ' + f"--embeddings --pooling {self.pooling}" + ) + self._server_process = start_shell_process(cmd, label=self._label) + self.uri = f"http://127.0.0.1:{self._port}" + self._http = httpx.AsyncClient(base_url=self.uri, timeout=httpx.Timeout(600.0)) + + async def close(self) -> None: + if self._http is not None: + await self._http.aclose() + self._http = None + stop_process(self._server_process, label=self._label) + self._server_process = None + self._port = None + self.uri = None + await super().close() + + async def configure(self, config_d): + if "model_filepath" in config_d: + self.model_filepath = str(config_d["model_filepath"]) + if "pooling" in config_d: + self.pooling = str(config_d["pooling"]) + if "embd_normalize" in config_d: + self.embd_normalize = int(config_d["embd_normalize"]) + + assert self.model_filepath is not None and len(self.model_filepath) > 0 + assert self.pooling in {"none", "mean", "cls", "last", "rank"} + assert self.embd_normalize in {-1, 0, 1, 2} or self.embd_normalize > 2 + + async def get_configuration(self): + return { + "model_filepath": self.model_filepath, + "pooling": self.pooling, + "embd_normalize": self.embd_normalize, + } + + async def _check_health(self) -> None: + if self._http is None: + raise RuntimeError("LlamaCPPEmbedding server is not running") + try: + health_resp = await self._http.get("/health", timeout=5.0) + health_resp.raise_for_status() + except httpx.HTTPError as e: + raise RuntimeError(f"LlamaCPPEmbedding server health check failed: {e}") from e + + async def _generate(self, content: str = None, **kwargs): + await self._check_health() + + try: + response = await self._http.post( + "/embedding", + json={"content": content, "embd_normalize": self.embd_normalize}, + ) + response.raise_for_status() + except httpx.HTTPError as e: + raise Exception(f"Failed to get embedding result: {e}") from e + + float_list = _extract_embedding_vector(response.json()) + yield {"embedding": _embedding_to_base64(float_list)} diff --git a/src/utils/operations/embedding/openai.py b/src/utils/operations/embedding/openai.py index 752ecfb..55c93c1 100644 --- a/src/utils/operations/embedding/openai.py +++ b/src/utils/operations/embedding/openai.py @@ -1,58 +1,54 @@ -from openai import AsyncOpenAI -import struct import base64 +import struct + +from openai import AsyncOpenAI from .base import EmbeddingOperation + class OpenAIEmbedding(EmbeddingOperation): def __init__(self): super().__init__("openai") self.client = None - + self.base_url = "https://api.openai.com/v1/" self.model = "text-embedding-3-small" self.dimensions = 1536 - + async def start(self): await super().start() self.client = AsyncOpenAI(base_url=self.base_url) - + async def close(self): await super().close() await self.client.close() self.client = None - + async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if "base_url" in config_d: self.base_url = str(config_d['base_url']) - if "model" in config_d: self.model = str(config_d['model']) - if "dimensions" in config_d: self.dimensions = int(config_d['dimensions']) + """Configure and validate operation-specific configuration""" + if "base_url" in config_d: + self.base_url = str(config_d["base_url"]) + if "model" in config_d: + self.model = str(config_d["model"]) + if "dimensions" in config_d: + self.dimensions = int(config_d["dimensions"]) assert self.base_url is not None and len(self.base_url) > 0 assert self.model is not None and len(self.model) > 0 assert self.dimensions in [1536] - + async def get_configuration(self): - '''Returns values of configurable fields''' - return { - "base_url": self.base_url, - "model": self.model, - "dimensions": self.dimensions - } + """Returns values of configurable fields""" + return {"base_url": self.base_url, "model": self.model, "dimensions": self.dimensions} async def _generate(self, content: str = None, **kwargs): - response = await self.client.embeddings.create( # dimension 1536 default for small - model=self.model, - input=content, - dimensions=self.dimensions, - encoding_format="float" + response = await self.client.embeddings.create( # dimension 1536 default for small + model=self.model, input=content, dimensions=self.dimensions, encoding_format="float" ) float_list = response.data[0].embedding - format_string = '<' + 'f' * len(float_list) + format_string = "<" + "f" * len(float_list) packed_bytes = struct.pack(format_string, *float_list) - result = base64.b64encode(packed_bytes).decode('utf-8') + result = base64.b64encode(packed_bytes).decode("utf-8") - yield { - "embedding": result - } \ No newline at end of file + yield {"embedding": result} diff --git a/src/utils/operations/error.py b/src/utils/operations/error.py index 3f55efe..a8ec840 100644 --- a/src/utils/operations/error.py +++ b/src/utils/operations/error.py @@ -1,20 +1,26 @@ class UnknownOpType(Exception): def __init__(self, op_type: str): - super().__init__("No operation of type {}".format(op_type)) - + super().__init__(f"No operation of type {op_type}") + + class UnknownOpRole(Exception): def __init__(self, op_role: str): - super().__init__("No operation of role {}".format(op_role)) - + super().__init__(f"No operation of role {op_role}") + + class UnknownOpID(Exception): def __init__(self, op_type: str, op_id): - super().__init__("No operation of type {} with id {}".format(op_type, op_id)) - + super().__init__(f"No operation of type {op_type} with id {op_id}") + + class DuplicateFilter(Exception): def __init__(self, op_type: str, op_id): - super().__init__("Can not add already active {} {}".format(op_type, op_id)) - + super().__init__(f"Can not add already active {op_type} {op_id}") + + class OperationUnloaded(Exception): def __init__(self, op_type: str, op_id: str = None): - if op_id: super().__init__("No operation {} with id {} loaded".format(op_type, op_id)) - else: super().__init__("No operation of type {} loaded".format(op_type)) \ No newline at end of file + if op_id: + super().__init__(f"No operation {op_type} with id {op_id} loaded") + else: + super().__init__(f"No operation of type {op_type} loaded") diff --git a/src/utils/operations/filter_audio/base.py b/src/utils/operations/filter_audio/base.py index 65c2e2c..ee36b2c 100644 --- a/src/utils/operations/filter_audio/base.py +++ b/src/utils/operations/filter_audio/base.py @@ -1,4 +1,4 @@ -''' +""" Filter audio operations (at minimum) require the following fields for input chunks: - audio_bytes: (bytes) pcm audio data - sr: (int) sample rate @@ -10,27 +10,29 @@ - sr: (int) new sample rate - sw: (int) new sample width - ch: (int) new audio channels -''' +""" -from typing import Dict, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import Any from ..base import Operation + class FilterAudioOperation(Operation): def __init__(self, op_id: str): super().__init__("FILTER_AUDIO", op_id) - + ## TO BE OVERRIDEN #### async def start(self) -> None: - '''General setup needed to start generated''' + """General setup needed to start generated""" await super().start() - + async def close(self) -> None: - '''Clean up resources before unloading''' + """Clean up resources before unloading""" await super().close() - - async def _parse_chunk(self, chunk_in: Dict[str, Any]) -> Dict[str, Any]: - '''Extract information from input for use in _generate''' + + async def _parse_chunk(self, chunk_in: dict[str, Any]) -> dict[str, Any]: + """Extract information from input for use in _generate""" assert "audio_bytes" in chunk_in assert isinstance(chunk_in["audio_bytes"], bytes) assert len(chunk_in["audio_bytes"]) > 0 @@ -43,30 +45,27 @@ async def _parse_chunk(self, chunk_in: Dict[str, Any]) -> Dict[str, Any]: assert "ch" in chunk_in assert isinstance(chunk_in["ch"], int) assert chunk_in["ch"] > 0 - + return { "audio_bytes": chunk_in["audio_bytes"], "sr": chunk_in["sr"], "sw": chunk_in["sw"], - "ch": chunk_in["ch"] + "ch": chunk_in["ch"], } - + ## TO BE IMPLEMENTED #### - async def configure(self, config_d: Dict[str, Any]): - '''Configure and validate operation-specific configuration''' + async def configure(self, config_d: dict[str, Any]): + """Configure and validate operation-specific configuration""" raise NotImplementedError - - async def get_configuration(self) -> Dict[str, Any]: - '''Returns values of configurable fields''' + + async def get_configuration(self) -> dict[str, Any]: + """Returns values of configurable fields""" raise NotImplementedError - - async def _generate(self, audio_bytes: bytes = None, sr: int = None, sw: int = None, ch: int = None, **kwargs) -> AsyncGenerator[Dict[str, Any], None]: - '''Generate a output stream''' + + async def _generate( + self, audio_bytes: bytes = None, sr: int = None, sw: int = None, ch: int = None, **kwargs + ) -> AsyncGenerator[dict[str, Any], None]: + """Generate a output stream""" raise NotImplementedError - - yield { - "audio_bytes": b'', - "sr": 123, - "sw": 123, - "ch": 123 - } \ No newline at end of file + + yield {"audio_bytes": b"", "sr": 123, "sw": 123, "ch": 123} diff --git a/src/utils/operations/filter_audio/pitch.py b/src/utils/operations/filter_audio/pitch.py index cc4f082..94b7d52 100644 --- a/src/utils/operations/filter_audio/pitch.py +++ b/src/utils/operations/filter_audio/pitch.py @@ -1,30 +1,25 @@ -from utils.config import Config from utils.helpers.audio import pitch_audio from .base import FilterAudioOperation -class PitchFilter(FilterAudioOperation): + +class PitchFilter(FilterAudioOperation): def __init__(self): super().__init__("pitch") - + self.pitch_amount: int = 0 - + async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if "pitch_amount" in config_d: self.pitch_amount = str(config_d["pitch_amount"]) - + """Configure and validate operation-specific configuration""" + if "pitch_amount" in config_d: + self.pitch_amount = str(config_d["pitch_amount"]) + async def get_configuration(self): - '''Returns values of configurable fields''' - return { - "pitch_amount": self.pitch_amount - } + """Returns values of configurable fields""" + return {"pitch_amount": self.pitch_amount} - async def _generate(self, audio_bytes: bytes = None, sr: int = None, sw: int = None, ch: int = None, **kwargs): + async def _generate( + self, audio_bytes: bytes = None, sr: int = None, sw: int = None, ch: int = None, **kwargs + ): ab, sr, sw, ch = pitch_audio(audio_bytes, sr, sw, ch, self.pitch_amount) - yield { - "audio_bytes": ab, - "sr": sr, - "sw": sw, - "ch": ch - } - \ No newline at end of file + yield {"audio_bytes": ab, "sr": sr, "sw": sw, "ch": ch} diff --git a/src/utils/operations/filter_audio/rvc.py b/src/utils/operations/filter_audio/rvc.py index 3e1206e..c6032eb 100644 --- a/src/utils/operations/filter_audio/rvc.py +++ b/src/utils/operations/filter_audio/rvc.py @@ -1,95 +1,133 @@ +import os import wave -from rvc.modules.vc.modules import VC -import torch +from pathlib import Path + import fairseq +import torch +from rvc.modules.vc.modules import VC -from utils.config import Config +from utils.config import config from .base import FilterAudioOperation + class RVCFilter(FilterAudioOperation): TARGET_SR = 16000 TARGET_SW = 2 TARGET_CH = 1 - + DEFAULT_SPEAKER_ID = 0 + def __init__(self): super().__init__("rvc") - self.vc = None - - self.voice: str = None + self.vc: VC | None = None + self._speaker_id = self.DEFAULT_SPEAKER_ID + + self.voice: str | None = None self.f0_up_key: int = 0 self.f0_method: str = "rmvpe" - self.f0_file: str = None - self.index_file: str = None + self.f0_filepath: str | None = None + self.index_filepath: str | None = None self.index_rate: float = 0 self.filter_radius: int = 3 - self.resample_sr: int = 0 + self.resample_sr: int = 0 self.rms_mix_rate: float = 0 self.protect: float = 0.5 - + torch.serialization.add_safe_globals([fairseq.data.dictionary.Dictionary]) - + + def _model_id(self) -> str: + if not self.voice: + raise RuntimeError("RVC voice model not configured") + return self.voice if self.voice.endswith(".pth") else f"{self.voice}.pth" + async def start(self): await super().start() self.vc = VC() - model_name = self.voice if self.voice.endswith('.pth') else f"{self.voice}.pth" - self.vc.get_vc(model_name) - + _, _, default_index = self.vc.get_vc(self._model_id()) + if not self.index_filepath and default_index: + self.index_filepath = default_index + + async def close(self) -> None: + if self.vc is not None: + try: + self.vc.get_vc("") + except Exception: + pass + self.vc = None + await super().close() + async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if "voice" in config_d: self.voice = str(config_d["voice"]) - if "f0_up_key" in config_d: self.f0_up_key = int(config_d["f0_up_key"]) - if "f0_method" in config_d: self.f0_method = str(config_d["f0_method"]) - if "f0_file" in config_d: self.f0_file = str(config_d["f0_file"]) - if "index_file" in config_d: self.index_file = str(config_d["index_file"]) - if "index_rate" in config_d: self.index_rate = float(config_d["index_rate"]) - if "filter_radius" in config_d: self.filter_radius = int(config_d["filter_radius"]) - if "resample_sr" in config_d: self.resample_sr = int(config_d["resample_sr"]) - if "rms_mix_rate" in config_d: self.rms_mix_rate = float(config_d["rms_mix_rate"]) - if "protect" in config_d: self.protect = float(config_d["protect"]) - + """Configure and validate operation-specific configuration""" + if "voice" in config_d: + self.voice = str(config_d["voice"]) + if "f0_up_key" in config_d: + self.f0_up_key = int(config_d["f0_up_key"]) + if "f0_method" in config_d: + self.f0_method = str(config_d["f0_method"]) + if "f0_filepath" in config_d: + self.f0_filepath = str(config_d["f0_filepath"]) + if "index_filepath" in config_d: + self.index_filepath = str(config_d["index_filepath"]) + if "index_rate" in config_d: + self.index_rate = float(config_d["index_rate"]) + if "filter_radius" in config_d: + self.filter_radius = int(config_d["filter_radius"]) + if "resample_sr" in config_d: + self.resample_sr = int(config_d["resample_sr"]) + if "rms_mix_rate" in config_d: + self.rms_mix_rate = float(config_d["rms_mix_rate"]) + if "protect" in config_d: + self.protect = float(config_d["protect"]) + # TODO check assertions - + async def get_configuration(self): - '''Returns values of configurable fields''' + """Returns values of configurable fields""" return { "voice": self.voice, "f0_up_key": self.f0_up_key, "f0_method": self.f0_method, - "f0_file": self.f0_file, - "index_file": self.index_file, + "f0_filepath": self.f0_filepath, + "index_filepath": self.index_filepath, "index_rate": self.index_rate, "filter_radius": self.filter_radius, "resample_sr": self.resample_sr, "rms_mix_rate": self.rms_mix_rate, - "protect": self.protect + "protect": self.protect, } - async def _generate(self, audio_bytes: bytes = None, sr: int = None, sw: int = None, ch: int = None, **kwargs): - with wave.open(Config().ffmpeg_working_src, 'wb') as f: + async def _generate( + self, audio_bytes: bytes = None, sr: int = None, sw: int = None, ch: int = None, **kwargs + ): + if self.vc is None: + raise RuntimeError("RVC filter is not started") + + input_path = Path(config.ffmpeg_working_src) + with wave.open(str(input_path), "wb") as f: f.setframerate(sr) f.setsampwidth(sw) f.setnchannels(ch) f.writeframes(audio_bytes) - - tgt_sr, audio_opt, times, _ = self.vc.vc_single( - 1, - Config().ffmpeg_working_src, + + tgt_sr, audio_opt, _, info = self.vc.vc_inference( + self._speaker_id, + input_path, f0_up_key=self.f0_up_key, f0_method=self.f0_method, - f0_file=self.f0_file, - index_file=self.index_file, + f0_file=self.f0_filepath, + index_file=self.index_filepath, index_rate=self.index_rate, filter_radius=self.filter_radius, resample_sr=self.resample_sr, rms_mix_rate=self.rms_mix_rate, protect=self.protect, ) - + if audio_opt is None: + raise RuntimeError(info or "RVC inference failed") + yield { "audio_bytes": audio_opt.tobytes(), "sr": tgt_sr, "sw": self.TARGET_SW, - "ch": self.TARGET_CH + "ch": self.TARGET_CH, } - \ No newline at end of file diff --git a/src/utils/operations/filter_text/base.py b/src/utils/operations/filter_text/base.py index e96d23a..615230c 100644 --- a/src/utils/operations/filter_text/base.py +++ b/src/utils/operations/filter_text/base.py @@ -1,51 +1,51 @@ -''' +""" Filter text operations (at minimum) require the following fields for input chunks: - content: (text) text to apply filter on Overwrites in chunk (OR augments chunk): - content: (str) text after filter application -''' +""" -from typing import Dict, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import Any from ..base import Operation + class FilterTextOperation(Operation): def __init__(self, op_id: str): super().__init__("FILTER_TEXT", op_id) - + ## TO BE OVERRIDEN #### async def start(self) -> None: - '''General setup needed to start generated''' + """General setup needed to start generated""" await super().start() - + async def close(self) -> None: - '''Clean up resources before unloading''' + """Clean up resources before unloading""" await super().close() - - async def _parse_chunk(self, chunk_in: Dict[str, Any]) -> Dict[str, Any]: - '''Extract information from input for use in _generate''' + + async def _parse_chunk(self, chunk_in: dict[str, Any]) -> dict[str, Any]: + """Extract information from input for use in _generate""" assert "content" in chunk_in assert isinstance(chunk_in["content"], str) assert len(chunk_in["content"]) > 0 - - return { - "content": chunk_in["content"] - } - + + return {"content": chunk_in["content"]} + ## TO BE IMPLEMENTED #### - async def configure(self, config_d: Dict[str, Any]): - '''Configure and validate operation-specific configuration''' + async def configure(self, config_d: dict[str, Any]): + """Configure and validate operation-specific configuration""" raise NotImplementedError - - async def get_configuration(self) -> Dict[str, Any]: - '''Returns values of configurable fields''' + + async def get_configuration(self) -> dict[str, Any]: + """Returns values of configurable fields""" raise NotImplementedError - - async def _generate(self, content: str = None, **kwargs) -> AsyncGenerator[Dict[str, Any], None]: - '''Generate a output stream''' + + async def _generate( + self, content: str = None, **kwargs + ) -> AsyncGenerator[dict[str, Any], None]: + """Generate a output stream""" raise NotImplementedError - - yield { - "content": "example response text" - } \ No newline at end of file + + yield {"content": "example response text"} diff --git a/src/utils/operations/filter_text/chunker_sentence.py b/src/utils/operations/filter_text/chunker_sentence.py index a8723de..b82a157 100644 --- a/src/utils/operations/filter_text/chunker_sentence.py +++ b/src/utils/operations/filter_text/chunker_sentence.py @@ -1,36 +1,34 @@ import spacy -from utils.config import Config +from utils.config import config from .base import FilterTextOperation + class SentenceChunkerFilter(FilterTextOperation): def __init__(self): super().__init__("chunker_sentence") self.nlp = None - + async def start(self): await super().start() - self.nlp = spacy.load(Config().spacy_model) - + self.nlp = spacy.load(config.spacy_model) + async def close(self): await super().close() self.nlp = None - + async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' + """Configure and validate operation-specific configuration""" return - + async def get_configuration(self): - '''Returns values of configurable fields''' + """Returns values of configurable fields""" return {} async def _generate(self, content: str = None, **kwargs): - '''Generate a output stream''' + """Generate a output stream""" sentences = [sent.text.strip() for sent in self.nlp(content).sents] - - for s in sentences: - yield { - "content": s - } + for s in sentences: + yield {"content": s} diff --git a/src/utils/operations/filter_text/emotion_roberta.py b/src/utils/operations/filter_text/emotion_roberta.py index b177e1f..ee26017 100644 --- a/src/utils/operations/filter_text/emotion_roberta.py +++ b/src/utils/operations/filter_text/emotion_roberta.py @@ -1,30 +1,36 @@ - -from transformers import pipeline import torch +from transformers import pipeline from .base import FilterTextOperation + class RobertaEmotionFilter(FilterTextOperation): def __init__(self): super().__init__("emotion_roberta") self.classifier = None - + async def start(self): await super().start() - self.classifier = pipeline(task="text-classification", model="SamLowe/roberta-base-go_emotions", top_k=1, device=('cuda' if torch.cuda.is_available() else 'cpu')) + self.classifier = pipeline( + task="text-classification", + model="SamLowe/roberta-base-go_emotions", + top_k=1, + device=("cuda" if torch.cuda.is_available() else "cpu"), + ) async def close(self): await super().close() del self.classifier - if torch.cuda.is_available(): torch.cuda.empty_cache() # clean cache on cuda - + if torch.cuda.is_available(): + torch.cuda.empty_cache() # clean cache on cuda + async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' + """Configure and validate operation-specific configuration""" return - + async def get_configuration(self): - '''Returns values of configurable fields''' + """Returns values of configurable fields""" return {} async def _generate(self, content: str = None, **kwargs): - yield {"content": content, "emotion": self.classifier(content)[0][0]['label']} \ No newline at end of file + yield {"content": content, "emotion": self.classifier(content)[0][0]["label"]} diff --git a/src/utils/operations/filter_text/filter_clean.py b/src/utils/operations/filter_text/filter_clean.py index 320ffa6..0febcdc 100644 --- a/src/utils/operations/filter_text/filter_clean.py +++ b/src/utils/operations/filter_text/filter_clean.py @@ -2,38 +2,36 @@ from .base import FilterTextOperation + class ResponseCleaningFilter(FilterTextOperation): def __init__(self): super().__init__("filter_clean") self.pattern = None - + async def start(self): await super().start() self.pattern = re.compile(r"\[[^\[\]]+\]:\s*") - + async def close(self): await super().close() - + async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' + """Configure and validate operation-specific configuration""" return - + async def get_configuration(self): - '''Returns values of configurable fields''' + """Returns values of configurable fields""" return {} async def _generate(self, content: str = None, **kwargs): - '''Generate a output stream''' + """Generate a output stream""" while True: match = self.pattern.search(content) if match: - tmp = content[:match.span()[0]] - tmp += content[match.span()[1]:] + tmp = content[: match.span()[0]] + tmp += content[match.span()[1] :] content = tmp else: break - - yield { - "content": content - } + yield {"content": content} diff --git a/src/utils/operations/filter_text/mod_koala.py b/src/utils/operations/filter_text/mod_koala.py index 2462ce8..cfa4c52 100644 --- a/src/utils/operations/filter_text/mod_koala.py +++ b/src/utils/operations/filter_text/mod_koala.py @@ -1,35 +1,40 @@ -from transformers import AutoModelForSequenceClassification, AutoTokenizer import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer + from .base import FilterTextOperation + class KoalaModerationFilter(FilterTextOperation): GOOD_LABEL = "OK" - + def __init__(self): super().__init__("mod_koala") self.model, self.tokenizer = None, None - + async def start(self): await super().start() - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.model = AutoModelForSequenceClassification.from_pretrained("KoalaAI/Text-Moderation").to(self.device) + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = AutoModelForSequenceClassification.from_pretrained( + "KoalaAI/Text-Moderation" + ).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained("KoalaAI/Text-Moderation") - + async def close(self): await super().close() del self.model, self.tokenizer - if torch.cuda.is_available(): torch.cuda.empty_cache() # clean cache on cuda - + if torch.cuda.is_available(): + torch.cuda.empty_cache() # clean cache on cuda + async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' + """Configure and validate operation-specific configuration""" return - + async def get_configuration(self): - '''Returns values of configurable fields''' + """Returns values of configurable fields""" return {} async def _generate(self, content: str = None, **kwargs): - '''Generate a output stream''' + """Generate a output stream""" # Classify content inputs = self.tokenizer(content, return_tensors="pt").to(self.device) outputs = self.model(**inputs) @@ -43,8 +48,8 @@ async def _generate(self, content: str = None, **kwargs): labels = [id2label[idx] for idx in range(len(probabilities))] # Sort labels by score - label_prob_pairs = list(zip(labels, probabilities)) - label_prob_pairs.sort(key=lambda item: item[1], reverse=True) + label_prob_pairs = list(zip(labels, probabilities, strict=False)) + label_prob_pairs.sort(key=lambda item: item[1], reverse=True) # Handle top classification top_label, _ = label_prob_pairs[0] diff --git a/src/utils/operations/manager.py b/src/utils/operations/manager.py index 3618fa4..f24c275 100644 --- a/src/utils/operations/manager.py +++ b/src/utils/operations/manager.py @@ -1,10 +1,38 @@ +from __future__ import annotations + +from collections.abc import AsyncGenerator, Awaitable, Callable +from dataclasses import dataclass from enum import Enum -from typing import Dict, List, AsyncGenerator, Any +from typing import Any, TypeVar -from .error import UnknownOpType, UnknownOpRole, UnknownOpID, DuplicateFilter, OperationUnloaded from .base import Operation -from utils.helpers.singleton import Singleton -from utils.config import Config +from .embedding.base import EmbeddingOperation +from .embedding.llamacpp import LlamaCPPEmbedding +from .embedding.openai import OpenAIEmbedding +from .error import DuplicateFilter, OperationUnloaded, UnknownOpID, UnknownOpRole, UnknownOpType +from .filter_audio.base import FilterAudioOperation +from .filter_audio.pitch import PitchFilter +from .filter_audio.rvc import RVCFilter +from .filter_text.base import FilterTextOperation +from .filter_text.chunker_sentence import SentenceChunkerFilter +from .filter_text.emotion_roberta import RobertaEmotionFilter +from .filter_text.filter_clean import ResponseCleaningFilter +from .filter_text.mod_koala import KoalaModerationFilter +from .stt.azure import AzureSTT +from .stt.base import STTOperation +from .stt.fish import FishSTT +from .stt.openai import OpenAISTT +from .stt.whispercpp import WhisperCPPSTT +from .t2t.base import T2TOperation +from .t2t.llamacpp import LlamaCPPT2T +from .t2t.openai import OpenAIT2T +from .tts.azure import AzureTTS +from .tts.base import TTSOperation +from .tts.fish import FishTTS +from .tts.melo import MeloTTS +from .tts.openai import OpenAITTS +from .tts.pytts import PyttsTTS + class OpTypes(Enum): STT = "stt" @@ -13,7 +41,8 @@ class OpTypes(Enum): FILTER_AUDIO = "filter_audio" FILTER_TEXT = "filter_text" EMBEDDING = "embedding" - + + class OpRoles(Enum): STT = "stt" MCP = "mcp" @@ -22,477 +51,242 @@ class OpRoles(Enum): FILTER_AUDIO = "filter_audio" FILTER_TEXT = "filter_text" EMBEDDING = "embedding" - + + +ROLE_TO_TYPE: dict[OpRoles, OpTypes] = { + OpRoles.STT: OpTypes.STT, + OpRoles.MCP: OpTypes.T2T, + OpRoles.T2T: OpTypes.T2T, + OpRoles.TTS: OpTypes.TTS, + OpRoles.FILTER_AUDIO: OpTypes.FILTER_AUDIO, + OpRoles.FILTER_TEXT: OpTypes.FILTER_TEXT, + OpRoles.EMBEDDING: OpTypes.EMBEDDING, +} + +OP_CLASSES: dict[OpTypes, dict[str, type[Operation]]] = { + OpTypes.STT: { + "fish": FishSTT, + "azure": AzureSTT, + "openai": OpenAISTT, + "whispercpp": WhisperCPPSTT, + }, + OpTypes.T2T: { + "openai": OpenAIT2T, + "llamacpp": LlamaCPPT2T, + }, + OpTypes.TTS: { + "azure": AzureTTS, + "fish": FishTTS, + "openai": OpenAITTS, + "melo": MeloTTS, + "pytts": PyttsTTS, + }, + OpTypes.FILTER_AUDIO: { + "rvc": RVCFilter, + "pitch": PitchFilter, + }, + OpTypes.FILTER_TEXT: { + "chunker_sentence": SentenceChunkerFilter, + "emotion_roberta": RobertaEmotionFilter, + "mod_koala": KoalaModerationFilter, + "filter_clean": ResponseCleaningFilter, + }, + OpTypes.EMBEDDING: { + "openai": OpenAIEmbedding, + "llamacpp": LlamaCPPEmbedding, + }, +} + + def role_to_type(op_role: OpRoles) -> OpTypes: - match op_role: - case OpRoles.STT: - return OpTypes.STT - case OpRoles.MCP: - return OpTypes.T2T - case OpRoles.T2T: - return OpTypes.T2T - case OpRoles.TTS: - return OpTypes.TTS - case OpRoles.FILTER_AUDIO: - return OpTypes.FILTER_AUDIO - case OpRoles.FILTER_TEXT: - return OpTypes.FILTER_TEXT - case OpRoles.EMBEDDING: - return OpTypes.EMBEDDING - case _: - raise UnknownOpRole(op_role) - - -def load_op(op_type: OpTypes, op_id: str): - ''' + try: + return ROLE_TO_TYPE[op_role] + except KeyError: + raise UnknownOpRole(op_role) from None + + +def load_op(op_type: OpTypes, op_id: str) -> Operation: + """ Return an operation, but do not saved to OperationManager - + Starting, usage and eventual closing of this operation is deferred to the caller. This is mainly used for temporarily loading an operation to be used, such as a filter used as a one-time preview and not intended to last whole session - ''' - match op_type: - case OpTypes.STT: - if op_id == "fish": - from .stt.fish import FishSTT - return FishSTT() - elif op_id == "azure": - from .stt.azure import AzureSTT - return AzureSTT() - elif op_id == "openai": - from .stt.openai import OpenAISTT - return OpenAISTT() - elif op_id == "kobold": - from .stt.kobold import KoboldSTT - return KoboldSTT() - else: - raise UnknownOpID("STT", op_id) - case OpTypes.T2T: - if op_id == "openai": - from .t2t.openai import OpenAIT2T - return OpenAIT2T() - elif op_id == "kobold": - from .t2t.kobold import KoboldT2T - return KoboldT2T() - else: - raise UnknownOpID("T2T", op_id) - case OpTypes.TTS: - if op_id == "azure": - from .tts.azure import AzureTTS - return AzureTTS() - elif op_id == "fish": - from .tts.fish import FishTTS - return FishTTS() - elif op_id == "openai": - from .tts.openai import OpenAITTS - return OpenAITTS() - elif op_id == "kobold": - from .tts.kobold import KoboldTTS - return KoboldTTS() - elif op_id == "melo": - from .tts.melo import MeloTTS - return MeloTTS() - elif op_id == "pytts": - from .tts.pytts import PyttsTTS - return PyttsTTS() - else: - raise UnknownOpID("TTS", op_id) - case OpTypes.FILTER_AUDIO: - if op_id == "rvc": - from .filter_audio.rvc import RVCFilter - return RVCFilter() - elif op_id == "pitch": - from .filter_audio.pitch import PitchFilter - return PitchFilter() - else: - raise UnknownOpID("FILTER_AUDIO", op_id) - case OpTypes.FILTER_TEXT: - if op_id == "chunker_sentence": - from .filter_text.chunker_sentence import SentenceChunkerFilter - return SentenceChunkerFilter() - elif op_id == "emotion_roberta": - from .filter_text.emotion_roberta import RobertaEmotionFilter - return RobertaEmotionFilter() - elif op_id == "mod_koala": - from .filter_text.mod_koala import KoalaModerationFilter - return KoalaModerationFilter() - elif op_id == "filter_clean": - from .filter_text.filter_clean import ResponseCleaningFilter - return ResponseCleaningFilter() - else: - raise UnknownOpID("FILTER_TEXT", op_id) - case OpTypes.EMBEDDING: - if op_id == "openai": - from .embedding.openai import OpenAIEmbedding - return OpenAIEmbedding() - else: - raise UnknownOpID("EMBEDDING", op_id) - case _: - # Should never get here if op_role is indeed OpRole - raise UnknownOpRole(op_type) - -class OperationManager(metaclass=Singleton): - def __init__(self): - self.stt = None - self.mcp = None - self.t2t = None - self.tts = None - self.filter_audio = list() - self.filter_text = list() - self.embedding = None - - def get_operation(self, op_role: OpRoles) -> Operation: - match op_role: - case OpRoles.STT: - return self.stt - case OpRoles.MCP: - return self.mcp - case OpRoles.T2T: - return self.t2t - case OpRoles.TTS: - return self.tts - case OpRoles.FILTER_AUDIO: - return self.filter_audio - case OpRoles.FILTER_TEXT: - return self.filter_text - case OpRoles.EMBEDDING: - return self.embedding - case _: - # Should never get here if op_role is indeed OpRoles - raise UnknownOpRole(op_role) - - def get_operation_all(self) -> Dict[str, Operation | List[Operation]]: - return { - "stt": self.get_operation(OpRoles.STT), - "mcp": self.get_operation(OpRoles.MCP), - "t2t": self.get_operation(OpRoles.T2T), - "tts": self.get_operation(OpRoles.TTS), - "filter_audio": self.get_operation(OpRoles.FILTER_AUDIO), - "filter_text": self.get_operation(OpRoles.FILTER_TEXT), - "embedding": self.get_operation(OpRoles.EMBEDDING), - } - - async def get_configuration( - self, - op_role: OpRoles, - op_id: str = None - ): - '''Get configuration for a loaded operation''' - match op_role: - case OpRoles.STT: - if not self.stt: - raise OperationUnloaded("STT") - elif op_id and self.stt and self.stt.op_id != op_id: - raise OperationUnloaded("STT", op_id=op_id) - - return await self.stt.get_configuration() - case OpRoles.MCP: - if not self.mcp: - raise OperationUnloaded("MCP") - elif op_id and self.mcp and self.mcp.op_id != op_id: - raise OperationUnloaded("MCP", op_id=op_id) - - return await self.mcp.get_configuration() - case OpRoles.T2T: - if not self.t2t: - raise OperationUnloaded("T2T") - elif op_id and self.t2t and self.t2t.op_id != op_id: - raise OperationUnloaded("T2T", op_id=op_id) - - return await self.t2t.get_configuration() - case OpRoles.TTS: - if not self.tts: - raise OperationUnloaded("TTS") - elif op_id and self.tts and self.tts.op_id != op_id: - raise OperationUnloaded("TTS", op_id=op_id) - - return await self.tts.get_configuration() - case OpRoles.FILTER_AUDIO: - assert op_id is not None - - for op in self.filter_audio: - if op.op_id == op_id: - return await op.get_configuration() - raise OperationUnloaded("FILTER_AUDIO", op_id=op_id) - case OpRoles.FILTER_TEXT: + """ + try: + op_class = OP_CLASSES[op_type][op_id] + except KeyError: + raise UnknownOpID(op_type.name, op_id) from None + return op_class() + + +_T = TypeVar("_T") + + +@dataclass(frozen=True) +class OpRoleSlot: + attr: str + multi: bool = False + + +ROLE_SLOTS: dict[OpRoles, OpRoleSlot] = { + OpRoles.STT: OpRoleSlot("stt"), + OpRoles.MCP: OpRoleSlot("mcp"), + OpRoles.T2T: OpRoleSlot("t2t"), + OpRoles.TTS: OpRoleSlot("tts"), + OpRoles.FILTER_AUDIO: OpRoleSlot("filter_audio", multi=True), + OpRoles.FILTER_TEXT: OpRoleSlot("filter_text", multi=True), + OpRoles.EMBEDDING: OpRoleSlot("embedding"), +} + + +class OperationManager: + def __init__(self, prompter: "Prompter"): + self.prompter = prompter + self.stt: STTOperation | None = None + self.mcp: T2TOperation | None = None + self.t2t: T2TOperation | None = None + self.tts: TTSOperation | None = None + self.filter_audio: list[FilterAudioOperation] = [] + self.filter_text: list[FilterTextOperation] = [] + self.embedding: EmbeddingOperation | None = None + + @staticmethod + def _slot(op_role: OpRoles) -> OpRoleSlot: + try: + return ROLE_SLOTS[op_role] + except KeyError: + raise UnknownOpRole(op_role) from None + + def _storage(self, op_role: OpRoles) -> Operation | list[Operation] | None: + return getattr(self, self._slot(op_role).attr) + + def _resolve_operation( + self, op_role: OpRoles, op_id: str | None = None, *, require_id: bool = False + ) -> Operation: + slot = self._slot(op_role) + label = op_role.name + + if slot.multi: + if require_id: assert op_id is not None - - for op in self.filter_text: - if op.op_id == op_id: - return await op.get_configuration() - raise OperationUnloaded("FILTER_AUDIO", op_id=op_id) - case OpRoles.EMBEDDING: - if not self.embedding: - raise OperationUnloaded("EMBEDDING") - elif op_id and self.embedding and self.embedding.op_id != op_id: - raise OperationUnloaded("EMBEDDING", op_id=op_id) - - return await self.embedding.get_configuration() - case _: - # Should never get here if op_role is indeed OpRoles - raise UnknownOpRole(op_role) - - async def load_operation(self, op_role: OpRoles, op_id: str, op_details: Dict[str, Any]) -> None: - '''Load, start, and save an Operation in the OperationManager''' - if op_role == OpRoles.FILTER_AUDIO: - for op in self.filter_audio: - if op.op_id == op_id: raise DuplicateFilter("FILTER_AUDIO", op_id) - if op_role == OpRoles.FILTER_TEXT: - for op in self.filter_text: - if op.op_id == op_id: raise DuplicateFilter("FILTER_TEXT", op_id) - - new_op = load_op(role_to_type(op_role), op_id) + if op_id is None: + raise OperationUnloaded(label) + for op in getattr(self, slot.attr): + if op.op_id == op_id: + return op + raise OperationUnloaded(label, op_id=op_id) + + op: Operation | None = getattr(self, slot.attr) + if not op: + raise OperationUnloaded(label) + if op_id is not None and op.op_id != op_id: + raise OperationUnloaded(label, op_id=op_id) + return op + + async def _act_on_loaded_operation( + self, op_role: OpRoles, op_id: str, action: Callable[[Operation], Awaitable[_T]] + ) -> _T: + op = self._resolve_operation(op_role, op_id, require_id=self._slot(op_role).multi) + return await action(op) + + def get_operation(self, op_role: OpRoles) -> Operation | list[Operation] | None: + return self._storage(op_role) + + def get_operation_all(self) -> dict[str, Operation | list[Operation] | None]: + return {role.value: self.get_operation(role) for role in OpRoles} + + async def get_configuration(self, op_role: OpRoles, op_id: str = None): + """Get configuration for a loaded operation""" + return await self._act_on_loaded_operation( + op_role, op_id, lambda op: op.get_configuration() + ) + + def _bind_runtime(self, op: Operation) -> Operation: + return op.bind_runtime(prompter=self.prompter) + + def loose_load_operation(self, op_role: OpRoles, op_id: str) -> Operation: + return self._bind_runtime(load_op(role_to_type(op_role), op_id)) + + async def load_operation( + self, op_role: OpRoles, op_id: str, op_details: dict[str, Any] + ) -> None: + """Load, start, and save an Operation in the OperationManager""" + slot = self._slot(op_role) + if slot.multi: + for op in getattr(self, slot.attr): + if op.op_id == op_id: + raise DuplicateFilter(op_role.name, op_id) + + new_op = self._bind_runtime(load_op(role_to_type(op_role), op_id)) await new_op.configure(op_details) await new_op.start() - - match op_role: - case OpRoles.STT: - if self.stt: await self.stt.close() - self.stt = new_op - case OpRoles.MCP: - if self.mcp: await self.mcp.close() - self.mcp = new_op - case OpRoles.T2T: - if self.t2t: await self.t2t.close() - self.t2t = new_op - case OpRoles.TTS: - if self.tts: await self.tts.close() - self.tts = new_op - case OpRoles.FILTER_AUDIO: - self.filter_audio.append(new_op) - case OpRoles.FILTER_TEXT: - self.filter_text.append(new_op) - case OpRoles.EMBEDDING: - if self.embedding: await self.embedding.close() - self.embedding = new_op - case _: - # Should never get here if op_role is indeed OpRoles - raise UnknownOpRole(op_role) - - async def load_operations_from_config(self) -> None: - '''Load, start, and save all operations specified in config in the OperationManager''' - config = Config() - - await self.close_operation_all() - - operations = Config().operations - for op_details in operations: - op_role = OpRoles(op_details['role']) - op_id = op_details['id'] - await self.load_operation(op_role, op_id, op_details) - + + if slot.multi: + getattr(self, slot.attr).append(new_op) + return + + current: Operation | None = getattr(self, slot.attr) + if current: + await current.close() + setattr(self, slot.attr, new_op) + async def close_operation(self, op_role: OpRoles, op_id: str = None) -> None: - match op_role: - case OpRoles.STT: - if not self.stt: - raise OperationUnloaded("STT") - elif op_id and self.stt and self.stt.op_id != op_id: - raise OperationUnloaded("STT", op_id=op_id) - - await self.stt.close() - self.stt = None - case OpRoles.T2T: - if not self.mcp: - raise OperationUnloaded("MCP") - elif op_id and self.mcp and self.mcp.op_id != op_id: - raise OperationUnloaded("MCP", op_id=op_id) - - await self.mcp.close() - self.mcp = None - case OpRoles.T2T: - if not self.t2t: - raise OperationUnloaded("T2T") - elif op_id and self.t2t and self.t2t.op_id != op_id: - raise OperationUnloaded("T2T", op_id=op_id) - - await self.t2t.close() - self.t2t = None - case OpRoles.TTS: - if not self.tts: - raise OperationUnloaded("TTS") - elif op_id and self.tts and self.tts.op_id != op_id: - raise OperationUnloaded("TTS", op_id=op_id) - - await self.tts.close() - self.tts = None - case OpRoles.FILTER_AUDIO: - for op in self.filter_audio: - if op.op_id == op_id: - await op.close() - self.filter_audio.remove(op) - return - raise OperationUnloaded("FILTER_AUDIO", op_id=op_id) - case OpRoles.FILTER_TEXT: - for op in self.filter_text: - if op.op_id == op_id: - await op.close() - self.filter_text.remove(op) - return - raise OperationUnloaded("FILTER_TEXT", op_id=op_id) - case OpRoles.EMBEDDING: - if not self.embedding: - raise OperationUnloaded("EMBEDDING") - elif op_id and self.embedding and self.embedding.op_id != op_id: - raise OperationUnloaded("EMBEDDING", op_id=op_id) - - await self.embedding.close() - self.embedding = None - case _: - # Should never get here if op_role is indeed OpRoles - raise UnknownOpRole(op_role) - - async def close_operation_all(self): - if self.stt: - await self.stt.close() - self.stt = None - if self.mcp: - await self.mcp.close() - self.mcp = None - if self.t2t: - await self.t2t.close() - self.t2t = None - if self.tts: - await self.tts.close() - self.tts = None - for op in self.filter_audio: - await op.close() - self.filter_audio.clear() - for op in self.filter_text: - await op.close() - self.filter_text.clear() - if self.embedding: - await self.embedding.close() - self.embedding = None - - async def configure(self, - op_role: OpRoles, - config_d: Dict[str, Any], - op_id: str = None + slot = self._slot(op_role) + label = op_role.name + + if slot.multi: + ops: list[Operation] = getattr(self, slot.attr) + for op in ops: + if op.op_id == op_id: + await op.close() + ops.remove(op) + return + raise OperationUnloaded(label, op_id=op_id) + + op: Operation | None = getattr(self, slot.attr) + if not op: + raise OperationUnloaded(label) + if op_id is not None and op.op_id != op_id: + raise OperationUnloaded(label, op_id=op_id) + await op.close() + setattr(self, slot.attr, None) + + async def close_operation_all(self) -> None: + for slot in ROLE_SLOTS.values(): + storage = getattr(self, slot.attr) + if slot.multi: + for op in storage: + await op.close() + storage.clear() + elif storage is not None: + await storage.close() + setattr(self, slot.attr, None) + + async def configure(self, op_role: OpRoles, config_d: dict[str, Any], op_id: str = None): + """Configure an operation that has already been loaded prior""" + return await self._act_on_loaded_operation( + op_role, op_id, lambda op: op.configure(config_d) + ) + + async def _use_filter( + self, filter_list: list[Operation], filter_idx: int, chunk_in: dict[str, Any] ): - '''Configure an operation that has already been loaded prior''' - match op_role: - case OpRoles.STT: - if not self.stt: - raise OperationUnloaded("STT") - elif op_id and self.stt and self.stt.op_id != op_id: - raise OperationUnloaded("STT", op_id=op_id) - - return await self.stt.configure(config_d) - case OpRoles.MCP: - if not self.mcp: - raise OperationUnloaded("MCP") - elif op_id and self.mcp and self.mcp.op_id != op_id: - raise OperationUnloaded("MCP", op_id=op_id) - - return await self.mcp.configure(config_d) - case OpRoles.T2T: - if not self.t2t: - raise OperationUnloaded("T2T") - elif op_id and self.t2t and self.t2t.op_id != op_id: - raise OperationUnloaded("T2T", op_id=op_id) - - return await self.t2t.configure(config_d) - case OpRoles.TTS: - if not self.tts: - raise OperationUnloaded("TTS") - elif op_id and self.tts and self.tts.op_id != op_id: - raise OperationUnloaded("TTS", op_id=op_id) - - return await self.tts.configure(config_d) - case OpRoles.FILTER_AUDIO: - assert op_id is not None - - for op in self.filter_audio: - if op.op_id == op_id: - return await op.configure(config_d) - raise OperationUnloaded("FILTER_AUDIO", op_id=op_id) - case OpRoles.FILTER_TEXT: - assert op_id is not None - - for op in self.filter_text: - if op.op_id == op_id: - return await op.configure(config_d) - raise OperationUnloaded("FILTER_TEXT", op_id=op_id) - case OpRoles.EMBEDDING: - if not self.embedding: - raise OperationUnloaded("EMBEDDING") - elif op_id and self.embedding and self.embedding.op_id != op_id: - raise OperationUnloaded("EMBEDDING", op_id=op_id) - - return await self.embedding.configure(config_d) - case _: - # Should never get here if op_role is indeed OpRoles - raise UnknownOpRole(op_role) - - async def _use_filter(self, filter_list: List[Operation], filter_idx: int, chunk_in: Dict[str, Any]): - if filter_idx == len(filter_list): yield chunk_in - elif filter_idx < len(filter_list)-1: # Not last filter + if filter_idx >= len(filter_list): + yield chunk_in + elif filter_idx < len(filter_list) - 1: # Not last filter async for result_chunk in filter_list[filter_idx](chunk_in): - async for chunk_out in self._use_filter(filter_list, filter_idx+1, result_chunk): + async for chunk_out in self._use_filter(filter_list, filter_idx + 1, result_chunk): yield chunk_out - else: # Is last filter + else: # Is last filter async for chunk_out in filter_list[filter_idx](chunk_in): yield chunk_out - + def use_operation( - self, - op_role: OpRoles, - chunk_in: Dict[str, Any], - op_id: str = None - ) -> AsyncGenerator[Dict[str, Any], None]: - '''Use an operation that has already been loaded prior''' - match op_role: - case OpRoles.STT: - if not self.stt: - raise OperationUnloaded("STT") - elif op_id and self.stt and self.stt.op_id != op_id: - raise OperationUnloaded("STT", op_id=op_id) - - return self.stt(chunk_in) - case OpRoles.MCP: - if not self.mcp: - raise OperationUnloaded("MCP") - elif op_id and self.mcp and self.mcp.op_id != op_id: - raise OperationUnloaded("MCP", op_id=op_id) - - return self.mcp(chunk_in) - case OpRoles.T2T: - if not self.t2t: - raise OperationUnloaded("T2T") - elif op_id and self.t2t and self.t2t.op_id != op_id: - raise OperationUnloaded("T2T", op_id=op_id) - - return self.t2t(chunk_in) - case OpRoles.TTS: - if not self.tts: - raise OperationUnloaded("TTS") - elif op_id and self.tts and self.tts.op_id != op_id: - raise OperationUnloaded("TTS", op_id=op_id) - - return self.tts(chunk_in) - case OpRoles.FILTER_AUDIO: - if op_id: - for op in self.filter_audio: - if op.op_id == op_id: - return op(chunk_in) - raise OperationUnloaded("FILTER_AUDIO", op_id=op_id) - else: - return self._use_filter(self.filter_audio, 0, chunk_in) - case OpRoles.FILTER_TEXT: - if op_id: - for op in self.filter_text: - if op.op_id == op_id: - return op(chunk_in) - raise OperationUnloaded("FILTER_TEXT", op_id=op_id) - else: - return self._use_filter(self.filter_text, 0, chunk_in) - case OpRoles.EMBEDDING: - if not self.embedding: - raise OperationUnloaded("EMBEDDING") - elif op_id and self.embedding and self.embedding.op_id != op_id: - raise OperationUnloaded("EMBEDDING", op_id=op_id) - - return self.embedding(chunk_in) - case _: - # Should never get here if op_role is indeed OpRoles - raise UnknownOpType(op_role) + self, op_role: OpRoles, chunk_in: dict[str, Any], op_id: str = None + ) -> AsyncGenerator[dict[str, Any], None]: + """Use an operation that has already been loaded prior""" + slot = self._slot(op_role) + if slot.multi and op_id is None: + return self._use_filter(getattr(self, slot.attr), 0, chunk_in) + return self._resolve_operation(op_role, op_id)(chunk_in) diff --git a/src/utils/operations/stt/azure.py b/src/utils/operations/stt/azure.py index feae84a..a44b5d3 100644 --- a/src/utils/operations/stt/azure.py +++ b/src/utils/operations/stt/azure.py @@ -1,67 +1,73 @@ -import os import asyncio -import azure.cognitiveservices.speech as speechsdk -import logging +import os -from utils.config import Config +import azure.cognitiveservices.speech as speechsdk from .base import STTOperation + class AzureSTT(STTOperation): def __init__(self): super().__init__("azure") self.client = None - + self.language: str = "en-US" - + async def start(self) -> None: - '''General setup needed to start generated''' + """General setup needed to start generated""" await super().start() - + self.speech_config = speechsdk.SpeechConfig( - region=os.environ.get('AZURE_REGION'), - subscription=os.getenv("AZURE_API_KEY") + region=os.environ.get("AZURE_REGION"), subscription=os.getenv("AZURE_API_KEY") ) - + async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if "language" in config_d: self.model_id = str(config_d["language"]) + """Configure and validate operation-specific configuration""" + if "language" in config_d: + self.model_id = str(config_d["language"]) assert self.language is not None and len(self.language) > 0 - + async def get_configuration(self): - '''Returns values of configurable fields''' + """Returns values of configurable fields""" return {"language": self.language} - async def _generate(self, prompt: str = None, audio_bytes: bytes = None, sr: int = None, sw: int = None, ch: int = None, **kwargs): - '''Generate a output stream''' + async def _generate( + self, + prompt: str = None, + audio_bytes: bytes = None, + sr: int = None, + sw: int = None, + ch: int = None, + **kwargs, + ): + """Generate a output stream""" # Setup transcriber with audio metadata wave_format = speechsdk.audio.AudioStreamFormat( samples_per_second=sr, - bits_per_sample=sw*8, + bits_per_sample=sw * 8, channels=ch, - wave_stream_format=speechsdk.audio.AudioStreamWaveFormat.PCM + wave_stream_format=speechsdk.audio.AudioStreamWaveFormat.PCM, ) stream = speechsdk.audio.PushAudioInputStream(stream_format=wave_format) audio_config = speechsdk.audio.AudioConfig(stream=stream) transcriber = speechsdk.transcription.ConversationTranscriber( - speech_config=self.speech_config, - audio_config=audio_config, - language=self.language + speech_config=self.speech_config, audio_config=audio_config, language=self.language ) # Setup event callbacks transcription = "" done = asyncio.Event() done.clear() + def transcribed_cb(evt): nonlocal transcription if evt.result.reason == speechsdk.ResultReason.RecognizedSpeech: transcription += str(evt.result) - + def stop_cb(evt: speechsdk.SessionEventArgs): done.set() - + transcriber.transcribed.connect(transcribed_cb) transcriber.session_stopped.connect(stop_cb) transcriber.canceled.connect(stop_cb) @@ -73,4 +79,4 @@ def stop_cb(evt: speechsdk.SessionEventArgs): await done.wait() transcriber.stop_transcribing_async() - yield {"transcription": transcription} \ No newline at end of file + yield {"transcription": transcription} diff --git a/src/utils/operations/stt/base.py b/src/utils/operations/stt/base.py index 5225369..76ef2b9 100644 --- a/src/utils/operations/stt/base.py +++ b/src/utils/operations/stt/base.py @@ -1,4 +1,4 @@ -''' +""" STT Operations (at minimum) require the following fields for input chunks: - prompt: (str) initial words to help with transcription (Optional) - audio_bytes: (bytes) pcm audio data @@ -8,27 +8,29 @@ Adds to chunk: - transcription: (str) transcribed audio -''' +""" -from typing import Dict, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import Any from ..base import Operation + class STTOperation(Operation): def __init__(self, op_id: str): super().__init__("STT", op_id) - + ## TO BE OVERRIDEN #### async def start(self) -> None: - '''General setup needed to start generated''' + """General setup needed to start generated""" await super().start() - + async def close(self) -> None: - '''Clean up resources before unloading''' + """Clean up resources before unloading""" await super().close() - - async def _parse_chunk(self, chunk_in: Dict[str, Any]) -> Dict[str, Any]: - '''Extract information from input for use in _generate''' + + async def _parse_chunk(self, chunk_in: dict[str, Any]) -> dict[str, Any]: + """Extract information from input for use in _generate""" assert "audio_bytes" in chunk_in assert isinstance(chunk_in["audio_bytes"], bytes) assert len(chunk_in["audio_bytes"]) > 0 @@ -41,28 +43,34 @@ async def _parse_chunk(self, chunk_in: Dict[str, Any]) -> Dict[str, Any]: assert "ch" in chunk_in assert isinstance(chunk_in["ch"], int) assert chunk_in["ch"] > 0 - + return { "prompt": chunk_in.get("prompt", ""), "audio_bytes": chunk_in["audio_bytes"], "sr": chunk_in["sr"], "sw": chunk_in["sw"], - "ch": chunk_in["ch"] + "ch": chunk_in["ch"], } - + ## TO BE IMPLEMENTED #### - async def configure(self, config_d: Dict[str, Any]): - '''Configure and validate operation-specific configuration''' + async def configure(self, config_d: dict[str, Any]): + """Configure and validate operation-specific configuration""" raise NotImplementedError - - async def get_configuration(self) -> Dict[str, Any]: - '''Returns values of configurable fields''' + + async def get_configuration(self) -> dict[str, Any]: + """Returns values of configurable fields""" raise NotImplementedError - - async def _generate(self, prompt: str = None, audio_bytes: bytes = None, sr: int = None, sw: int = None, ch: int = None, **kwargs) -> AsyncGenerator[Dict[str, Any], None]: - '''Generate a output stream''' + + async def _generate( + self, + prompt: str = None, + audio_bytes: bytes = None, + sr: int = None, + sw: int = None, + ch: int = None, + **kwargs, + ) -> AsyncGenerator[dict[str, Any], None]: + """Generate a output stream""" raise NotImplementedError - - yield { - "transcription": "example transcribed text" - } \ No newline at end of file + + yield {"transcription": "example transcribed text"} diff --git a/src/utils/operations/stt/fish.py b/src/utils/operations/stt/fish.py index c8c7fce..7934044 100644 --- a/src/utils/operations/stt/fish.py +++ b/src/utils/operations/stt/fish.py @@ -1,45 +1,55 @@ import io -import wave import os +import wave -from fish_audio_sdk import Session, ASRRequest +from fish_audio_sdk import ASRRequest, Session from .base import STTOperation + class FishSTT(STTOperation): def __init__(self): super().__init__("fish") self.session = None - + async def start(self): await super().start() self.session = Session(os.getenv("FISH_API_KEY")) - + async def unload(self): await super().close() await self.session.close() self.session = None - - + async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' + """Configure and validate operation-specific configuration""" return - + async def get_configuration(self): - '''Returns values of configurable fields''' + """Returns values of configurable fields""" return {} - async def _generate(self, prompt: str = None, audio_bytes: bytes = None, sr: int = None, sw: int = None, ch: int = None, **kwargs): - '''Generate a output stream''' + async def _generate( + self, + prompt: str = None, + audio_bytes: bytes = None, + sr: int = None, + sw: int = None, + ch: int = None, + **kwargs, + ): + """Generate a output stream""" audio_data = io.BytesIO() - with wave.open(audio_data, 'wb') as f: + with wave.open(audio_data, "wb") as f: f.setframerate(sr) f.setsampwidth(sw) f.setnchannels(ch) f.writeframes(audio_bytes) audio_data.seek(0) - response = self.session.asr(ASRRequest(audio=audio_data.read(), language="en", ignore_timestamps=False)) + response = self.session.asr( + ASRRequest(audio=audio_data.read(), language="en", ignore_timestamps=False) + ) result = response.text - yield {"transcription": result} \ No newline at end of file + yield {"transcription": result} diff --git a/src/utils/operations/stt/kobold.py b/src/utils/operations/stt/kobold.py deleted file mode 100644 index 3127c91..0000000 --- a/src/utils/operations/stt/kobold.py +++ /dev/null @@ -1,70 +0,0 @@ -from io import BytesIO -import wave -import requests -import base64 - -from utils.config import Config -from utils.processes import ProcessManager, ProcessType - -from .base import STTOperation - -class KoboldSTT(STTOperation): - KOBOLD_LINK_ID = "kobold_stt" - - def __init__(self): - super().__init__("kobold") - self.uri = None - - self.suppress_non_speech: bool = True - self.langcode: str = "en" - - async def start(self) -> None: - '''General setup needed to start generated''' - await super().start() - await ProcessManager().link(self.KOBOLD_LINK_ID, ProcessType.KOBOLD) - self.uri = "http://127.0.0.1:{}".format(ProcessManager().get_process(ProcessType.KOBOLD).port) - - async def close(self) -> None: - '''Clean up resources before unloading''' - await super().close() - await ProcessManager().unlink(self.KOBOLD_LINK_ID, ProcessType.KOBOLD) - - async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if "suppress_non_speech" in config_d: self.suppress_non_speech = bool(config_d['suppress_non_speech']) - if "langcode" in config_d: self.langcode = str(config_d['langcode']) - - assert self.langcode is not None and len(self.langcode) > 0 - - async def get_configuration(self): - '''Returns values of configurable fields''' - return { - "suppress_non_speech": self.suppress_non_speech, - "langcode": self.langcode - } - - async def _generate(self, prompt: str = None, audio_bytes: bytes = None, sr: int = None, sw: int = None, ch: int = None, **kwargs): - '''Generate a output stream''' - audio_data = BytesIO() - with wave.open(audio_data, 'wb') as f: - f.setframerate(sr) - f.setsampwidth(sw) - f.setnchannels(ch) - f.writeframes(audio_bytes) - audio_data.seek(0) - - response = requests.post( - "{}/api/extra/transcribe".format(self.uri), - json={ - "prompt": prompt, - "suppress_non_speech": self.suppress_non_speech, - "langcode": self.langcode, - "audio_data": base64.b64encode(audio_data.read()).decode('utf-8') - }, - ) - - if response.status_code == 200: - result = response.json()['text'] - yield {"transcription": result} - else: - raise Exception(f"Failed to get STT result: {response.status_code} {response.reason}") \ No newline at end of file diff --git a/src/utils/operations/stt/openai.py b/src/utils/operations/stt/openai.py index 6626dab..534ee10 100644 --- a/src/utils/operations/stt/openai.py +++ b/src/utils/operations/stt/openai.py @@ -1,63 +1,72 @@ import wave -from openai import AsyncOpenAI from pathlib import Path -from utils.config import Config +from openai import AsyncOpenAI + +from utils.config import config from .base import STTOperation + class OpenAISTT(STTOperation): def __init__(self): super().__init__("openai") self.client = None - + self.base_url: str = "https://api.openai.com/v1/" self.model: str = "gpt-4o" self.language: str = "en" - + async def start(self) -> None: - '''General setup needed to start generated''' + """General setup needed to start generated""" await super().start() self.client = AsyncOpenAI(base_url=self.base_url) - + async def close(self) -> None: - '''Clean up resources before unloading''' + """Clean up resources before unloading""" await super().close() - self.client.close() + await self.client.close() self.client = None - + async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if "base_url" in config_d: self.base_url = str(config_d['base_url']) - if "model" in config_d: self.model = str(config_d['model']) - if "language" in config_d: self.language = str(config_d['language']) - + """Configure and validate operation-specific configuration""" + if "base_url" in config_d: + self.base_url = str(config_d["base_url"]) + if "model" in config_d: + self.model = str(config_d["model"]) + if "language" in config_d: + self.language = str(config_d["language"]) + assert self.base_url is not None and len(self.base_url) > 0 assert self.model is not None and len(self.model) > 0 assert self.language is not None and len(self.language) > 0 - + async def get_configuration(self): - '''Returns values of configurable fields''' - return { - "base_url": self.base_url, - "model": self.model, - "language": self.language - } - - async def _generate(self, prompt: str = None, audio_bytes: bytes = None, sr: int = None, sw: int = None, ch: int = None, **kwargs): - '''Generate a output stream''' - with wave.open(Config().stt_working_src, 'w') as f: + """Returns values of configurable fields""" + return {"base_url": self.base_url, "model": self.model, "language": self.language} + + async def _generate( + self, + prompt: str = None, + audio_bytes: bytes = None, + sr: int = None, + sw: int = None, + ch: int = None, + **kwargs, + ): + """Generate a output stream""" + with wave.open(config.stt_working_src, "w") as f: f.setframerate(sr) f.setsampwidth(sw) f.setnchannels(ch) f.writeframes(audio_bytes) transcription = await self.client.audio.transcriptions.create( - file=Path(Config().stt_working_src), + file=Path(config.stt_working_src), model=self.model, response_format="text", language=self.language, - prompt=prompt + prompt=prompt, ) - - yield {"transcription": transcription} \ No newline at end of file + + yield {"transcription": transcription} diff --git a/src/utils/operations/stt/whispercpp.py b/src/utils/operations/stt/whispercpp.py new file mode 100644 index 0000000..0d7abd8 --- /dev/null +++ b/src/utils/operations/stt/whispercpp.py @@ -0,0 +1,125 @@ +import wave +from io import BytesIO + +import httpx + +from utils.helpers.subprocess_server import ( + allocate_port, + bin_executable, + start_shell_process, + stop_process, +) + +from .base import STTOperation + + +class WhisperCPPSTT(STTOperation): + def __init__(self): + super().__init__("whispercpp") + self.uri = None + self.model_filepath = None + self.language = "en" + self.temperature = 0.0 + self.response_format = "json" + self._server_process = None + self._port: int | None = None + self._label: str = "whisper.cpp server" + self._http: httpx.AsyncClient | None = None + + async def start(self) -> None: + await super().start() + if self._server_process is not None: + return + + server = bin_executable("whisper-server") + self._port = allocate_port() + cmd = ( + f'"{server}" -m "{self.model_filepath}" --host 127.0.0.1 --port {self._port} ' + f"-l {self.language}" + ) + self._server_process = start_shell_process(cmd, label=self._label) + self.uri = f"http://127.0.0.1:{self._port}" + self._http = httpx.AsyncClient(base_url=self.uri, timeout=httpx.Timeout(600.0)) + + async def close(self) -> None: + if self._http is not None: + await self._http.aclose() + self._http = None + stop_process(self._server_process, label=self._label) + self._server_process = None + self._port = None + self.uri = None + await super().close() + + async def configure(self, config_d): + if "model_filepath" in config_d: + self.model_filepath = str(config_d["model_filepath"]) + if "language" in config_d: + self.language = str(config_d["language"]) + if "temperature" in config_d: + self.temperature = float(config_d["temperature"]) + if "response_format" in config_d: + self.response_format = str(config_d["response_format"]) + + assert self.model_filepath is not None and len(self.model_filepath) > 0 + assert self.language is not None and len(self.language) > 0 + assert self.response_format in {"json", "text", "verbose_json", "srt", "vtt"} + + async def get_configuration(self): + return { + "model_filepath": self.model_filepath, + "language": self.language, + "temperature": self.temperature, + "response_format": self.response_format, + } + + async def _check_health(self) -> None: + if self._http is None: + raise RuntimeError("WhisperCPPSTT server is not running") + try: + health_resp = await self._http.get("/v1/health", timeout=5.0) + health_resp.raise_for_status() + except httpx.HTTPError as e: + raise RuntimeError(f"WhisperCPPSTT server health check failed: {e}") from e + + async def _generate( + self, + prompt: str = None, + audio_bytes: bytes = None, + sr: int = None, + sw: int = None, + ch: int = None, + **kwargs, + ): + await self._check_health() + + audio_data = BytesIO() + with wave.open(audio_data, "wb") as f: + f.setframerate(sr) + f.setsampwidth(sw) + f.setnchannels(ch) + f.writeframes(audio_bytes) + audio_data.seek(0) + + try: + response = await self._http.post( + "/inference", + files={"file": ("audio.wav", audio_data.read(), "audio/wav")}, + data={ + "temperature": str(self.temperature), + "response_format": self.response_format, + }, + ) + response.raise_for_status() + except httpx.HTTPError as e: + raise Exception(f"Failed to get STT result: {e}") from e + + if self.response_format == "text": + text = response.text + else: + body = response.json() + text = body.get("text") or body.get("transcription") or "" + if not text and "segments" in body: + text = " ".join(seg.get("text", "") for seg in body["segments"]).strip() + + yield {"transcription": text} diff --git a/src/utils/operations/t2t/base.py b/src/utils/operations/t2t/base.py index cb514e8..64f9375 100644 --- a/src/utils/operations/t2t/base.py +++ b/src/utils/operations/t2t/base.py @@ -1,39 +1,42 @@ -''' +""" T2T Operations (at minimum) require the following fields for input chunks: - system_prompt: (str) System prompt text - user_prompt: (str) User prompt text Adds to chunk: - content: (str) Generated text -''' +""" -from typing import Dict, List, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import Any from ..base import Operation -from utils.prompter.message import Message + class T2TOperation(Operation): def __init__(self, op_id: str): super().__init__("T2T", op_id) - + # Set by the MCP client at runtime (not user YAML); constrains MCP tool-call output. + self.mcp_json_schema: dict[str, Any] | None = None + ## TO BE OVERRIDEN #### async def start(self) -> None: - '''General setup needed to start generated''' + """General setup needed to start generated""" await super().start() - + async def close(self) -> None: - '''Clean up resources before unloading''' + """Clean up resources before unloading""" await super().close() - - async def _parse_chunk(self, chunk_in: Dict[str, Any]) -> Dict[str, Any]: - '''Extract information from input for use in _generate''' + + async def _parse_chunk(self, chunk_in: dict[str, Any]) -> dict[str, Any]: + """Extract information from input for use in _generate""" # assert "system_prompt" in chunk_in # assert isinstance(chunk_in["system_prompt"], str) # assert len(chunk_in["system_prompt"]) > 0 # assert "user_prompt" in chunk_in # assert isinstance(chunk_in["user_prompt"], str) # assert len(chunk_in["user_prompt"]) > 0 - + # return { # "system_prompt": chunk_in["system_prompt"], # "user_prompt": chunk_in["user_prompt"], @@ -49,20 +52,25 @@ async def _parse_chunk(self, chunk_in: Dict[str, Any]) -> Dict[str, Any]: "messages": chunk_in["messages"], } - ## TO BE IMPLEMENTED #### - async def configure(self, config_d: Dict[str, Any]): - '''Configure and validate operation-specific configuration''' + async def configure(self, config_d: dict[str, Any]): + """Configure and validate operation-specific configuration""" raise NotImplementedError - - async def get_configuration(self) -> Dict[str, Any]: - '''Returns values of configurable fields''' + + async def get_configuration(self) -> dict[str, Any]: + """Returns values of configurable fields""" raise NotImplementedError - - async def _generate(self, instruction_prompt: str = None, messages: list = None, **kwargs) -> AsyncGenerator[Dict[str, Any], None]: - '''Generate a output stream''' + + async def _generate( + self, instruction_prompt: str = None, messages: list = None, **kwargs + ) -> AsyncGenerator[dict[str, Any], None]: + """Generate a output stream""" raise NotImplementedError - - yield { - "content": "example generated text" - } \ No newline at end of file + + yield {"content": "example generated text"} + + def set_mcp_json_schema(self, schema: dict[str, Any] | None) -> None: + """Apply a JSON schema for MCP structured output (provided by the MCP client).""" + if schema is not None: + assert isinstance(schema, dict) + self.mcp_json_schema = schema diff --git a/src/utils/operations/t2t/kobold.py b/src/utils/operations/t2t/kobold.py deleted file mode 100644 index 13ca1e3..0000000 --- a/src/utils/operations/t2t/kobold.py +++ /dev/null @@ -1,116 +0,0 @@ -import requests - -from utils.processes import ProcessManager, ProcessType - -from .base import T2TOperation -from utils.prompter.message import ChatMessage -from utils.prompter import Prompter - -class KoboldT2T(T2TOperation): - KOBOLD_LINK_ID = "kobold_t2t" - - def __init__(self): - super().__init__("kobold") - self.uri = None - - self.max_context_length: int = 2048 - self.max_length: int = 100 - self.rep_pen: float = 1.1 - self.rep_pen_range: int = 256 - self.rep_pen_slope: int = 1 - self.temperature: float = 0.5 - self.tfs: int = 1 - self.top_a: int = 0 - self.top_k: int = 100 - self.top_p: float = 0.9 - self.typical: int = 1 - - async def start(self) -> None: - '''General setup needed to start generated''' - await super().start() - await ProcessManager().link(self.KOBOLD_LINK_ID, ProcessType.KOBOLD) - self.uri = "http://127.0.0.1:{}".format(ProcessManager().get_process(ProcessType.KOBOLD).port) - - async def close(self) -> None: - '''Clean up resources before unloading''' - await super().close() - await ProcessManager().unlink(self.KOBOLD_LINK_ID, ProcessType.KOBOLD) - - async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if "max_context_length" in config_d: self.max_context_length = config_d["max_context_length"] - if "max_length" in config_d: self.max_length = config_d["max_length"] - if "rep_pen" in config_d: self.rep_pen = config_d["rep_pen"] - if "rep_pen_range" in config_d: self.rep_pen_range = config_d["rep_pen_range"] - if "rep_pen_slope" in config_d: self.rep_pen_slope = config_d["rep_pen_slope"] - if "temperature " in config_d: self.temperature = config_d["temperature "] - if "tfs" in config_d: self.tfs = config_d["tfs"] - if "top_a" in config_d: self.top_a = config_d["top_a"] - if "top_k" in config_d: self.top_k = config_d["top_k"] - if "top_p" in config_d: self.top_p = config_d["top_p"] - if "typical" in config_d: self.typical = config_d["typical"] - - assert self.max_context_length > 0 - assert self.max_length > 0 - assert self.rep_pen > 0 # TODO check the limits - assert self.rep_pen_range > 0 - assert self.rep_pen_slope > 0 - assert self.temperature > 0 - assert self.tfs >= 0 - assert self.top_a >= 0 - assert self.top_k >= 0 - assert 0 < self.top_p <= 1 - assert self.typical >= 0 - - async def get_configuration(self): - '''Returns values of configurable fields''' - return { - "max_context_length": self.max_context_length, - "max_length": self.max_length, - "rep_pen": self.rep_pen, - "rep_pen_range": self.rep_pen_range, - "rep_pen_range": self.rep_pen_range, - "rep_pen_slope": self.rep_pen_slope, - "temperature": self.temperature, - "tfs": self.tfs, - "top_a": self.top_a, - "top_k": self.top_k, - "top_p": self.top_p, - "typical": self.typical, - } - - async def _generate(self, instruction_prompt: str = None, messages: list = None, **kwargs): - history = [{ "role": "system", "content": instruction_prompt }] - for msg in messages: - next_hist = None - if isinstance(msg, ChatMessage) and msg.user == Prompter().character_name: - next_hist = { "role": "assistant", "content": msg.message } - else: - next_hist = { "role": "user", "content": msg.to_line() } - history.append(next_hist) - - response = requests.post( - "{}/v1/chat/completions".format(self.uri), - json={ - "model": "kcpp", - "messages": history, - "max_context_length": self.max_context_length, - "max_length": self.max_length, - "quiet": True, - "rep_pen": self.rep_pen, - "rep_pen_range": self.rep_pen_range, - "rep_pen_slope": self.rep_pen_slope, - "temperature": self.temperature, - "tfs": self.tfs, - "top_a": self.top_a, - "top_k": self.top_k, - "top_p": self.top_p, - "typical": self.typical - }, - ) - - if response.status_code == 200: - result = response.json()['choices'][0]['message']['content'] - yield {"content": result} - else: - raise Exception(f"Failed to get T2T result: {response.status_code} {response.reason}") \ No newline at end of file diff --git a/src/utils/operations/t2t/llamacpp.py b/src/utils/operations/t2t/llamacpp.py new file mode 100644 index 0000000..a3f69b4 --- /dev/null +++ b/src/utils/operations/t2t/llamacpp.py @@ -0,0 +1,245 @@ +import json + +import httpx + +from utils.helpers.subprocess_server import ( + allocate_port, + bin_executable, + start_shell_process, + stop_process, +) +from utils.prompter.message import ChatMessage + +from .base import T2TOperation + + +class LlamaCPPT2T(T2TOperation): + def __init__(self): + super().__init__("llamacpp") + self.uri = None + self.model_filepath = None + self.ctx_size = 0 # 0 = use model default (--ctx-size 0) + self.n_predict = 256 + self.temperature = 0.8 + self.top_p = 0.95 + self.top_k = 40 + self.min_p = 0.05 + self.typical_p = 1.0 + self.repeat_penalty = 1.1 + self.repeat_last_n = 64 + self.presence_penalty = 0.0 + self.frequency_penalty = 0.0 + self.dry_multiplier = 0.0 + self.dry_base = 1.75 + self.dry_allowed_length = 2 + self.dry_penalty_last_n = -1 + self.dry_sequence_breakers = ["\n", ":", '"', "*"] + self.samplers = ["dry", "top_k", "typ_p", "top_p", "min_p", "temperature"] + self._server_process = None + self._port: int | None = None + self._label: str = "llama.cpp server" + self._http: httpx.AsyncClient | None = None + + async def _start_server(self) -> None: + server = bin_executable("llama-server") + self._port = allocate_port() + cmd = f'"{server}" -m "{self.model_filepath}" --host 127.0.0.1 --port {self._port}' + if self.ctx_size > 0: + cmd += f" --ctx-size {self.ctx_size}" + self._server_process = start_shell_process(cmd, label=self._label) + self.uri = f"http://127.0.0.1:{self._port}" + self._http = httpx.AsyncClient(base_url=self.uri, timeout=httpx.Timeout(600.0)) + + async def _stop_server(self) -> None: + if self._http is not None: + await self._http.aclose() + self._http = None + stop_process(self._server_process, label=self._label) + self._server_process = None + self._port = None + self.uri = None + + async def _restart_server(self) -> None: + await self._stop_server() + await self._start_server() + + async def start(self) -> None: + await super().start() + if self._server_process is not None: + return + await self._start_server() + + async def close(self) -> None: + await self._stop_server() + await super().close() + + async def configure(self, config_d): + restart_for_ctx = False + if "model_filepath" in config_d: + self.model_filepath = str(config_d["model_filepath"]) + if "ctx_size" in config_d: + new_ctx_size = int(config_d["ctx_size"]) + if new_ctx_size != self.ctx_size: + self.ctx_size = new_ctx_size + restart_for_ctx = self._server_process is not None + if "n_predict" in config_d: + self.n_predict = int(config_d["n_predict"]) + if "max_tokens" in config_d: + self.n_predict = int(config_d["max_tokens"]) + if "max_length" in config_d: + self.n_predict = int(config_d["max_length"]) + if "temperature" in config_d: + self.temperature = float(config_d["temperature"]) + if "top_p" in config_d: + self.top_p = float(config_d["top_p"]) + if "top_k" in config_d: + self.top_k = int(config_d["top_k"]) + if "min_p" in config_d: + self.min_p = float(config_d["min_p"]) + if "typical_p" in config_d: + self.typical_p = float(config_d["typical_p"]) + if "repeat_penalty" in config_d: + self.repeat_penalty = float(config_d["repeat_penalty"]) + if "repeat_last_n" in config_d: + self.repeat_last_n = int(config_d["repeat_last_n"]) + if "presence_penalty" in config_d: + self.presence_penalty = float(config_d["presence_penalty"]) + if "frequency_penalty" in config_d: + self.frequency_penalty = float(config_d["frequency_penalty"]) + if "dry_multiplier" in config_d: + self.dry_multiplier = float(config_d["dry_multiplier"]) + if "dry_base" in config_d: + self.dry_base = float(config_d["dry_base"]) + if "dry_allowed_length" in config_d: + self.dry_allowed_length = int(config_d["dry_allowed_length"]) + if "dry_penalty_last_n" in config_d: + self.dry_penalty_last_n = int(config_d["dry_penalty_last_n"]) + if "dry_sequence_breakers" in config_d: + self.dry_sequence_breakers = list(config_d["dry_sequence_breakers"]) + if "samplers" in config_d: + self.samplers = list(config_d["samplers"]) + + assert self.model_filepath is not None and len(self.model_filepath) > 0 + assert self.ctx_size >= 0 + assert self.n_predict > 0 + assert self.temperature >= 0 + assert self.top_k >= 0 + assert 0 <= self.top_p <= 1 + assert 0 <= self.min_p <= 1 + assert self.typical_p > 0 + assert self.dry_multiplier >= 0 + assert self.dry_base > 0 + assert self.dry_allowed_length >= 0 + assert len(self.samplers) > 0 + + if restart_for_ctx: + await self._restart_server() + + async def get_configuration(self): + return { + "model_filepath": self.model_filepath, + "ctx_size": self.ctx_size, + "n_predict": self.n_predict, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "min_p": self.min_p, + "typical_p": self.typical_p, + "repeat_penalty": self.repeat_penalty, + "repeat_last_n": self.repeat_last_n, + "presence_penalty": self.presence_penalty, + "frequency_penalty": self.frequency_penalty, + "dry_multiplier": self.dry_multiplier, + "dry_base": self.dry_base, + "dry_allowed_length": self.dry_allowed_length, + "dry_penalty_last_n": self.dry_penalty_last_n, + "dry_sequence_breakers": self.dry_sequence_breakers, + "samplers": self.samplers, + } + + async def _check_health(self) -> None: + if self._http is None: + raise RuntimeError("LlamaCPPT2T server is not running") + try: + health_resp = await self._http.get("/health", timeout=5.0) + health_resp.raise_for_status() + except httpx.HTTPError as e: + raise RuntimeError(f"LlamaCPPT2T server health check failed: {e}") from e + + def _chat_messages(self, instruction_prompt: str, messages: list) -> list[dict[str, str]]: + history = [{"role": "system", "content": instruction_prompt}] + for msg in messages: + if isinstance(msg, ChatMessage) and msg.user == self.prompter.character_name: + history.append({"role": "assistant", "content": msg.message}) + else: + history.append({"role": "user", "content": msg.to_line()}) + return history + + async def _apply_chat_template(self, messages: list[dict[str, str]]) -> str: + response = await self._http.post("/apply-template", json={"messages": messages}) + response.raise_for_status() + prompt = response.json().get("prompt") + if not prompt: + raise RuntimeError("LlamaCPPT2T /apply-template returned no prompt") + return prompt + + def _completion_payload(self, prompt: str) -> dict: + payload = { + "prompt": prompt, + "stream": True, + "n_predict": self.n_predict, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "min_p": self.min_p, + "typical_p": self.typical_p, + "repeat_penalty": self.repeat_penalty, + "repeat_last_n": self.repeat_last_n, + "presence_penalty": self.presence_penalty, + "frequency_penalty": self.frequency_penalty, + "dry_multiplier": self.dry_multiplier, + "dry_base": self.dry_base, + "dry_allowed_length": self.dry_allowed_length, + "dry_penalty_last_n": self.dry_penalty_last_n, + "dry_sequence_breakers": self.dry_sequence_breakers, + "samplers": self.samplers, + } + if self.mcp_json_schema is not None: + payload["json_schema"] = self.mcp_json_schema + return payload + + async def _iter_completion_stream(self, prompt: str): + payload = self._completion_payload(prompt) + async with self._http.stream("POST", "/completion", json=payload) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + if not line.startswith("data:"): + continue + data = line.removeprefix("data:").strip() + if not data or data == "[DONE]": + if data == "[DONE]": + break + continue + try: + event = json.loads(data) + except json.JSONDecodeError: + continue + if event.get("stop"): + break + content = event.get("content") + if content: + yield content + + async def _generate(self, instruction_prompt: str = None, messages: list = None, **kwargs): + if self.prompter is None: + raise RuntimeError("LlamaCPPT2T missing runtime dependency: prompter") + + await self._check_health() + + chat_messages = self._chat_messages(instruction_prompt, messages) + try: + prompt = await self._apply_chat_template(chat_messages) + async for content_chunk in self._iter_completion_stream(prompt): + yield {"content": content_chunk} + except httpx.HTTPError as e: + raise Exception(f"Failed to get T2T result: {e}") from e diff --git a/src/utils/operations/t2t/openai.py b/src/utils/operations/t2t/openai.py index bee71a7..3bf19ce 100644 --- a/src/utils/operations/t2t/openai.py +++ b/src/utils/operations/t2t/openai.py @@ -1,49 +1,57 @@ from openai import AsyncOpenAI +from typing import Any -from .base import T2TOperation from utils.prompter.message import ChatMessage -from utils.prompter import Prompter + +from .base import T2TOperation + class OpenAIT2T(T2TOperation): def __init__(self): super().__init__("openai") self.client = None - + self.base_url = "https://api.openai.com/v1/" self.model = "gpt-4o" self.temperature = 1 self.top_p = 0.9 self.presence_penalty = 0 self.frequency_penalty = 0 - + async def start(self): await super().start() self.client = AsyncOpenAI(base_url=self.base_url) - + async def close(self): await super().close() await self.client.close() self.client = None - + async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if "base_url" in config_d: self.base_url = str(config_d['base_url']) - if "model" in config_d: self.model = str(config_d['model']) - - if "temperature" in config_d: self.temperature = float(config_d['temperature']) - if "top_p" in config_d: self.top_p = float(config_d['top_p']) - if "presence_penalty" in config_d: self.presence_penalty = float(config_d['presence_penalty']) - if "frequency_penalty" in config_d: self.frequency_penalty = float(config_d['frequency_penalty']) - + """Configure and validate operation-specific configuration""" + if "base_url" in config_d: + self.base_url = str(config_d["base_url"]) + if "model" in config_d: + self.model = str(config_d["model"]) + + if "temperature" in config_d: + self.temperature = float(config_d["temperature"]) + if "top_p" in config_d: + self.top_p = float(config_d["top_p"]) + if "presence_penalty" in config_d: + self.presence_penalty = float(config_d["presence_penalty"]) + if "frequency_penalty" in config_d: + self.frequency_penalty = float(config_d["frequency_penalty"]) + assert self.base_url is not None and len(self.base_url) > 0 assert self.model is not None and len(self.model) > 0 assert self.temperature >= 0 and self.temperature <= 2 assert self.top_p >= 0 and self.top_p <= 1 assert self.presence_penalty >= 0 and self.presence_penalty <= 1 assert self.frequency_penalty >= 0 and self.frequency_penalty <= 1 - + async def get_configuration(self): - '''Returns values of configurable fields''' + """Returns values of configurable fields""" return { "base_url": self.base_url, "model": self.model, @@ -53,25 +61,45 @@ async def get_configuration(self): "frequency_penalty": self.frequency_penalty, } + def _openai_mcp_response_format(self) -> dict[str, Any] | None: + """OpenAI `response_format` for ``mcp_json_schema``, or None if unset.""" + if self.mcp_json_schema is None: + return None + return { + "type": "json_schema", + "json_schema": { + "name": "mcp_output", + "strict": True, + "schema": self.mcp_json_schema, + }, + } + async def _generate(self, instruction_prompt: str = None, messages: list = None, **kwargs): - history = [{ "role": "system", "content": instruction_prompt }] + if self.prompter is None: + raise RuntimeError("OpenAIT2T missing runtime dependency: prompter") + history = [{"role": "system", "content": instruction_prompt}] for msg in messages: next_hist = None - if isinstance(msg, ChatMessage) and msg.user == Prompter().character_name: - next_hist = { "role": "assistant", "content": msg.message } + if isinstance(msg, ChatMessage) and msg.user == self.prompter.character_name: + next_hist = {"role": "assistant", "content": msg.message} else: - next_hist = { "role": "user", "content": msg.to_line() } + next_hist = {"role": "user", "content": msg.to_line()} history.append(next_hist) - stream = await self.client.chat.completions.create( - messages=history, - model=self.model, - stream=True, - temperature=self.temperature, - top_p=self.top_p, - presence_penalty=self.presence_penalty, - frequency_penalty=self.frequency_penalty - ) + create_kwargs = { + "messages": history, + "model": self.model, + "stream": True, + "temperature": self.temperature, + "top_p": self.top_p, + "presence_penalty": self.presence_penalty, + "frequency_penalty": self.frequency_penalty, + } + response_format = self.openai_mcp_response_format() + if response_format is not None: + create_kwargs["response_format"] = response_format + + stream = await self.client.chat.completions.create(**create_kwargs) full_response = "" async for chunk in stream: diff --git a/src/utils/operations/tts/azure.py b/src/utils/operations/tts/azure.py index cd3eec2..f91e37e 100644 --- a/src/utils/operations/tts/azure.py +++ b/src/utils/operations/tts/azure.py @@ -1,63 +1,64 @@ import os import wave from io import BytesIO -import azure.cognitiveservices.speech as speechsdk -from utils.config import Config +import azure.cognitiveservices.speech as speechsdk from .base import TTSOperation + class AzureTTS(TTSOperation): def __init__(self): super().__init__("azure") self.client = None - + self.voice: str = "en-US-AshleyNeural" - + async def start(self) -> None: - '''General setup needed to start generated''' + """General setup needed to start generated""" await super().start() - + self.speech_config = speechsdk.SpeechConfig( - region=os.environ.get('AZURE_REGION'), - subscription=os.getenv("AZURE_API_KEY") + region=os.environ.get("AZURE_REGION"), subscription=os.getenv("AZURE_API_KEY") ) self.speech_config.speech_synthesis_voice_name = self.voice - self.speech_config.set_speech_synthesis_output_format(speechsdk.SpeechSynthesisOutputFormat.Riff48Khz16BitMonoPcm) + self.speech_config.set_speech_synthesis_output_format( + speechsdk.SpeechSynthesisOutputFormat.Riff48Khz16BitMonoPcm + ) # set timeout value to bigger ones to avoid sdk cancel the request when GPT latency too high - self.speech_config.set_property(speechsdk.PropertyId.SpeechSynthesis_FrameTimeoutInterval, "100000000") - self.speech_config.set_property(speechsdk.PropertyId.SpeechSynthesis_RtfTimeoutThreshold, "10") - - self.speech_synthesizer = speechsdk.SpeechSynthesizer(speech_config=self.speech_config, audio_config=None) - + self.speech_config.set_property( + speechsdk.PropertyId.SpeechSynthesis_FrameTimeoutInterval, "100000000" + ) + self.speech_config.set_property( + speechsdk.PropertyId.SpeechSynthesis_RtfTimeoutThreshold, "10" + ) + + self.speech_synthesizer = speechsdk.SpeechSynthesizer( + speech_config=self.speech_config, audio_config=None + ) + async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if "voice" in config_d: self.voice = str(config_d['voice']) - + """Configure and validate operation-specific configuration""" + if "voice" in config_d: + self.voice = str(config_d["voice"]) + assert self.voice is not None and len(self.voice) > 0 - + async def get_configuration(self): - '''Returns values of configurable fields''' - return { - "voice": self.voice - } + """Returns values of configurable fields""" + return {"voice": self.voice} async def _generate(self, content: str = None, **kwargs): - '''Generate a output stream''' + """Generate a output stream""" # create request with TextStream input type result = self.speech_synthesizer.speak_text_async(content).get() - + output_b = BytesIO(result.audio_data) - + with wave.open(output_b, "r") as f: sr = f.getframerate() sw = f.getsampwidth() ch = f.getnchannels() ab = f.readframes(f.getnframes()) - - yield { - "audio_bytes": ab, - "sr": sr, - "sw": sw, - "ch": ch - } \ No newline at end of file + + yield {"audio_bytes": ab, "sr": sr, "sw": sw, "ch": ch} diff --git a/src/utils/operations/tts/base.py b/src/utils/operations/tts/base.py index 6b7ca68..582956d 100644 --- a/src/utils/operations/tts/base.py +++ b/src/utils/operations/tts/base.py @@ -1,4 +1,4 @@ -''' +""" TTS Operations (at minimum) require the following fields for input chunks: - content: (str) Text to generate speech for @@ -7,51 +7,48 @@ - sr: (int) sample rate - sw: (int) sample width - ch: (int) audio channels -''' +""" -from typing import Dict, Any, AsyncGenerator +from collections.abc import AsyncGenerator +from typing import Any from ..base import Operation + class TTSOperation(Operation): def __init__(self, op_id: str): super().__init__("TTS", op_id) - + ## TO BE OVERRIDEN #### async def start(self) -> None: - '''General setup needed to start generated''' + """General setup needed to start generated""" await super().start() - + async def close(self) -> None: - '''Clean up resources before unloading''' + """Clean up resources before unloading""" await super().close() - - async def _parse_chunk(self, chunk_in: Dict[str, Any]) -> Dict[str, Any]: - '''Extract information from input for use in _generate''' + + async def _parse_chunk(self, chunk_in: dict[str, Any]) -> dict[str, Any]: + """Extract information from input for use in _generate""" assert "content" in chunk_in assert isinstance(chunk_in["content"], str) assert len(chunk_in["content"]) > 0 - - return { - "content": chunk_in["content"] - } - + + return {"content": chunk_in["content"]} + ## TO BE IMPLEMENTED #### - async def configure(self, config_d: Dict[str, Any]): - '''Configure and validate operation-specific configuration''' + async def configure(self, config_d: dict[str, Any]): + """Configure and validate operation-specific configuration""" raise NotImplementedError - - async def get_configuration(self) -> Dict[str, Any]: - '''Returns values of configurable fields''' + + async def get_configuration(self) -> dict[str, Any]: + """Returns values of configurable fields""" raise NotImplementedError - - async def _generate(self, content: str = None, **kwargs) -> AsyncGenerator[Dict[str, Any], None]: - '''Generate a output stream''' + + async def _generate( + self, content: str = None, **kwargs + ) -> AsyncGenerator[dict[str, Any], None]: + """Generate a output stream""" raise NotImplementedError - - yield { - "audio_bytes": b'', - "sr": 123, - "sw": 123, - "ch": 123 - } \ No newline at end of file + + yield {"audio_bytes": b"", "sr": 123, "sw": 123, "ch": 123} diff --git a/src/utils/operations/tts/fish.py b/src/utils/operations/tts/fish.py index 33ebde5..dfc6597 100644 --- a/src/utils/operations/tts/fish.py +++ b/src/utils/operations/tts/fish.py @@ -1,44 +1,48 @@ -from fish_audio_sdk import AsyncWebSocketSession, TTSRequest import os -from utils.config import Config +from fish_audio_sdk import AsyncWebSocketSession, TTSRequest from .base import TTSOperation + class FishTTS(TTSOperation): def __init__(self): super().__init__("fish") self.session = None - + self.model_id = "c9198512a4164a18b11a3bf96e5c668f" self.backend = "speech-1.6" self.normalize = False self.latency = "normal" - + async def start(self) -> None: - '''General setup needed to start generated''' + """General setup needed to start generated""" await super().start() self.session = AsyncWebSocketSession(os.getenv("FISH_API_KEY")) - + async def close(self) -> None: - '''Clean up resources before unloading''' + """Clean up resources before unloading""" await super().close() await self.session.close() self.session = None - + async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if "model_id" in config_d: self.model_id = str(config_d["model_id"]) - if "backend" in config_d: self.backend = str(config_d["backend"]) - if "normalize" in config_d: self.normalize = bool(config_d["normalize"]) - if "latency" in config_d: self.latency = str(config_d["latency"]) - + """Configure and validate operation-specific configuration""" + if "model_id" in config_d: + self.model_id = str(config_d["model_id"]) + if "backend" in config_d: + self.backend = str(config_d["backend"]) + if "normalize" in config_d: + self.normalize = bool(config_d["normalize"]) + if "latency" in config_d: + self.latency = str(config_d["latency"]) + assert self.model_id is not None and len(self.model_id) > 0 assert self.backend is not None and len(self.backend) > 0 - assert self.latency in ['normal', 'balanced'] - + assert self.latency in ["normal", "balanced"] + async def get_configuration(self): - '''Returns values of configurable fields''' + """Returns values of configurable fields""" return { "model_id": self.model_id, "backend": self.backend, @@ -47,23 +51,19 @@ async def get_configuration(self): } async def _generate(self, content: str = None, **kwargs): - '''Generate a output stream''' + """Generate a output stream""" tts_request = TTSRequest( text=content, format="pcm", normalize=self.normalize, latency=self.latency, - reference_id=self.model_id + reference_id=self.model_id, ) - b = b'' - async for chunk in self.session.tts( - tts_request, - self._stream(), - backend=self.backend - ): + b = b"" + async for chunk in self.session.tts(tts_request, self._stream(), backend=self.backend): b += chunk - + yield {"audio_bytes": b, "sr": 44100, "sw": 2, "ch": 1} - + async def _stream(self): - yield "" \ No newline at end of file + yield "" diff --git a/src/utils/operations/tts/kobold.py b/src/utils/operations/tts/kobold.py deleted file mode 100644 index 2ac5dca..0000000 --- a/src/utils/operations/tts/kobold.py +++ /dev/null @@ -1,58 +0,0 @@ -import requests -from io import BytesIO -import wave - -from utils.config import Config -from utils.processes import ProcessManager, ProcessType - -from .base import TTSOperation - -class KoboldTTS(TTSOperation): - KOBOLD_LINK_ID = "kobold_tts" - - def __init__(self): - super().__init__("kobold") - self.uri = None - - self.voice = "kobo" - - async def start(self) -> None: - '''General setup needed to start generated''' - await super().start() - await ProcessManager().link(self.KOBOLD_LINK_ID, ProcessType.KOBOLD) - self.uri = "http://127.0.0.1:{}".format(ProcessManager().get_process(ProcessType.KOBOLD).port) - - async def close(self) -> None: - '''Clean up resources before unloading''' - await super().close() - await ProcessManager().unlink(self.KOBOLD_LINK_ID, ProcessType.KOBOLD) - - async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if "voice" in config_d: self.voice = str(config_d["voice"]) - - assert self.voice is not None and len(self.voice) > 0 - - async def get_configuration(self): - '''Returns values of configurable fields''' - return { - "voice": self.voice - } - - async def _generate(self, content: str = None, **kwargs): - response = requests.post( - "{}/api/extra/tts".format(self.uri), - json={ - "input": content, - "voice": self.voice, - "speaker_json": "" - }, - ) - - if response.status_code == 200: - result = response.content - audio = BytesIO(result) - with wave.open(audio, 'r') as f: - yield {"audio_bytes": f.readframes(f.getnframes()), "sr": f.getframerate(), "sw": f.getsampwidth(), "ch": f.getnchannels()} - else: - raise Exception(f"Failed to get T2T result: {response.status_code} {response.reason}") \ No newline at end of file diff --git a/src/utils/operations/tts/melo.py b/src/utils/operations/tts/melo.py index be5c8c8..49892f3 100644 --- a/src/utils/operations/tts/melo.py +++ b/src/utils/operations/tts/melo.py @@ -1,67 +1,74 @@ import wave from io import BytesIO -from melo.api import TTS -import logging -import numpy as np -import torch -import soundfile -from utils.config import Config +import soundfile +import torch +from melo.api import TTS from .base import TTSOperation + class MeloTTS(TTSOperation): SAMPLE_RATE = 44100 SAMPLE_WIDTH = 2 CHANNELS = 1 - + def __init__(self): super().__init__("melo") self.model = None - self.speaker_ids = dict() - + self.speaker_ids = {} + self.config_filepath = None self.model_filepath = None self.speaker_id = None self.device = "cpu" self.language = "EN" - + self.sdp_ratio = 0.2 self.noise_scale = 0.6 self.noise_scale_w = 0.8 self.speed = 1.0 - + async def start(self) -> None: - '''General setup needed to start generated''' + """General setup needed to start generated""" await super().start() self.model = TTS( language=self.language, device=self.device, config_path=self.config_filepath, - ckpt_path=self.model_filepath + ckpt_path=self.model_filepath, ) self.speaker_ids = self.model.hps.data.spk2id - + async def close(self) -> None: - '''Clean up resources before unloading''' + """Clean up resources before unloading""" await super().close() del self.model self.model = None - self.speaker_ids = dict() + self.speaker_ids = {} async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if config_d.get("config_filepath", None): self.config_filepath = str(config_d['config_filepath']) - if config_d.get("model_filepath", None): self.model_filepath = str(config_d['model_filepath']) - if "speaker_id" in config_d: self.speaker_id = str(config_d['speaker_id']) - if "device" in config_d: self.device = str(config_d['device']) - if "language" in config_d: self.language = str(config_d['language']) - - if "sdp_ratio" in config_d: self.sdp_ratio = float(config_d["sdp_ratio"]) - if "noise_scale" in config_d: self.noise_scale = float(config_d["noise_scale"]) - if "noise_scale_w" in config_d: self.noise_scale_w = float(config_d["noise_scale_w"]) - if "speed" in config_d: self.speed = float(config_d["speed"]) - + """Configure and validate operation-specific configuration""" + if config_d.get("config_filepath", None): + self.config_filepath = str(config_d["config_filepath"]) + if config_d.get("model_filepath", None): + self.model_filepath = str(config_d["model_filepath"]) + if "speaker_id" in config_d: + self.speaker_id = str(config_d["speaker_id"]) + if "device" in config_d: + self.device = str(config_d["device"]) + if "language" in config_d: + self.language = str(config_d["language"]) + + if "sdp_ratio" in config_d: + self.sdp_ratio = float(config_d["sdp_ratio"]) + if "noise_scale" in config_d: + self.noise_scale = float(config_d["noise_scale"]) + if "noise_scale_w" in config_d: + self.noise_scale_w = float(config_d["noise_scale_w"]) + if "speed" in config_d: + self.speed = float(config_d["speed"]) + assert self.speaker_id is not None and len(self.speaker_id) > 0 assert self.device is not None and len(self.device) > 0 assert self.language is not None and len(self.language) > 0 @@ -69,19 +76,19 @@ async def configure(self, config_d): assert self.noise_scale < 1.25 and self.noise_scale >= 0 assert self.noise_scale_w < 1.25 and self.noise_scale_w >= 0 assert self.speed > 0 - + async def get_configuration(self): - '''Returns values of configurable fields''' + """Returns values of configurable fields""" return { "config_filepath": self.config_filepath, "model_filepath": self.model_filepath, "speaker_id": self.speaker_id, "device": self.device, - "language": self.language + "language": self.language, } async def _generate(self, content: str = None, **kwargs): - '''Generate a output stream''' + """Generate a output stream""" ab_np = self.model.tts_to_file( content, self.speaker_ids[self.speaker_id], @@ -90,16 +97,16 @@ async def _generate(self, content: str = None, **kwargs): noise_scale=self.noise_scale, noise_scale_w=self.noise_scale_w, speed=self.speed, - quiet=True + quiet=True, ) ab = torch.from_numpy(ab_np).float() audio_buffer = BytesIO() - soundfile.write(audio_buffer, ab, self.SAMPLE_RATE, format='WAV', subtype='PCM_16') + soundfile.write(audio_buffer, ab, self.SAMPLE_RATE, format="WAV", subtype="PCM_16") audio_buffer.seek(0) - with wave.open(audio_buffer, 'r') as f: + with wave.open(audio_buffer, "r") as f: yield { "audio_bytes": f.readframes(f.getnframes()), "sr": self.SAMPLE_RATE, "sw": self.SAMPLE_WIDTH, - "ch": self.CHANNELS - } \ No newline at end of file + "ch": self.CHANNELS, + } diff --git a/src/utils/operations/tts/openai.py b/src/utils/operations/tts/openai.py index 224ee66..349f399 100644 --- a/src/utils/operations/tts/openai.py +++ b/src/utils/operations/tts/openai.py @@ -1,51 +1,50 @@ import wave from io import BytesIO -from openai import AsyncOpenAI -from utils.config import Config +from openai import AsyncOpenAI from .base import TTSOperation + class OpenAITTS(TTSOperation): def __init__(self): super().__init__("openai") self.client = None - + self.base_url = "https://api.openai.com/v1/" self.voice = "nova" self.model = "tts-1" - + async def start(self) -> None: - '''General setup needed to start generated''' + """General setup needed to start generated""" await super().start() self.client = AsyncOpenAI(base_url=self.base_url) - + async def close(self) -> None: - '''Clean up resources before unloading''' + """Clean up resources before unloading""" await super().close() await self.client.close() self.client = None async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if "base_url" in config_d: self.base_url = str(config_d["base_url"]) - if "voice" in config_d: self.voice = str(config_d["voice"]) - if "model" in config_d: self.model = str(config_d["model"]) - + """Configure and validate operation-specific configuration""" + if "base_url" in config_d: + self.base_url = str(config_d["base_url"]) + if "voice" in config_d: + self.voice = str(config_d["voice"]) + if "model" in config_d: + self.model = str(config_d["model"]) + assert self.base_url is not None and len(self.base_url) > 0 assert self.voice is not None and len(self.voice) > 0 assert self.model is not None and len(self.model) > 0 - + async def get_configuration(self): - '''Returns values of configurable fields''' - return { - "base_url": self.base_url, - "voice": self.voice, - "model": self.model - } + """Returns values of configurable fields""" + return {"base_url": self.base_url, "voice": self.voice, "model": self.model} async def _generate(self, content: str = None, **kwargs): - '''Generate a output stream''' + """Generate a output stream""" async with self.client.audio.speech.with_streaming_response.create( model=self.model, voice=self.voice, @@ -53,16 +52,11 @@ async def _generate(self, content: str = None, **kwargs): response_format="wav", ) as response: output_b = BytesIO(await response.read()) - + with wave.open(output_b, "r") as f: sr = f.getframerate() sw = f.getsampwidth() ch = f.getnchannels() ab = f.readframes(f.getnframes()) - - yield { - "audio_bytes": ab, - "sr": sr, - "sw": sw, - "ch": ch - } \ No newline at end of file + + yield {"audio_bytes": ab, "sr": sr, "sw": sw, "ch": ch} diff --git a/src/utils/operations/tts/pytts.py b/src/utils/operations/tts/pytts.py index 553dcdd..d101182 100644 --- a/src/utils/operations/tts/pytts.py +++ b/src/utils/operations/tts/pytts.py @@ -1,38 +1,42 @@ -''' +""" Class implementing TTS generation using old-school speech synthesis. -This version runs entirely offline.This may require espeak for Linux. -Voices available will differ between OS, and available voices for your +This version runs entirely offline.This may require espeak for Linux. +Voices available will differ between OS, and available voices for your OS can be found using get_available_voices -''' +""" import logging -import pyttsx3 -import wave import os +import wave +import pyttsx3 + +from utils.config import config from utils.helpers.path import portable_path -from utils.config import Config from .base import TTSOperation + class PyttsTTS(TTSOperation): def __init__(self): super().__init__("pytts") self.engine = None - + self.voice: str = None - self.gender: str = 'female' - self.working_file: str = portable_path(os.path.join(Config().WORKING_DIR,'ttsg-synth-out.wav')) + self.gender: str = "female" + self.working_file: str = portable_path( + os.path.join(config.WORKING_DIR, "ttsg-synth-out.wav") + ) async def start(self): await super().start() - + self.engine = pyttsx3.init() - voices = self.engine.getProperty('voices') - logging.info("Operation {}: Available voices are: {}".format(self.op_id, list(map(lambda x: x.id, voices)))) - - self.engine.setProperty('voice', self.voice) - self.engine.setProperty('gender', self.gender) + voices = self.engine.getProperty("voices") + logging.info(f"Operation {self.op_id}: Available voices are: {[x.id for x in voices]}") + + self.engine.setProperty("voice", self.voice) + self.engine.setProperty("gender", self.gender) async def close(self): await super().close() @@ -40,31 +44,30 @@ async def close(self): self.engine = None async def configure(self, config_d): - '''Configure and validate operation-specific configuration''' - if "voice" in config_d: self.voice = str(config_d['voice']) - if "gender" in config_d: self.gender = str(config_d['gender']) - if "working_file" in config_d: self.working_file = str(config_d['working_file']) - + """Configure and validate operation-specific configuration""" + if "voice" in config_d: + self.voice = str(config_d["voice"]) + if "gender" in config_d: + self.gender = str(config_d["gender"]) + if "working_file" in config_d: + self.working_file = str(config_d["working_file"]) + assert self.voice is not None and len(self.voice) > 0 assert self.working_file is not None and len(self.working_file) > 0 - + async def get_configuration(self): - '''Returns values of configurable fields''' - return { - "voice": self.voice, - "gender": self.gender, - "working_file": self.working_file - } + """Returns values of configurable fields""" + return {"voice": self.voice, "gender": self.gender, "working_file": self.working_file} async def _generate(self, content: str = None, **kwargs): - '''Generate a output stream''' + """Generate a output stream""" self.engine.save_to_file(content, self.working_file) self.engine.runAndWait() - - with wave.open(self.working_file, 'r') as f: + + with wave.open(self.working_file, "r") as f: yield { "audio_bytes": f.readframes(f.getnframes()), "sr": f.getframerate(), "sw": f.getsampwidth(), - "ch": f.getnchannels() - } \ No newline at end of file + "ch": f.getnchannels(), + } diff --git a/src/utils/processes/__init__.py b/src/utils/processes/__init__.py deleted file mode 100644 index e05866c..0000000 --- a/src/utils/processes/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .error import * -from .manager import ProcessManager, ProcessType \ No newline at end of file diff --git a/src/utils/processes/base.py b/src/utils/processes/base.py deleted file mode 100644 index 083f4bb..0000000 --- a/src/utils/processes/base.py +++ /dev/null @@ -1,59 +0,0 @@ -import subprocess -import psutil -import logging - -from .error import DuplicateLink, MissingLink - -class BaseProcess(): # Be sure to make it a singleton (metaclass=Singleton) - id: str = None - process: subprocess.Popen = None - port: int = None - links: set = set() - - reload_signal: bool = False - unload_signal: bool = False - - def __init__(self, id): - self.id = id - - async def reload(self): - # This needs to be implemented - self.reload_signal = False - - async def unload(self): - self.unload_signal = False - - if len(self.links): - logging.warning(f"Unloading process that still has the {len(self.links)} links") - logging.warning(f"Links: {self.links}") - - if self.process: - ps_process = psutil.Process(self.process.pid) - for child in ps_process.children(recursive=True): - child.kill() - ps_process.kill() - self.process = None - self.port = None - logging.info(f"Unloaded process {self.id}") - else: - logging.warning(f"Attempted to unload process {self.id} when it is already unloaded") - - async def link(self, link_id): - logging.debug("Adding link {} to process {}".format(link_id, self.id)) - if link_id in self.links: - raise DuplicateLink(link_id, self.id) - - if self.process is None: - await self.reload() - - self.links.add(link_id) # Add to links after loading process to ensure link established - - async def unlink(self, link_id): - logging.debug("Removing link {} from process {}".format(link_id, self.id)) - if link_id not in self.links: - raise MissingLink(link_id, self.id) - self.links.remove(link_id) - - if len(self.links): - logging.info(f"No more links to process {self.id}. Unloading...") - await self.unload() diff --git a/src/utils/processes/error.py b/src/utils/processes/error.py deleted file mode 100644 index 7e8abec..0000000 --- a/src/utils/processes/error.py +++ /dev/null @@ -1,15 +0,0 @@ -class UnknownProcessError(Exception): - def __init__(self, process): - super().__init__("No process {} exists".format(process)) - -class UnloadedProcessError(Exception): - def __init__(self, process): - super().__init__("Process {} is not loaded".format(process)) - -class DuplicateLink(Exception): - def __init__(self, link_id, process): - super().__init__("Link ID {} already linked to process {}".format(link_id, process)) - -class MissingLink(Exception): - def __init__(self, link_id, process): - super().__init__("Link ID {} is not linked to process {}".format(link_id, process)) \ No newline at end of file diff --git a/src/utils/processes/manager.py b/src/utils/processes/manager.py deleted file mode 100644 index 387a7cf..0000000 --- a/src/utils/processes/manager.py +++ /dev/null @@ -1,74 +0,0 @@ -''' -Global processes manager - -Enables expensive processes used in one place to be reused elsewhere -For example: Kobold server shared between STT and T2T operation implementation -''' - -import logging -from enum import Enum - -from utils.helpers.singleton import Singleton - -from .error import UnknownProcessError, UnloadedProcessError - -class ProcessType(Enum): - KOBOLD = "kobold" - -class ProcessManager(metaclass=Singleton): - loaded_processes = dict() - - '''Perform initial load''' - async def load(self, process_type: ProcessType): - logging.info("Loading process by type {}".format(process_type.value)) - match process_type: - case ProcessType.KOBOLD: - from .processes.koboldcpp import KoboldCPPProcess - self.loaded_processes[ProcessType.KOBOLD] = KoboldCPPProcess() - await self.loaded_processes[ProcessType.KOBOLD].reload() - case _: - raise UnknownProcessError(process_type) - - '''Reload any process where reload_signal is True''' - async def reload(self): - for process_type in self.loaded_processes: - if self.loaded_processes[process_type] and self.loaded_processes[process_type].reload_signal: - logging.info("Reloading process {}".format(self.loaded_processes[process_type].id)) - await self.loaded_processes[process_type].reload() - - '''Unload any process where unload_signal is True''' - async def unload(self): - for process_type in self.loaded_processes: - if self.loaded_processes[process_type] and self.loaded_processes[process_type].unload_signal: - logging.info("Unloading process {}".format(self.loaded_processes[process_type].id)) - await self.loaded_processes[process_type].unload() - - async def link(self, link_id: str, process_type: ProcessType): - if not (process_type in self.loaded_processes and self.loaded_processes[process_type]): - await self.load(process_type) - - await self.loaded_processes[process_type].link(link_id) - - async def unlink(self, link_id: str, process_type: ProcessType): - if not (process_type in self.loaded_processes and self.loaded_processes[process_type]): - raise UnloadedProcessError(process_type.value) - - await self.loaded_processes[process_type].unlink(link_id) - - def signal_reload(self, process_type: ProcessType): - if not (process_type in self.loaded_processes and self.loaded_processes[process_type]): - raise UnloadedProcessError(process_type.value) - - self.loaded_processes[process_type].reload_signal = True - - def signal_unload(self, process_type: ProcessType): - if not (process_type in self.loaded_processes and self.loaded_processes[process_type]): - raise UnloadedProcessError(process_type.value) - - self.loaded_processes[process_type].unload_signal = True - - def get_process(self, process_type: ProcessType): - if not (process_type in self.loaded_processes and self.loaded_processes[process_type]): - raise UnloadedProcessError(process_type.value) - - return self.loaded_processes[process_type] \ No newline at end of file diff --git a/src/utils/processes/processes/koboldcpp.py b/src/utils/processes/processes/koboldcpp.py deleted file mode 100644 index ad62104..0000000 --- a/src/utils/processes/processes/koboldcpp.py +++ /dev/null @@ -1,32 +0,0 @@ -import logging -import subprocess -from subprocess import DEVNULL -import socket -from utils.config import Config -from utils.helpers.singleton import Singleton -from ..base import BaseProcess - -class KoboldCPPProcess(BaseProcess, metaclass=Singleton): - def __init__(self): - super().__init__("koboldcpp") - self.reload_signal = True - - async def reload(self): - # Close any existing servers - if self.process is not None: - await self.unload() - - await super().reload() - - # Find open port - config = Config() - sock = socket.socket() - sock.bind(('', 0)) - self.port = sock.getsockname()[1] - sock.close() - - # Start Kobold server on that port - cmd = '{} --quiet --config "{}" --port {}'.format(config.kobold_filepath, config.kcpps_filepath, self.port) - logging.debug(f"Running Koboldcpp server using command: \"{cmd}\"") - self.process = subprocess.Popen(cmd, shell=True, stdout=DEVNULL, stderr=DEVNULL) - logging.info(f"Opened Koboldcpp server (PID: {self.process.pid}) on port {self.port}") \ No newline at end of file diff --git a/src/utils/prompter/__init__.py b/src/utils/prompter/__init__.py index ea4d105..16f4f64 100644 --- a/src/utils/prompter/__init__.py +++ b/src/utils/prompter/__init__.py @@ -1 +1 @@ -from .prompter import Prompter \ No newline at end of file +from .prompter import Prompter as Prompter diff --git a/src/utils/prompter/context.py b/src/utils/prompter/context.py index 17611a0..fb5cd4a 100644 --- a/src/utils/prompter/context.py +++ b/src/utils/prompter/context.py @@ -2,7 +2,7 @@ class ContextMetadata: def __init__(self, id: str, name: str, description: str): assert id and len(id) > 0 assert name and len(name) > 0 - + self.id: str = id self.name: str = name - self.description: str = description or "" \ No newline at end of file + self.description: str = description or "" diff --git a/src/utils/prompter/message.py b/src/utils/prompter/message.py index 7cf585c..eadc078 100644 --- a/src/utils/prompter/message.py +++ b/src/utils/prompter/message.py @@ -1,101 +1,104 @@ import datetime + from .context import ContextMetadata + class Message: def to_line(): raise NotImplementedError - + def to_dict(): raise NotImplementedError - + + class RawMessage(Message): - def __init__(self, message: str,): + def __init__( + self, + message: str, + ): assert message is not None and len(message) > 0 - + self.message = message.replace("\n", "") - + def to_line(self): return self.message - + def to_dict(self): - return { - "type": "raw", - "message": self.message - } - + return {"type": "raw", "message": self.message} + + class RequestMessage(Message): def __init__(self, message: str, time: datetime.datetime): assert message is not None and len(message) > 0 - + self.message = message.replace("\n", "") self.time = time - + def to_line(self): return f"[REQUEST]: {self.message}" - + def to_dict(self): - return { - "type": "request", - "time": self.time.timestamp(), - "message": self.message - } - + return {"type": "request", "time": self.time.timestamp(), "message": self.message} + + class ChatMessage(Message): def __init__(self, user: str, message: str, time: datetime.datetime): assert user is not None assert message is not None and len(message) > 0 - + self.user = user self.message = message.replace("\n", "") self.time = time - + def to_line(self): return f"[{self.user}]: {self.message}" - + def to_dict(self): return { "type": "chat", "user": self.user, "time": self.time.timestamp(), - "message": self.message + "message": self.message, } + class MCPMessage(Message): def __init__(self, tool_name: str, result: str, time: datetime.datetime): assert tool_name is not None assert result is not None and len(result) > 0 - + self.tool_name = tool_name self.result = result.replace("\n", "") self.time = time - + def to_line(self): return f"[MCP#{self.tool_name}]: {self.result}" - + def to_dict(self): return { "type": "tool", "tool": self.tool_name, "time": self.time.timestamp(), - "message": self.result + "message": self.result, } - + + class CustomMessage(Message): def __init__(self, context_metadata: ContextMetadata, message: str, time: datetime.datetime): assert context_metadata is not None assert message is not None and len(message) > 0 - + self.context_metadata = context_metadata self.message = message.replace("\n", "") self.time = time - + def to_line(self): return f"[CONTEXT#{self.context_metadata.name}]: {self.message}" - + def to_dict(self): return { "type": "custom", "id": self.context_metadata.id, "time": self.time.timestamp(), - "message": self.message - } \ No newline at end of file + "message": self.message, + } diff --git a/src/utils/prompter/prompter.py b/src/utils/prompter/prompter.py index c697ccb..66c01e6 100644 --- a/src/utils/prompter/prompter.py +++ b/src/utils/prompter/prompter.py @@ -1,193 +1,226 @@ - -import os import datetime -from typing import AsyncGenerator, Dict, List, Any -from utils.helpers.time import get_current_time -from utils.helpers.singleton import Singleton +import os +from collections.abc import AsyncGenerator +from typing import Any + +from utils.config import config from utils.helpers.path import portable_path -from utils.config import Config +from utils.helpers.time import get_current_time + from .context import ContextMetadata -from .message import Message, ChatMessage, RequestMessage, MCPMessage, CustomMessage +from .message import ChatMessage, CustomMessage, MCPMessage, Message, RequestMessage -class Prompter(metaclass=Singleton): + +class Prompter: def __init__(self): - self.context_metadata: Dict[str, ContextMetadata] = dict() - self.history: List[Message] = list() - - self.instruction_prompt_filename: str = 'example.txt' - self.character_prompt_filename: str = 'example.txt' - self.scene_prompt_filename: str = 'example.txt' + self.context_metadata: dict[str, ContextMetadata] = {} + self.history: list[Message] = [] + + self.instruction_prompt_filename: str = "example.txt" + self.character_prompt_filename: str = "example.txt" + self.scene_prompt_filename: str = "example.txt" self.character_name: str = "J.A.I.son" - self.name_translations: Dict[str, str] = {"old name": "new:name"} + self.name_translations: dict[str, str] = {"old name": "new:name"} self.history_length: int = 50 - + self.tooling_prompt = "" self.response_template = "" - - async def configure(self, config_d: Dict[str, Any]): - if "instruction_prompt_filename" in config_d: self.instruction_prompt_filename = str(config_d["instruction_prompt_filename"]) - if "character_prompt_filename" in config_d: self.character_prompt_filename = str(config_d["character_prompt_filename"]) - if "scene_prompt_filename" in config_d: self.scene_prompt_filename = str(config_d["scene_prompt_filename"]) - if "character_name" in config_d: self.character_name = str(config_d["character_name"]) - if "name_translations" in config_d: self.name_translations = dict(config_d["name_translations"]) - if "history_length" in config_d: self.history_length = int(config_d["history_length"]) - + + async def configure(self, config_d: dict[str, Any]): + if "instruction_prompt_filename" in config_d: + self.instruction_prompt_filename = str(config_d["instruction_prompt_filename"]) + if "character_prompt_filename" in config_d: + self.character_prompt_filename = str(config_d["character_prompt_filename"]) + if "scene_prompt_filename" in config_d: + self.scene_prompt_filename = str(config_d["scene_prompt_filename"]) + if "character_name" in config_d: + self.character_name = str(config_d["character_name"]) + if "name_translations" in config_d: + self.name_translations = dict(config_d["name_translations"]) + if "history_length" in config_d: + self.history_length = int(config_d["history_length"]) + assert ( - self.instruction_prompt_filename is not None and - len(self.instruction_prompt_filename) > 0 and - os.path.isfile(portable_path(os.path.join( - Config().PROMPT_DIR, - Config().PROMPT_INSTRUCTION_SUBDIR, - self.instruction_prompt_filename - ))) + self.instruction_prompt_filename is not None + and len(self.instruction_prompt_filename) > 0 + and os.path.isfile( + portable_path( + os.path.join( + config.PROMPT_DIR, + config.PROMPT_INSTRUCTION_SUBDIR, + self.instruction_prompt_filename, + ) + ) + ) ) assert ( - self.character_prompt_filename is not None and - len(self.character_prompt_filename) > 0 and - os.path.isfile(portable_path(os.path.join( - Config().PROMPT_DIR, - Config().PROMPT_CHARACTER_SUBDIR, - self.character_prompt_filename - ))) + self.character_prompt_filename is not None + and len(self.character_prompt_filename) > 0 + and os.path.isfile( + portable_path( + os.path.join( + config.PROMPT_DIR, + config.PROMPT_CHARACTER_SUBDIR, + self.character_prompt_filename, + ) + ) + ) ) assert ( - self.scene_prompt_filename is not None and - len(self.scene_prompt_filename) > 0 and - os.path.isfile(portable_path(os.path.join( - Config().PROMPT_DIR, - Config().PROMPT_SCENE_SUBDIR, - self.scene_prompt_filename - ))) + self.scene_prompt_filename is not None + and len(self.scene_prompt_filename) > 0 + and os.path.isfile( + portable_path( + os.path.join( + config.PROMPT_DIR, + config.PROMPT_SCENE_SUBDIR, + self.scene_prompt_filename, + ) + ) + ) ) assert self.character_name is not None and len(self.character_name) assert self.history_length > 0 - - + def clear_history(self): - self.history = list() - + self.history = [] + def insert_history(self, message: Message): self.history.append(message) - self.history = self.history[-(self.history_length):] - - with open(Config().history_filepath, 'a', encoding="utf-8") as f: + self.history = self.history[-(self.history_length) :] + + with open(config.history_filepath, "a", encoding="utf-8") as f: f.write(message.to_line()) f.write("\n") - + # Custom context - def register_custom_context(self, context_id: str, context_name: str, context_description: str = None): - - self.context_metadata[context_id] = ContextMetadata(context_id, context_name, context_description) + def register_custom_context( + self, context_id: str, context_name: str, context_description: str = None + ): + + self.context_metadata[context_id] = ContextMetadata( + context_id, context_name, context_description + ) def remove_custom_context(self, context_id: str): assert context_id in self.context_metadata - + del self.context_metadata[context_id] def add_custom_context(self, context_id: str, contents: str, time: datetime.datetime = None): assert context_id in self.context_metadata assert contents and len(contents) > 0 - - if time is None: time = get_current_time(include_ms=False, as_str=False) + + if time is None: + time = get_current_time(include_ms=False, as_str=False) self.insert_history(CustomMessage(self.context_metadata[context_id], contents, time)) # Main conversation def translate_name(self, name: str): return self.name_translations.get(name, name) - + def add_chat(self, name: str, message: str, time: datetime.datetime = None): assert name and len(name) > 0 assert message and len(message) > 0 - - if time is None: time = get_current_time(include_ms=False, as_str=False) + + if time is None: + time = get_current_time(include_ms=False, as_str=False) self.insert_history(ChatMessage(self.translate_name(name), message, time)) - - async def add_chat_stream(self, name: str, in_stream: AsyncGenerator, time: datetime.datetime = None): - if time is None: time = get_current_time(include_ms=False, as_str=False) - - message = '' + + async def add_chat_stream( + self, name: str, in_stream: AsyncGenerator, time: datetime.datetime = None + ): + if time is None: + time = get_current_time(include_ms=False, as_str=False) + + message = "" async for in_d in in_stream: - message += in_d['content'] - + message += in_d["content"] + self.insert_history(ChatMessage(self.translate_name(name), message, time)) # Requests def add_request(self, message: str, time: datetime.datetime = None): assert message and len(message) > 0 - - if time is None: time = get_current_time(include_ms=False, as_str=False) - + + if time is None: + time = get_current_time(include_ms=False, as_str=False) + self.insert_history(RequestMessage(message, time)) - + # Prompt generators def get_instructions_prompt(self): - with open(portable_path(os.path.join( - Config().PROMPT_DIR, - Config().PROMPT_INSTRUCTION_SUBDIR, - self.instruction_prompt_filename - )), 'r', encoding="utf-8") as f: + with open( + portable_path( + os.path.join( + config.PROMPT_DIR, + config.PROMPT_INSTRUCTION_SUBDIR, + self.instruction_prompt_filename, + ) + ), + encoding="utf-8", + ) as f: return f.read() - + def get_context_descriptions(self): result = "" for context_id in self.context_metadata: - result += "{name}: {description}\n".format( - name=self.context_metadata[context_id].name, - description=self.context_metadata[context_id].description - ) - + result += f"{self.context_metadata[context_id].name}: {self.context_metadata[context_id].description}\n" + return result - + def get_character_prompt(self): - with open(portable_path(os.path.join( - Config().PROMPT_DIR, - Config().PROMPT_CHARACTER_SUBDIR, - self.character_prompt_filename - )), 'r', encoding="utf-8") as f: + with open( + portable_path( + os.path.join( + config.PROMPT_DIR, + config.PROMPT_CHARACTER_SUBDIR, + self.character_prompt_filename, + ) + ), + encoding="utf-8", + ) as f: return f.read() - + def get_scene_prompt(self): - with open(portable_path(os.path.join( - Config().PROMPT_DIR, - Config().PROMPT_SCENE_SUBDIR, - self.scene_prompt_filename - )), 'r', encoding="utf-8") as f: + with open( + portable_path( + os.path.join( + config.PROMPT_DIR, config.PROMPT_SCENE_SUBDIR, self.scene_prompt_filename + ) + ), + encoding="utf-8", + ) as f: return f.read() - + def get_sys_prompt(self): - return "{instructions}\n{mcp_usage}\n{contexts}\n### Character ###\n{character}\n### Scene ###\n{scene}".format( - instructions = self.get_instructions_prompt(), - contexts = self.get_context_descriptions(), - mcp_usage = self.response_template, - character = self.get_character_prompt(), - scene = self.get_scene_prompt(), - ) + return f"{self.get_instructions_prompt()}\n{self.response_template}\n{self.get_context_descriptions()}\n### Character ###\n{self.get_character_prompt()}\n### Scene ###\n{self.get_scene_prompt()}" def get_history_text(self): prompt = "" - + for message in self.history: message_line = message.to_line() - prompt += "\n{}".format(message_line) - + prompt += f"\n{message_line}" + return prompt - + def get_history(self): return self.history def add_mcp_usage_prompt(self, tooling_prompt: str, response_template: str): self.tooling_prompt = tooling_prompt self.response_template = response_template - + def generate_mcp_system_context(self): return self.tooling_prompt - + def generate_mcp_user_context(self): user_prompt = self.get_history_text() character = self.get_character_prompt() scene = self.get_scene_prompt() - + return f"{character}{scene}