Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ serde_json = "1.0"
anyhow = "1.0"
notify = "6.1"
uuid = { version = "1.0", features = ["v4"] }
oci-distribution = "0.11"
toml = "0.8"

[build-dependencies]
tonic-build = "0.11"
10 changes: 8 additions & 2 deletions STRUCTURE.tree
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
β”œβ”€β”€ CLAUDE.md
β”œβ”€β”€ CODE_OF_CONDUCT.md
β”œβ”€β”€ CONTRIBUTING.md
β”œβ”€β”€ deploy
β”‚Β Β  └── k8s
β”‚Β Β  └── vtuber-image.yaml
β”œβ”€β”€ DEPLOYMENT_GUIDE.md
β”œβ”€β”€ DESIGN_DECISIONS.md
β”œβ”€β”€ docker-compose.yml
Expand Down Expand Up @@ -50,11 +53,14 @@
β”‚Β Β  β”œβ”€β”€ guard
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ cache.rs
β”‚Β Β  β”‚Β Β  └── mod.rs
β”‚Β Β  └── main.rs
β”‚Β Β  β”œβ”€β”€ main.rs
β”‚Β Β  └── registry
β”‚Β Β  β”œβ”€β”€ client.rs
β”‚Β Β  └── mod.rs
β”œβ”€β”€ STRATEGY.md
β”œβ”€β”€ STRUCTURE.tree
β”œβ”€β”€ SUPPORT.md
β”œβ”€β”€ TROUBLESHOOTING.md
└── VISION.md

11 directories, 47 files
14 directories, 50 files
9 changes: 7 additions & 2 deletions python/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,13 @@ def verify_model(self, model_id, expected_hash, allow_nsfw):
try:
req = json.loads(input_data)

# 1. Fetch template
workflow = client.fetch_template(req['template_bucket'], req['template_key'])
# 1. Fetch template or use provided workflow_json
if 'workflow_json' in req:
workflow = req['workflow_json']
if isinstance(workflow, str):
workflow = json.loads(workflow)
else:
workflow = client.fetch_template(req['template_bucket'], req['template_key'])

# 2. Inject overrides
workflow = client.inject_overrides(workflow, req.get('overrides', {}))
Expand Down
31 changes: 27 additions & 4 deletions src/guard/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,20 @@ pub struct ModelEntry {
pub allow_nsfw: bool,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PersonaAssets {
pub image_registry: String,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PersonaConfig {
pub assets: PersonaAssets,
}

#[derive(Debug, Clone)]
pub struct GuardCache {
cache: Arc<RwLock<HashMap<String, ModelEntry>>>,
model_cache: Arc<RwLock<HashMap<String, ModelEntry>>>,
persona_cache: Arc<RwLock<HashMap<String, PersonaConfig>>>,
}

impl Default for GuardCache {
Expand All @@ -26,7 +37,8 @@ impl Default for GuardCache {
impl GuardCache {
pub fn new() -> Self {
Self {
cache: Arc::new(RwLock::new(HashMap::new())),
model_cache: Arc::new(RwLock::new(HashMap::new())),
persona_cache: Arc::new(RwLock::new(HashMap::new())),
}
}

Expand All @@ -35,7 +47,7 @@ impl GuardCache {
let entries: Vec<ModelEntry> = serde_json::from_str(&content)?;

let mut cache = self
.cache
.model_cache
.write()
.map_err(|_| anyhow::anyhow!("Failed to acquire write lock"))?;
cache.clear();
Expand All @@ -47,7 +59,18 @@ impl GuardCache {
}

pub fn get_model(&self, model_id: &str) -> Option<ModelEntry> {
let cache = self.cache.read().ok()?;
let cache = self.model_cache.read().ok()?;
cache.get(model_id).cloned()
}

pub fn get_persona(&self, persona_id: &str) -> Option<PersonaConfig> {
let cache = self.persona_cache.read().ok()?;
cache.get(persona_id).cloned()
}

pub fn insert_persona(&self, persona_id: String, config: PersonaConfig) {
if let Ok(mut cache) = self.persona_cache.write() {
cache.insert(persona_id, config);
}
}
}
61 changes: 50 additions & 11 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ use notify::{Event, RecursiveMode, Watcher};
use std::io::Write;
use std::path::Path;
use std::process::Stdio;
use std::sync::Arc;
use tonic::{transport::Server, Request, Response, Status};
use vtuber_image::v1::image_generator_service_server::{
ImageGeneratorService, ImageGeneratorServiceServer,
};
use vtuber_image::v1::{GenerateRequest, GenerateResponse};

pub mod guard;
pub mod registry;

pub mod vtuber_image {
pub mod v1 {
Expand All @@ -18,6 +20,7 @@ pub mod vtuber_image {

pub struct MyImageGeneratorService {
pub guard_cache: guard::cache::GuardCache,
pub registry_client: Arc<registry::OCIClient>,
}

#[tonic::async_trait]
Expand All @@ -29,18 +32,46 @@ impl ImageGeneratorService for MyImageGeneratorService {
let req = request.into_inner();
println!("Received request for persona: {}", req.persona_id);

// Task 2: Rust gRPC Enforcement
// Placeholder check: Is the persona_id in the allowlist cache?
if self.guard_cache.get_model(&req.persona_id).is_none() {
return Err(Status::permission_denied(format!(
"Requested configuration (persona: {}) is not in the allowlist",
req.persona_id
)));
}
// Task 2: Reactive Persona Mapping
let config_path = std::env::var("CONFIG_PATH").unwrap_or_else(|_| "config".to_string());

// 1. Get Persona Config (from cache or file)
let persona_config = if let Some(config) = self.guard_cache.get_persona(&req.persona_id) {
config
} else {
let persona_path = Path::new(&config_path)
.join("personas")
.join(format!("{}.toml", req.persona_id));
if !persona_path.exists() {
return Err(Status::not_found(format!(
"Persona config not found for {}",
req.persona_id
)));
}

let content = std::fs::read_to_string(&persona_path)
.map_err(|e| Status::internal(format!("Failed to read persona file: {}", e)))?;
let config: guard::cache::PersonaConfig = toml::from_str(&content)
.map_err(|e| Status::internal(format!("Failed to parse persona TOML: {}", e)))?;

self.guard_cache
.insert_persona(req.persona_id.clone(), config.clone());
config
};

// 2. Pull workflow from OCI Registry
let image_registry_url = &persona_config.assets.image_registry;
let workflow_json = self
.registry_client
.pull_workflow(image_registry_url)
.await
.map_err(|e| {
Status::internal(format!("Failed to pull workflow from registry: {}", e))
})?;

// 3. Prepare payload for Python orchestration
let input_payload = serde_json::json!({
"template_bucket": std::env::var("S3_BUCKET_TEMPLATES").unwrap_or_else(|_| "templates".to_string()),
"template_key": format!("{}.json", req.persona_id),
"workflow_json": workflow_json,
"overrides": {
"hair_style": req.overrides.as_ref().map(|o| o.hair_style.clone()).unwrap_or_default(),
"eye_color": req.overrides.as_ref().map(|o| o.eye_color.clone()).unwrap_or_default(),
Expand Down Expand Up @@ -136,7 +167,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
});

let generator = MyImageGeneratorService { guard_cache };
let cache_dir = std::env::var("CACHE_DIR").unwrap_or_else(|_| "cache/workflows".to_string());
let registry_client = Arc::new(registry::OCIClient::new(
Path::new(&cache_dir).to_path_buf(),
));

let generator = MyImageGeneratorService {
guard_cache,
registry_client,
};

println!("ImageGeneratorService server listening on {}", addr);

Expand Down
76 changes: 76 additions & 0 deletions src/registry/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use anyhow::Context;
use oci_distribution::{client::Client, Reference};
use std::path::PathBuf;
use tokio::fs;

pub struct OCIClient {
client: Client,
cache_dir: PathBuf,
}

impl OCIClient {
pub fn new(cache_dir: PathBuf) -> Self {
Self {
client: Client::default(),
cache_dir,
}
}

pub async fn pull_workflow(&self, image_url: &str) -> anyhow::Result<String> {
let reference: Reference = image_url.parse().context("Failed to parse image URL")?;

// Smart Cache Logic: Use sanitized image_url as filename
let cache_filename = format!("{}.json", image_url.replace("/", "_").replace(":", "_"));
let cache_path = self.cache_dir.join(cache_filename);

if cache_path.exists() {
println!("Cache hit for workflow: {}", image_url);
return fs::read_to_string(cache_path)
.await
.context("Failed to read cached workflow");
}

println!(
"Cache miss for workflow: {}, pulling from registry...",
image_url
);

let auth = oci_distribution::secrets::RegistryAuth::Anonymous;

// Pull the image data
// For simplicity, we pull the manifest and then the blobs (layers)
let image_data = self
.client
.pull(
&reference,
&auth,
vec![
"application/vnd.oci.image.layer.v1.tar+gzip",
"application/vnd.docker.image.rootfs.diff.tar.gzip",
],
)
.await
.context("Failed to pull image from OCI registry")?;

let mut workflow_content = None;
for layer in image_data.layers {
if let Ok(content) = String::from_utf8(layer.data) {
if content.trim().starts_with('{') {
workflow_content = Some(content);
break;
}
}
}

let content = workflow_content
.ok_or_else(|| anyhow::anyhow!("workflow.json not found in OCI image layers"))?;

// Save to cache
if let Some(parent) = cache_path.parent() {
fs::create_dir_all(parent).await?;
}
fs::write(&cache_path, &content).await?;

Ok(content)
}
}
18 changes: 18 additions & 0 deletions src/registry/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
pub mod client;

pub use client::OCIClient;

pub struct SmartCache {
pub cache_dir: std::path::PathBuf,
}

impl SmartCache {
pub fn new(cache_dir: std::path::PathBuf) -> Self {
Self { cache_dir }
}

pub fn get_cache_path(&self, image_url: &str) -> std::path::PathBuf {
let cache_filename = format!("{}.json", image_url.replace("/", "_").replace(":", "_"));
self.cache_dir.join(cache_filename)
}
}
53 changes: 53 additions & 0 deletions workflows/flux_dev_v1.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
{
"6": {
"inputs": {
"text": "{{hair_style}}, {{eye_color}}, wearing {{outfit}}, masterpiece, high quality, anime style",
"clip": [
"11",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "Positive Prompt"
}
},
"11": {
"inputs": {
"model_name": "flux_dev.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Flux.1 dev"
}
},
"8": {
"inputs": {
"samples": [
"13",
0
],
"vae": [
"11",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "VTuber_Flux",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Result"
}
}
}
Loading