diff --git a/Cargo.toml b/Cargo.toml index 54f781b..a43ac26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,8 @@ serde_json = "1.0" anyhow = "1.0" notify = "6.1" uuid = { version = "1.0", features = ["v4"] } +oci-distribution = "0.11" +toml = "0.8" [build-dependencies] tonic-build = "0.11" diff --git a/STRUCTURE.tree b/STRUCTURE.tree index 6930b2f..949ffee 100644 --- a/STRUCTURE.tree +++ b/STRUCTURE.tree @@ -8,6 +8,9 @@ ├── CLAUDE.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md +├── deploy +│   └── k8s +│   └── vtuber-image.yaml ├── DEPLOYMENT_GUIDE.md ├── DESIGN_DECISIONS.md ├── docker-compose.yml @@ -50,11 +53,14 @@ │   ├── guard │   │   ├── cache.rs │   │   └── mod.rs -│   └── main.rs +│   ├── main.rs +│   └── registry +│   ├── client.rs +│   └── mod.rs ├── STRATEGY.md ├── STRUCTURE.tree ├── SUPPORT.md ├── TROUBLESHOOTING.md └── VISION.md -11 directories, 47 files +14 directories, 50 files diff --git a/python/comfy_client.py b/python/comfy_client.py index e833539..26ed8b3 100644 --- a/python/comfy_client.py +++ b/python/comfy_client.py @@ -130,8 +130,13 @@ def verify_model(self, model_id, expected_hash, allow_nsfw): try: req = json.loads(input_data) - # 1. Fetch template - workflow = client.fetch_template(req['template_bucket'], req['template_key']) + # 1. Fetch template or use provided workflow_json + if 'workflow_json' in req: + workflow = req['workflow_json'] + if isinstance(workflow, str): + workflow = json.loads(workflow) + else: + workflow = client.fetch_template(req['template_bucket'], req['template_key']) # 2. Inject overrides workflow = client.inject_overrides(workflow, req.get('overrides', {})) diff --git a/src/guard/cache.rs b/src/guard/cache.rs index 8a8fdf4..6a729fd 100644 --- a/src/guard/cache.rs +++ b/src/guard/cache.rs @@ -12,9 +12,20 @@ pub struct ModelEntry { pub allow_nsfw: bool, } +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct PersonaAssets { + pub image_registry: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct PersonaConfig { + pub assets: PersonaAssets, +} + #[derive(Debug, Clone)] pub struct GuardCache { - cache: Arc>>, + model_cache: Arc>>, + persona_cache: Arc>>, } impl Default for GuardCache { @@ -26,7 +37,8 @@ impl Default for GuardCache { impl GuardCache { pub fn new() -> Self { Self { - cache: Arc::new(RwLock::new(HashMap::new())), + model_cache: Arc::new(RwLock::new(HashMap::new())), + persona_cache: Arc::new(RwLock::new(HashMap::new())), } } @@ -35,7 +47,7 @@ impl GuardCache { let entries: Vec = serde_json::from_str(&content)?; let mut cache = self - .cache + .model_cache .write() .map_err(|_| anyhow::anyhow!("Failed to acquire write lock"))?; cache.clear(); @@ -47,7 +59,18 @@ impl GuardCache { } pub fn get_model(&self, model_id: &str) -> Option { - let cache = self.cache.read().ok()?; + let cache = self.model_cache.read().ok()?; cache.get(model_id).cloned() } + + pub fn get_persona(&self, persona_id: &str) -> Option { + let cache = self.persona_cache.read().ok()?; + cache.get(persona_id).cloned() + } + + pub fn insert_persona(&self, persona_id: String, config: PersonaConfig) { + if let Ok(mut cache) = self.persona_cache.write() { + cache.insert(persona_id, config); + } + } } diff --git a/src/main.rs b/src/main.rs index 6fd45ef..d6d39e8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use notify::{Event, RecursiveMode, Watcher}; use std::io::Write; use std::path::Path; use std::process::Stdio; +use std::sync::Arc; use tonic::{transport::Server, Request, Response, Status}; use vtuber_image::v1::image_generator_service_server::{ ImageGeneratorService, ImageGeneratorServiceServer, @@ -9,6 +10,7 @@ use vtuber_image::v1::image_generator_service_server::{ use vtuber_image::v1::{GenerateRequest, GenerateResponse}; pub mod guard; +pub mod registry; pub mod vtuber_image { pub mod v1 { @@ -18,6 +20,7 @@ pub mod vtuber_image { pub struct MyImageGeneratorService { pub guard_cache: guard::cache::GuardCache, + pub registry_client: Arc, } #[tonic::async_trait] @@ -29,18 +32,46 @@ impl ImageGeneratorService for MyImageGeneratorService { let req = request.into_inner(); println!("Received request for persona: {}", req.persona_id); - // Task 2: Rust gRPC Enforcement - // Placeholder check: Is the persona_id in the allowlist cache? - if self.guard_cache.get_model(&req.persona_id).is_none() { - return Err(Status::permission_denied(format!( - "Requested configuration (persona: {}) is not in the allowlist", - req.persona_id - ))); - } + // Task 2: Reactive Persona Mapping + let config_path = std::env::var("CONFIG_PATH").unwrap_or_else(|_| "config".to_string()); + + // 1. Get Persona Config (from cache or file) + let persona_config = if let Some(config) = self.guard_cache.get_persona(&req.persona_id) { + config + } else { + let persona_path = Path::new(&config_path) + .join("personas") + .join(format!("{}.toml", req.persona_id)); + if !persona_path.exists() { + return Err(Status::not_found(format!( + "Persona config not found for {}", + req.persona_id + ))); + } + + let content = std::fs::read_to_string(&persona_path) + .map_err(|e| Status::internal(format!("Failed to read persona file: {}", e)))?; + let config: guard::cache::PersonaConfig = toml::from_str(&content) + .map_err(|e| Status::internal(format!("Failed to parse persona TOML: {}", e)))?; + self.guard_cache + .insert_persona(req.persona_id.clone(), config.clone()); + config + }; + + // 2. Pull workflow from OCI Registry + let image_registry_url = &persona_config.assets.image_registry; + let workflow_json = self + .registry_client + .pull_workflow(image_registry_url) + .await + .map_err(|e| { + Status::internal(format!("Failed to pull workflow from registry: {}", e)) + })?; + + // 3. Prepare payload for Python orchestration let input_payload = serde_json::json!({ - "template_bucket": std::env::var("S3_BUCKET_TEMPLATES").unwrap_or_else(|_| "templates".to_string()), - "template_key": format!("{}.json", req.persona_id), + "workflow_json": workflow_json, "overrides": { "hair_style": req.overrides.as_ref().map(|o| o.hair_style.clone()).unwrap_or_default(), "eye_color": req.overrides.as_ref().map(|o| o.eye_color.clone()).unwrap_or_default(), @@ -136,7 +167,15 @@ async fn main() -> Result<(), Box> { } }); - let generator = MyImageGeneratorService { guard_cache }; + let cache_dir = std::env::var("CACHE_DIR").unwrap_or_else(|_| "cache/workflows".to_string()); + let registry_client = Arc::new(registry::OCIClient::new( + Path::new(&cache_dir).to_path_buf(), + )); + + let generator = MyImageGeneratorService { + guard_cache, + registry_client, + }; println!("ImageGeneratorService server listening on {}", addr); diff --git a/src/registry/client.rs b/src/registry/client.rs new file mode 100644 index 0000000..b8021ae --- /dev/null +++ b/src/registry/client.rs @@ -0,0 +1,76 @@ +use anyhow::Context; +use oci_distribution::{client::Client, Reference}; +use std::path::PathBuf; +use tokio::fs; + +pub struct OCIClient { + client: Client, + cache_dir: PathBuf, +} + +impl OCIClient { + pub fn new(cache_dir: PathBuf) -> Self { + Self { + client: Client::default(), + cache_dir, + } + } + + pub async fn pull_workflow(&self, image_url: &str) -> anyhow::Result { + let reference: Reference = image_url.parse().context("Failed to parse image URL")?; + + // Smart Cache Logic: Use sanitized image_url as filename + let cache_filename = format!("{}.json", image_url.replace("/", "_").replace(":", "_")); + let cache_path = self.cache_dir.join(cache_filename); + + if cache_path.exists() { + println!("Cache hit for workflow: {}", image_url); + return fs::read_to_string(cache_path) + .await + .context("Failed to read cached workflow"); + } + + println!( + "Cache miss for workflow: {}, pulling from registry...", + image_url + ); + + let auth = oci_distribution::secrets::RegistryAuth::Anonymous; + + // Pull the image data + // For simplicity, we pull the manifest and then the blobs (layers) + let image_data = self + .client + .pull( + &reference, + &auth, + vec![ + "application/vnd.oci.image.layer.v1.tar+gzip", + "application/vnd.docker.image.rootfs.diff.tar.gzip", + ], + ) + .await + .context("Failed to pull image from OCI registry")?; + + let mut workflow_content = None; + for layer in image_data.layers { + if let Ok(content) = String::from_utf8(layer.data) { + if content.trim().starts_with('{') { + workflow_content = Some(content); + break; + } + } + } + + let content = workflow_content + .ok_or_else(|| anyhow::anyhow!("workflow.json not found in OCI image layers"))?; + + // Save to cache + if let Some(parent) = cache_path.parent() { + fs::create_dir_all(parent).await?; + } + fs::write(&cache_path, &content).await?; + + Ok(content) + } +} diff --git a/src/registry/mod.rs b/src/registry/mod.rs new file mode 100644 index 0000000..027f8da --- /dev/null +++ b/src/registry/mod.rs @@ -0,0 +1,18 @@ +pub mod client; + +pub use client::OCIClient; + +pub struct SmartCache { + pub cache_dir: std::path::PathBuf, +} + +impl SmartCache { + pub fn new(cache_dir: std::path::PathBuf) -> Self { + Self { cache_dir } + } + + pub fn get_cache_path(&self, image_url: &str) -> std::path::PathBuf { + let cache_filename = format!("{}.json", image_url.replace("/", "_").replace(":", "_")); + self.cache_dir.join(cache_filename) + } +} diff --git a/workflows/flux_dev_v1.json b/workflows/flux_dev_v1.json new file mode 100644 index 0000000..4c3d7d2 --- /dev/null +++ b/workflows/flux_dev_v1.json @@ -0,0 +1,53 @@ +{ + "6": { + "inputs": { + "text": "{{hair_style}}, {{eye_color}}, wearing {{outfit}}, masterpiece, high quality, anime style", + "clip": [ + "11", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "Positive Prompt" + } + }, + "11": { + "inputs": { + "model_name": "flux_dev.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Flux.1 dev" + } + }, + "8": { + "inputs": { + "samples": [ + "13", + 0 + ], + "vae": [ + "11", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "VTuber_Flux", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Result" + } + } +}