diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index 6b13a4a..479d462 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -49,7 +49,7 @@ jobs: - name: Clippy run: | if [ "${{ matrix.feature }}" = "all-features" ]; then - cargo clippy --no-default-features --features "all-models,video,viewer,annotator,hf-hub,ort-download-binaries,ort-load-dynamic" --all-targets -- -D warnings + cargo clippy --no-default-features --features "all-models,video,viewer,annotator,ort-download-binaries,ort-load-dynamic" --all-targets -- -D warnings elif [ "${{ matrix.feature }}" = "" ]; then cargo clippy --no-default-features --all-targets -- -D warnings else @@ -74,7 +74,7 @@ jobs: uses: dtolnay/rust-toolchain@stable - name: Check - run: cargo check --no-default-features --features "all-models,video,viewer,annotator,hf-hub,ort-download-binaries,ort-load-dynamic" --all-targets + run: cargo check --no-default-features --features "all-models,video,viewer,annotator,ort-download-binaries,ort-load-dynamic" --all-targets test: name: cargo-test @@ -94,7 +94,7 @@ jobs: uses: dtolnay/rust-toolchain@nightly - name: Test - run: cargo +nightly test --no-default-features --features "all-models,video,viewer,annotator,hf-hub,ort-download-binaries,ort-load-dynamic" --all-targets + run: cargo +nightly test --no-default-features --features "all-models,video,viewer,annotator,ort-download-binaries,ort-load-dynamic" --all-targets build-linux: needs: test @@ -120,4 +120,4 @@ jobs: uses: dtolnay/rust-toolchain@stable - name: Build - run: cargo build --no-default-features --features "all-models,video,viewer,annotator,hf-hub,ort-download-binaries,ort-load-dynamic" \ No newline at end of file + run: cargo build --no-default-features --features "all-models,video,viewer,annotator,ort-download-binaries,ort-load-dynamic" \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index ea41574..d94691a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,7 +52,6 @@ cudarc = { version = "0.19", optional = true, default-features = false, features "dynamic-linking" ] } ureq = { version = "3.1.4", default-features = false, features = ["rustls", "gzip"] } -hf-hub = { version = "0.4.3", default-features = false, features = ["ureq", "rustls-tls"], optional = true } tokenizers = { version = "0.22.1", optional = true } lru = { version = "0.16.2", default-features = false } @@ -200,9 +199,6 @@ image-qoi = ["image/qoi"] image-tga = ["image/tga"] image-all-formats = ["image/default-formats"] -# Hugging Face hub support (for downloading models from Hugging Face) -hf-hub = ["dep:hf-hub"] - # Model Zoo vision = [] vlm = ["vision", "dep:tokenizers", "dep:ndarray-npy"] diff --git a/docs/cargo-features/overview.md b/docs/cargo-features/overview.md index 9e155f1..e131e32 100644 --- a/docs/cargo-features/overview.md +++ b/docs/cargo-features/overview.md @@ -18,7 +18,7 @@ --- - visualization, video I/O, model hub, and annotation utilities. + visualization, video I/O, and annotation utilities. [:octicons-arrow-right-24: Utilities →](./utils.md) diff --git a/docs/cargo-features/utils.md b/docs/cargo-features/utils.md index 461ef8d..a35252a 100644 --- a/docs/cargo-features/utils.md +++ b/docs/cargo-features/utils.md @@ -5,7 +5,6 @@ | ***`annotator`*** | Annotation | Draw bounding boxes, keypoints, masks on images | `ab_glyph`, `imageproc` | ✓ | | **`viewer`** | Visualization | Real-time image/video display (like OpenCV `imshow`) | `minifb` | x | | **`video`** | I/O | Video read/write streaming support | `video-rs` | x | -| **`hf-hub`** | Model Hub | Download models from Hugging Face | `hf-hub` | x | !!! tip "Usage Example" ```toml diff --git a/docs/guides/hub.md b/docs/guides/hub.md new file mode 100644 index 0000000..8e63ba3 --- /dev/null +++ b/docs/guides/hub.md @@ -0,0 +1,86 @@ +# Hub + +`Hub` downloads and caches files from **GitHub Releases** and **Hugging Face** repositories. + +## Supported Formats + +| Source | Format | Example | +|--------|--------|---------| +| Local file | File path | `"./model.onnx"` | +| GitHub Release | `/` | `"yolo/v5-n-det.onnx"` | +| GitHub Release URL | Full URL | `"https://github.com///releases/download//"` | +| HF (inline) | `//` | `"BAAI/bge-m3/tokenizer.json"` | +| HF (dedicated) | `` via `from_hf` | `"onnx/model.onnx"` | +| HF URL | Full URL (`resolve`/`blob`) | `"https://huggingface.co///blob/main/"` | + +!!! tip "HF Endpoint" + By default, Hugging Face downloads use `https://huggingface.co`. + + Set the `HF_ENDPOINT` environment variable to use a mirror: + ```bash + export HF_ENDPOINT=https://hf-mirror.com + ``` + +## GitHub Release + +!!! example "Default Repository" + Download files from the default GitHub repository (`jamjamjon/assets`): + + ```rust + let path = Hub::default().try_fetch("images/bus.jpg")?; + ``` + +!!! example "Custom Repository" + ```rust + let mut hub = Hub::new("owner", "repo"); + let path = hub.try_fetch("/")?; + ``` + +!!! example "Direct GitHub URL" + ```rust + let path = Hub::default().try_fetch( + "https://github.com///releases/download//" + )?; + ``` + +## Hugging Face + +!!! example "Inline Path (Recommended)" + Use `//` format directly — no extra setup needed: + + ```rust + let path = Hub::default().try_fetch("///")?; + ``` + +!!! example "Dedicated Hub" + Bind a Hub to a specific HF repository: + + ```rust + let mut hub = Hub::from_hf("", "")?; + let path = hub.try_fetch("")?; + let path = hub.try_fetch("/")?; + ``` + +!!! example "Direct HF URL" + Supports both `/resolve/` and `/blob/` URLs: + + ```rust + let path = Hub::default().try_fetch( + "https://huggingface.co///blob/main/" + )?; + ``` + +## Repository Info + +!!! example "Inspect Repository" + ```rust + Hub::default().info()?; // GitHub releases + Hub::from_hf("", "")?.info()?; // HF file tree with sizes + ``` + +## Caching + +!!! info "Cache Behavior" + - Files are cached locally after the first download (`~/.cache/usls/` or similar). + - GitHub release metadata: TTL-based (default 10 min, configurable via `with_ttl`). + - Failed or incomplete downloads are discarded (atomic write via temp files). diff --git a/docs/model-zoo/embedding.md b/docs/model-zoo/embedding.md index 6d5cb37..ef473e5 100644 --- a/docs/model-zoo/embedding.md +++ b/docs/model-zoo/embedding.md @@ -13,5 +13,7 @@ hide: | [jina-clip-v1](https://huggingface.co/jinaai/jina-clip-v1) | Vision-Language Embedding | [demo](https://github.com/jamjamjon/usls/tree/main/examples/embedding) | ✅ | ❓ | ✅ | ✅ | ✅ | ✅ | ✅ | | [jina-clip-v2](https://huggingface.co/jinaai/jina-clip-v2) | Vision-Language Embedding | [demo](https://github.com/jamjamjon/usls/tree/main/examples/embedding) | ✅ | ❓ | ✅ | ✅ | ✅ | ✅ | ✅ | | [mobileclip](https://github.com/apple/ml-mobileclip) | Vision-Language Embedding | [demo](https://github.com/jamjamjon/usls/tree/main/examples/embedding) | ✅ | ❓ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [SigLIP](https://huggingface.co/collections/google/siglip) | Vision-Language Embedding | [demo](https://github.com/jamjamjon/usls/tree/main/examples/embedding) | ✅ | ❓ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [SigLIPv2](https://huggingface.co/collections/google/siglip2) | Vision-Language Embedding | [demo](https://github.com/jamjamjon/usls/tree/main/examples/embedding) | ✅ | ❓ | ✅ | ✅ | ✅ | ✅ | ✅ | | [DINOv2](https://github.com/facebookresearch/dinov2) | Vision Embedding | [demo](https://github.com/jamjamjon/usls/tree/main/examples/embedding) | ✅ | ❓ | ✅ | ❌ | ❌ | ❌ | ❌ | | [DINOv3](https://github.com/facebookresearch/dinov3) | Vision Embedding | [demo](https://github.com/jamjamjon/usls/tree/main/examples/embedding) | ✅ | ❓ | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/examples/embedding/README.md b/examples/embedding/README.md index 1dc63c8..3028fdc 100644 --- a/examples/embedding/README.md +++ b/examples/embedding/README.md @@ -8,21 +8,51 @@ This directory contains examples for embedding models that extract feature repre Vision-language model for image and text embeddings. **Variants:** -- `clip-b16` - CLIP ViT-B/16 -- `clip-b32` - CLIP ViT-B/32 -- `clip-l14` - CLIP ViT-L/14 -- `jina-clip-v1` - Jina CLIP v1 -- `jina-clip-v2` - Jina CLIP v2 -- `mobileclip-s0` - MobileCLIP S0 -- `mobileclip-s1` - MobileCLIP S1 -- `mobileclip-s2` - MobileCLIP S2 -- `mobileclip-b` - MobileCLIP B -- `mobileclip-blt` - MobileCLIP BLT -- `mobileclip2-s0` - MobileCLIP2 S0 (default) -- `mobileclip2-s2` - MobileCLIP2 S2 -- `mobileclip2-s4` - MobileCLIP2 S4 -- `mobileclip2-b` - MobileCLIP2 B -- `mobileclip2-l14` - MobileCLIP2 L14 + +**OpenAI CLIP:** +- `clip-b16` - ViT-B/16 (85M params) +- `clip-b32` - ViT-B/32 (87M params) +- `clip-l14` - ViT-L/14 (304M params) + +**Jina CLIP:** +- `jina-clip-v1` - Improved performance, 224x224 +- `jina-clip-v2` - 512x512 resolution, better accuracy + +**MobileCLIP (Apple):** +- `mobileclip-s0` - Small variant S0 +- `mobileclip-s1` - Small variant S1 +- `mobileclip-s2` - Small variant S2 +- `mobileclip-b` - Base variant +- `mobileclip-blt` - Base with large text encoder + +**MobileCLIP v2:** +- `mobileclip2-s0` - Enhanced small S0 (default) +- `mobileclip2-s2` - Enhanced small S2 +- `mobileclip2-s4` - Enhanced small S4 +- `mobileclip2-b` - Enhanced base +- `mobileclip2-l14` - Enhanced large + +**SigLIP (Google DeepMind):** +- `siglip-b16-224` - Base, patch16, 224x224 +- `siglip-b16-256` - Base, patch16, 256x256 +- `siglip-b16-384` - Base, patch16, 384x384 +- `siglip-b16-512` - Base, patch16, 512x512 +- `siglip-l16-256` - Large, patch16, 256x256 +- `siglip-l16-384` - Large, patch16, 384x384 + +**SigLIP v2 (Google DeepMind):** +- `siglip2-b16-224` - Base v2, patch16, 224x224 +- `siglip2-b16-256` - Base v2, patch16, 256x256 +- `siglip2-b16-384` - Base v2, patch16, 384x384 +- `siglip2-b16-512` - Base v2, patch16, 512x512 +- `siglip2-l16-256` - Large v2, patch16, 256x256 +- `siglip2-l16-384` - Large v2, patch16, 384x384 +- `siglip2-l16-512` - Large v2, patch16, 512x512 +- `siglip2-so400m-patch14-224` - 400M, patch14, 224x224 +- `siglip2-so400m-patch14-384` - 400M, patch14, 384x384 +- `siglip2-so400m-patch16-256` - 400M, patch16, 256x256 +- `siglip2-so400m-patch16-384` - 400M, patch16, 384x384 +- `siglip2-so400m-patch16-512` - 400M, patch16, 512x512 **Usage:** ```bash diff --git a/examples/embedding/clip.rs b/examples/embedding/clip.rs index 0424a2f..202548f 100644 --- a/examples/embedding/clip.rs +++ b/examples/embedding/clip.rs @@ -4,7 +4,7 @@ use usls::{Config, DType, Device}; #[derive(Args, Debug)] pub struct ClipArgs { - /// Variant: clip-b16, clip-b32, clip-l14, jina-clip-v1, jina-clip-v2, mobileclip-s0, mobileclip-s1, mobileclip-s2, mobileclip-b, mobileclip-blt, mobileclip2-s0, mobileclip2-s2, mobileclip2-s4, mobileclip2-b, mobileclip2-l14 + /// Variant: clip-b16, clip-b32, clip-l14, jina-clip-v1, jina-clip-v2, mobileclip-s0, mobileclip-s1, mobileclip-s2, mobileclip-b, mobileclip-blt, mobileclip2-s0, mobileclip2-s2, mobileclip2-s4, mobileclip2-b, mobileclip2-l14, siglip-b16-224, siglip-b16-256, siglip-b16-384, siglip-b16-512, siglip-l16-256, siglip-l16-384, siglip2-b16-224, siglip2-b16-256, siglip2-b16-384, siglip2-b16-512, siglip2-l16-256, siglip2-l16-384, siglip2-l16-512, siglip2-so400m-patch14-224, siglip2-so400m-patch14-384, siglip2-so400m-patch16-256, siglip2-so400m-patch16-384, siglip2-so400m-patch16-512 #[arg(long, default_value = "mobileclip2-s0")] pub variant: String, @@ -62,6 +62,24 @@ pub fn config(args: &ClipArgs) -> Result { "mobileclip2-s4" => Config::mobileclip2_s4(), "mobileclip2-b" => Config::mobileclip2_b(), "mobileclip2-l14" => Config::mobileclip2_l14(), + "siglip-b16-224" => Config::siglip_b16_224(), + "siglip-b16-256" => Config::siglip_b16_256(), + "siglip-b16-384" => Config::siglip_b16_384(), + "siglip-b16-512" => Config::siglip_b16_512(), + "siglip-l16-256" => Config::siglip_l16_256(), + "siglip-l16-384" => Config::siglip_l16_384(), + "siglip2-b16-224" => Config::siglip2_b16_224(), + "siglip2-b16-256" => Config::siglip2_b16_256(), + "siglip2-b16-384" => Config::siglip2_b16_384(), + "siglip2-b16-512" => Config::siglip2_b16_512(), + "siglip2-l16-256" => Config::siglip2_l16_256(), + "siglip2-l16-384" => Config::siglip2_l16_384(), + "siglip2-l16-512" => Config::siglip2_l16_512(), + "siglip2-so400m-patch14-224" => Config::siglip2_so400m_patch14_224(), + "siglip2-so400m-patch14-384" => Config::siglip2_so400m_patch14_384(), + "siglip2-so400m-patch16-256" => Config::siglip2_so400m_patch16_256(), + "siglip2-so400m-patch16-384" => Config::siglip2_so400m_patch16_384(), + "siglip2-so400m-patch16-512" => Config::siglip2_so400m_patch16_512(), _ => anyhow::bail!("Unsupported CLIP variant: {}", args.variant), } .with_visual_dtype(args.visual_dtype) diff --git a/mkdocs.yml b/mkdocs.yml index 614fddb..549a728 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -76,6 +76,7 @@ nav: - Guides: - Overview: guides/overview.md - Config System: guides/config.md + - Hub (GitHub / Hugging Face): guides/hub.md - Data Loading: guides/dataloader.md - Execution Providers: guides/ep.md - DType & Quantization: guides/dtype.md diff --git a/src/config/impl_inference.rs b/src/config/impl_inference.rs index 0314c29..cf8fc1c 100644 --- a/src/config/impl_inference.rs +++ b/src/config/impl_inference.rs @@ -102,18 +102,6 @@ impl crate::Config { self } - /// Set max tokens. - pub fn with_max_tokens(mut self, max_tokens: u64) -> Self { - self.inference.max_tokens = Some(max_tokens); - self - } - - /// Set ignore eos flag. - pub fn with_ignore_eos(mut self, ignore_eos: bool) -> Self { - self.inference.ignore_eos = ignore_eos; - self - } - /// Get class confidences (accessor for inference params). pub fn class_confs(&self) -> &[f32] { &self.inference.class_confs @@ -135,16 +123,6 @@ impl crate::Config { self } - // pub fn with_temperature(mut self, temperature: f32) -> Self { - // self.inference.temperature = temperature; - // self - // } - - // pub fn with_topp(mut self, topp: f32) -> Self { - // self.inference.topp = topp; - // self - // } - /// Get text confidences. pub fn text_confs(&self) -> &[f32] { &self.inference.text_confs diff --git a/src/config/impl_text_processor.rs b/src/config/impl_text_processor.rs index a9c2b00..9ffd310 100644 --- a/src/config/impl_text_processor.rs +++ b/src/config/impl_text_processor.rs @@ -17,17 +17,17 @@ impl crate::Config { self } - // /// Set maximum number of tokens to generate. - // pub fn with_max_tokens(mut self, n: u64) -> Self { - // self.text_processor.max_tokens = Some(n); - // self - // } - - // /// Set whether to ignore the end-of-sequence token. - // pub fn with_ignore_eos(mut self, ignore_eos: bool) -> Self { - // self.text_processor.ignore_eos = ignore_eos; - // self - // } + /// Set maximum number of tokens to generate. + pub fn with_max_tokens(mut self, n: u64) -> Self { + self.text_processor.max_tokens = Some(n); + self + } + + /// Set whether to ignore the end-of-sequence token. + pub fn with_ignore_eos(mut self, ignore_eos: bool) -> Self { + self.text_processor.ignore_eos = ignore_eos; + self + } /// Set special tokens map file. pub fn with_special_tokens_map_file(mut self, file: impl Into) -> Self { @@ -40,4 +40,14 @@ impl crate::Config { self.text_processor.config_file = Some(file.into()); self } + + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.text_processor.temperature = temperature; + self + } + + pub fn with_topp(mut self, topp: f32) -> Self { + self.text_processor.topp = topp; + self + } } diff --git a/src/config/inference_params.rs b/src/config/inference_params.rs index 215fde6..8ae4148 100644 --- a/src/config/inference_params.rs +++ b/src/config/inference_params.rs @@ -34,14 +34,6 @@ pub struct InferenceParams { pub db_unclip_ratio: Option, pub db_binary_thresh: Option, pub token_level_class: bool, - /// Maximum number of tokens to generate. - pub max_tokens: Option, - /// Whether to ignore the end-of-sequence token. - pub ignore_eos: bool, - // /// Temperature parameter for text generation. - // pub temperature: f32, - // /// Top-p parameter for nucleus sampling. - // pub topp: f32, // Task-specific parameters #[cfg(feature = "vision")] @@ -83,10 +75,6 @@ impl Default for InferenceParams { find_contours: Default::default(), up_scale: 2.0, text_names: Default::default(), - max_tokens: Default::default(), - ignore_eos: Default::default(), - // temperature: 1.0, - // topp: 0.9, token_level_class: Default::default(), #[cfg(feature = "vision")] yolo_preds_format: Default::default(), diff --git a/src/dataloader/hub.rs b/src/dataloader/hub.rs index 51d257f..1796818 100644 --- a/src/dataloader/hub.rs +++ b/src/dataloader/hub.rs @@ -23,43 +23,46 @@ pub(crate) struct Release { pub assets: Vec, } -/// Manages interactions with GitHub repository releases and Hugging Face repositories +/// Manages downloading and caching files from GitHub releases and Hugging Face repositories. /// -/// # Format Rules -/// - **GitHub Release** -/// - Use `/` format. -/// - Example: `"yolo/v5-n-det.onnx"`. +/// # Supported Formats /// -/// - **Hugging Face** -/// - With `Hub::from_hf(owner, repo)`: File paths are interpreted relative to the repo root. -/// - Example: `"sentencepiece.bpe.model"`, `"onnx/tokenizer.json"`. -/// - With `Hub::default()`: Paths with three segments are interpreted as `//`. -/// - Example: `"BAAI/bge-m3/sentencepiece.bpe.model"`. +/// | Source | Format | Example | +/// |--------|--------|---------| +/// | Local file | File path | `"./model.onnx"` | +/// | GitHub Release URL | Full URL | `"https://github.com///releases/download//"` | +/// | GitHub Release | `/` | `"yolo/v5-n-det.onnx"` | +/// | HF URL | Full URL | `"https://huggingface.co///resolve/main/"` | +/// | HF (inline) | `//` | `"BAAI/bge-m3/tokenizer.json"` | +/// | HF (dedicated) | `` (via `from_hf`) | `"onnx/model.onnx"` | /// /// # Examples /// -/// ## GitHub Release Download /// ```rust,ignore +/// // GitHub Release /// let mut hub = Hub::default(); -/// // let mut hub = Hub::new(owner, repo); // Optional: Specify owner and repo if not using default -/// let path = hub.try_fetch("images/bus.jpg")?; // / format -/// ``` +/// let path = hub.try_fetch("images/bus.jpg")?; /// -/// ## Hugging Face Download (Dedicated Hub) -/// ```rust,ignore -/// let mut hub = Hub::from_hf("BAAI", "bge-m3")?; -/// let path = hub.try_fetch("sentencepiece.bpe.model")?; // Any format works -/// let path = hub.try_fetch("onnx/tokenizer.json")?; // Any format works -/// ``` +/// // HF (inline) +/// let path = Hub::default().try_fetch("//")?; +/// let path = Hub::default().try_fetch("///")?; /// -/// ## Hugging Face Download (Temporary) -/// ```rust,ignore -/// let mut hub = Hub::default().try_fetch("BAAI/bge-m3/tokenizer_config.json")?; // // format +/// // HF (dedicated) +/// let mut hub = Hub::from_hf("", "")?; +/// let path = hub.try_fetch("")?; +/// +/// // HF URL (resolve or blob) +/// let path = Hub::default().try_fetch( +/// "https://huggingface.co///blob/main/" +/// )?; /// ``` /// +/// # HF Endpoint +/// Default: `https://huggingface.co`. Override via `HF_ENDPOINT` env var. +/// /// # Errors -/// Methods in `Hub` return `Result` types. Errors may occur due to invalid paths, failed -/// network requests, cache write failures, or mismatched file sizes during downloads. +/// Methods return `Result`. Errors may occur due to invalid paths, failed +/// network requests, cache write failures, or mismatched file sizes. /// #[derive(Debug)] pub struct Hub { @@ -79,13 +82,11 @@ pub struct Hub { /// The maximum number of retry attempts for failed downloads or network operations max_attempts: u32, - /// HF Endpoint (only used when hf-hub feature is enabled) - #[cfg(feature = "hf-hub")] + /// HF Endpoint hf_endpoint: String, - /// HF Api Repo (only used when hf-hub feature is enabled) - #[cfg(feature = "hf-hub")] - hf_repo: Option, + /// Hugging Face repository (owner, repo). When set, try_fetch uses HF download path. + hf_repo: Option<(String, String)>, } impl Default for Hub { @@ -108,10 +109,8 @@ impl Default for Hub { repo, to, max_attempts, - #[cfg(feature = "hf-hub")] hf_endpoint: std::env::var("HF_ENDPOINT") .unwrap_or_else(|_| "https://huggingface.co".to_string()), - #[cfg(feature = "hf-hub")] hf_repo: None, ttl: Duration::from_secs(10 * 60), } @@ -127,274 +126,253 @@ impl Hub { } } - #[cfg(feature = "hf-hub")] pub fn from_hf(owner: &str, repo: &str) -> Result { - let mut self_ = Self::new(owner, repo); - let hf_api = hf_hub::api::sync::ApiBuilder::new() - .with_cache_dir( - self_ - .to - .crate_dir_default() - .expect("Faild to get cache dir"), - ) - .with_endpoint(self_.hf_endpoint.clone()) - .with_retries(self_.max_attempts as usize) - .with_progress(true) - .build()?; - self_.hf_repo = Some(hf_api.model(format!("{owner}/{repo}"))); - - Ok(self_) - } - - #[cfg(not(feature = "hf-hub"))] - pub fn from_hf(_owner: &str, _repo: &str) -> Result { - anyhow::bail!("HF hub support is not enabled. Please enable the 'hf-hub' feature.") + Ok(Self { + hf_repo: Some((owner.into(), repo.into())), + ..Default::default() + }) } - /// Attempts to fetch a file from a local path, GitHub release, or Hugging Face repository. - /// - /// The `try_fetch` method supports multiple scenarios: - /// 1. **Local file**: If the provided string is a valid file path, the file is returned without downloading. - /// 2. **GitHub release URL**: If the input matches a valid GitHub release URL, the corresponding file is downloaded. - /// 3. **Hugging Face repository**: If the hub is configured for HF or the path contains HF format, files are downloaded from HF. - /// 4. **Default repository**: If no explicit URL is provided, the method uses the default or configured repository. - /// - /// # Parameters - /// - `s`: A string representing the file to fetch. This can be: - /// - A local file path. - /// - A GitHub release URL (e.g., `https://github.com/owner/repo/releases/download/tag/file`). - /// - A `/` format for fetching from the default GitHub repository. - /// - A HF repository file path (e.g., `"sentencepiece.bpe.model"` when using `from_hf`). - /// - A temporary HF path format (e.g., `"BAAI/bge-m3/sentencepiece.bpe.model"`). - /// - /// # Returns - /// - `Result`: On success, returns the path to the fetched file. + /// Fetches a file from local path, GitHub release, HF URL, or HF repository. /// - /// # Errors - /// - Returns an error if: - /// - The file cannot be found locally. - /// - The URL or tag is invalid. - /// - Network operations fail after the maximum retry attempts. - /// - HF repository access fails. - /// - /// # Examples - /// ```rust,ignore - /// let mut hub = Hub::default(); - /// - /// // Fetch a file from a local path - /// let local_path = hub.try_fetch("local/path/to/file").expect("File not found"); - /// - /// // Fetch a file from a GitHub release URL - /// let url_path = hub.try_fetch("https://github.com/owner/repo/releases/download/tag/file") - /// .expect("Failed to fetch file"); - /// - /// // Fetch a file using the default GitHub repository - /// let default_repo_path = hub.try_fetch("yolo/v5-n-det.onnx").expect("Failed to fetch file"); - /// - /// // Method 1: Fetch from HF repository using dedicated hub - /// let mut hf_hub = Hub::from_hf("BAAI", "bge-m3")?; - /// let hf_path = hf_hub.try_fetch("sentencepiece.bpe.model").expect("Failed to fetch HF file"); - /// - /// // Method 2: Fetch from HF repository using temporary path (doesn't change hub's owner/repo) - /// let temp_hf_path = Hub::default().try_fetch("BAAI/bge-m3/sentencepiece.bpe.model") - /// .expect("Failed to fetch HF file"); - /// ``` + /// Resolution order: + /// 1. Local file path + /// 2. Already cached locally (no network access needed) + /// 3. GitHub release URL + /// 4. Hugging Face URL (`/resolve/` or `/blob/`) + /// 5. Dedicated HF repo (if `from_hf` / `with_hf_repo` was used) + /// 6. Inline HF path (`//`, when path has 3+ segments) + /// 7. Default GitHub release (`/`) pub fn try_fetch(&mut self, s: &str) -> Result { let span = tracing::info_span!("hub_fetch", source = s); let _enter = span.enter(); - #[derive(Default, Debug, aksr::Builder)] - struct Pack { - // owner: String, - // repo: String, - url: String, - tag: String, - file_name: String, - file_size: Option, - } - let mut pack = Pack::default(); - - // saveout + // 1. Local file path let p = PathBuf::from(s); - let saveout = if p.exists() { - // => Local file - p - } else { - // First, check if it's a valid GitHub release URL - // This must be checked BEFORE HF path check, because GitHub URLs also have parts.len() > 2 - if let Some((owner_, repo_, tag_, file_name_)) = Self::is_valid_github_release_url(s) { - // => Valid GitHub release URL - // keep original owner, repo and tag - let saveout = self - .to - .crate_dir_default_with_subs(&[&owner_, &repo_, &tag_])? - .join(&file_name_); - pack = pack.with_url(s).with_tag(&tag_).with_file_name(&file_name_); - if let Some(n) = retry!(self.max_attempts, self.fetch_get_response(s))? - .headers() - .get(ureq::http::header::CONTENT_LENGTH) - .and_then(|v| v.to_str().ok()?.parse::().ok()) - { - pack = pack.with_file_size(n); - } - - saveout - } else { - // Check for HF repo usage (only after confirming it's not a GitHub URL) - #[cfg(feature = "hf-hub")] - { - if let Some(hf_repo) = &self.hf_repo { - return Ok(hf_repo.get(s)?.to_str().unwrap().to_string()); - // from hf repo - } - let parts: Vec<&str> = s.split('/').filter(|x| !x.is_empty()).collect(); - if parts.len() > 2 { - // from hf repo - // Note: this does not update self.owner or self.repo; they are only used temporarily. - let hf_api = hf_hub::api::sync::ApiBuilder::new() - .with_cache_dir( - self.to.crate_dir_default().expect("Faild to get cache dir"), - ) - .with_endpoint(self.hf_endpoint.clone()) - .with_progress(true) - .with_retries(self.max_attempts as usize) - .build()?; - let hf_repo = hf_api.model(format!("{}/{}", parts[0], parts[1])); - return Ok(hf_repo - .get(&parts[2..].join("/"))? - .to_str() - .unwrap() - .to_string()); - } - } - #[cfg(not(feature = "hf-hub"))] - { - let parts: Vec<&str> = s.split('/').filter(|x| !x.is_empty()).collect(); - if parts.len() > 2 { - anyhow::bail!("HF hub support is not enabled. Please enable the 'hf-hub' feature to download from Hugging Face repositories.") - } - } + if p.exists() { + tracing::debug!("Local file accessible: {}", p.display()); + return p + .to_str() + .map(|s| s.to_string()) + .with_context(|| format!("Failed to convert PathBuf: {p:?} to String")); + } - // => Default hub (GitHub release tag/file format) - - // Check remote - match s.split_once('/') { - Some((tag_, file_name_)) => { - let dst = self.to - .crate_dir_default_with_subs(&[tag_])? - .join(file_name_); - - // check if is cached - if !dst.is_file() { - tracing::debug!("File not cached, fetching from remote: {}", dst.display()); - - // Fetch releases - let releases = - match self.get_releases(&self.owner, &self.repo, &self.to, &self.ttl) { - Err(err) => anyhow::bail!( - "Failed to download: No releases found in this repo. Error: {err}" - ), - Ok(releases) => releases, - }; - - // Validate the tag - let tags: Vec = releases.iter().map(|x| x.tag_name.clone()).collect(); - if !tags.contains(&tag_.to_string()) { - anyhow::bail!( - "Failed to download: Tag `{tag_}` not found in GitHub releases. Available tags: {tags:?}" - ); - } else { - // Validate the file - if let Some(release) = releases.iter().find(|r| r.tag_name == tag_) { - let files: Vec<&str> = - release.assets.iter().map(|x| x.name.as_str()).collect(); - if !files.contains(&file_name_) { - anyhow::bail!( - "Failed to download: The file `{file_name_}` is missing in tag `{tag_}`. Available files: {files:?}" - ); - } else { - for f_ in release.assets.iter() { - if f_.name.as_str() == file_name_ { - pack = pack - .with_url(&f_.browser_download_url) - .with_tag(tag_) - .with_file_name(file_name_) - .with_file_size(f_.size); - break; - } - } - } - } - } - } - tracing::debug!("Using cached file: {}", dst.display()); - dst - } - _ => anyhow::bail!( - "Failed to download file from github releases due to invalid format. Expected: /, got: {s}" - ), - } + // 2. Check if already cached locally + if let Some(cached) = self.resolve_cache_path(s)? { + if cached.is_file() { + tracing::debug!("Cache file available: {}", cached.display()); + return cached + .to_str() + .map(|s| s.to_string()) + .with_context(|| format!("Failed to convert PathBuf: {cached:?} to String")); } - }; + } - // Commit the downloaded file, downloading if necessary - if !pack.url.is_empty() { - tracing::debug!("Starting remote file download..."); + // 3. Not cached — resolve source and download + let saveout = if let Some((owner_, repo_, tag_, file_name_)) = + Self::is_valid_github_release_url(s) + { + // => GitHub release URL + tracing::debug!( + "Downloading from explicit GitHub URL: {}/{} (tag: {}, file: {})", + owner_, + repo_, + tag_, + file_name_ + ); + let saveout = self + .to + .crate_dir_default_with_subs(&[&owner_, &repo_, &tag_])? + .join(&file_name_); retry!( self.max_attempts, 1000, 3000, - self.download( - &pack.url, - &saveout, - Some(&format!("{}/{}", pack.tag, pack.file_name)), - ) + self.download(s, &saveout, Some(&format!("{tag_}/{file_name_}"))) )?; - // // Download if the file does not exist or if the size of file does not match - // if saveout.is_file() { - // match pack.file_size { - // None => { - // log::warn!( - // "Failed to retrieve the remote file size. \ - // Download will be skipped, which may cause issues. \ - // Please verify your network connection or ensure the local file is valid and complete." - // ); - // } - // Some(file_size) => { - // if std::fs::metadata(&saveout)?.len() != file_size { - // tracing::debug!( - // "Local file size does not match remote. Starting download." - // ); - // retry!( - // self.max_attempts, - // 1000, - // 3000, - // self.download( - // &pack.url, - // &saveout, - // Some(&format!("{}/{}", pack.tag, pack.file_name)), - // ) - // )?; - // } else { - // tracing::debug!("Local file size matches remote. No download required."); - // } - // } - // } - // } else { - // tracing::debug!("Starting remote file download..."); - // retry!( - // self.max_attempts, - // 1000, - // 3000, - // self.download( - // &pack.url, - // &saveout, - // Some(&format!("{}/{}", pack.tag, pack.file_name)), - // ) - // )?; - // } + saveout + } else if let Some((hf_owner, hf_repo_name, filename)) = Self::parse_hf_url(s) { + // => Hugging Face URL (resolve/blob) + tracing::debug!( + "Downloading from HF URL: {}/{} -> {}", + hf_owner, + hf_repo_name, + filename + ); + return self.download_hf(&hf_owner, &hf_repo_name, &filename); + } else if let Some((hf_owner, hf_repo)) = self.hf_repo.clone() { + // => Dedicated HF mode + tracing::debug!( + "Downloading from dedicated HF repo: {}/{} -> {}", + hf_owner, + hf_repo, + s + ); + return self.download_hf(&hf_owner, &hf_repo, s); + } else { + let parts: Vec<&str> = s.split('/').filter(|x| !x.is_empty()).collect(); + if parts.len() > 2 { + // => Inline HF path: owner/repo/filename + let hf_owner = parts[0]; + let hf_repo_name = parts[1]; + let filename = parts[2..].join("/"); + tracing::debug!( + "Downloading from HF path: {}/{} -> {}", + hf_owner, + hf_repo_name, + filename + ); + return self.download_hf(hf_owner, hf_repo_name, &filename); + } + + // => Default hub (GitHub release tag/file format) + tracing::debug!("Downloading from default GitHub release: {s}"); + match s.split_once('/') { + Some((tag_, file_name_)) => { + let dst = self + .to + .crate_dir_default_with_subs(&[tag_])? + .join(file_name_); + let releases = self + .get_releases(&self.owner, &self.repo, &self.to, &self.ttl) + .map_err(|e| { + anyhow::anyhow!( + "No releases found in {}/{}. Error: {e}", + self.owner, + self.repo + ) + })?; + + let release = + releases + .iter() + .find(|r| r.tag_name == tag_) + .with_context(|| { + let tags: Vec<&str> = + releases.iter().map(|r| r.tag_name.as_str()).collect(); + format!("Tag `{tag_}` not found. Available: {tags:?}") + })?; + + let asset = release + .assets + .iter() + .find(|a| a.name == file_name_) + .with_context(|| { + let files: Vec<&str> = + release.assets.iter().map(|a| a.name.as_str()).collect(); + format!("File `{file_name_}` not in tag `{tag_}`. Available: {files:?}") + })?; + + retry!( + self.max_attempts, + 1000, + 3000, + self.download( + &asset.browser_download_url, + &dst, + Some(&format!("{tag_}/{file_name_}")), + ) + )?; + dst + } + _ => anyhow::bail!("Invalid format. Expected: /, got: {s}"), + } + }; + + tracing::debug!("Download completed: {}", saveout.display()); + saveout + .to_str() + .map(|s| s.to_string()) + .with_context(|| format!("Failed to convert PathBuf: {saveout:?} to String")) + } + + /// Check if a file is already available locally (as a local file or in cache), + /// without downloading. Returns `Some(path)` if found, `None` otherwise. + pub fn cached(&self, s: &str) -> Option { + let p = PathBuf::from(s); + if p.exists() { + return p.to_str().map(|s| s.to_string()); + } + if let Ok(Some(cached)) = self.resolve_cache_path(s) { + if cached.is_file() { + return cached.to_str().map(|s| s.to_string()); + } + } + None + } + + /// Resolves the expected local cache path for a source string without any network calls. + /// + /// Returns `Ok(Some(path))` if a cache path can be determined, `Ok(None)` otherwise. + fn resolve_cache_path(&self, s: &str) -> Result> { + // GitHub release URL + if let Some((owner_, repo_, tag_, file_name_)) = Self::is_valid_github_release_url(s) { + let path = self + .to + .crate_dir_default_with_subs(&[&owner_, &repo_, &tag_])? + .join(&file_name_); + return Ok(Some(path)); + } + + // HF URL (resolve/blob) + if let Some((hf_owner, hf_repo_name, filename)) = Self::parse_hf_url(s) { + let path = self + .to + .crate_dir_default_with_subs(&[&hf_owner, &hf_repo_name])? + .join(&filename); + return Ok(Some(path)); } + // Dedicated HF mode + if let Some((ref hf_owner, ref hf_repo)) = self.hf_repo { + let path = self + .to + .crate_dir_default_with_subs(&[hf_owner, hf_repo])? + .join(s); + return Ok(Some(path)); + } + + let parts: Vec<&str> = s.split('/').filter(|x| !x.is_empty()).collect(); + if parts.len() > 2 { + // Inline HF path: owner/repo/filename + let hf_owner = parts[0]; + let hf_repo_name = parts[1]; + let filename = parts[2..].join("/"); + let path = self + .to + .crate_dir_default_with_subs(&[hf_owner, hf_repo_name])? + .join(&filename); + return Ok(Some(path)); + } + + // Default GitHub release (tag/file) + if let Some((tag_, file_name_)) = s.split_once('/') { + let path = self + .to + .crate_dir_default_with_subs(&[tag_])? + .join(file_name_); + return Ok(Some(path)); + } + + Ok(None) + } + + /// Download a file from Hugging Face, returning the cached local path. + fn download_hf(&self, owner: &str, repo: &str, filename: &str) -> Result { + let saveout = self + .to + .crate_dir_default_with_subs(&[owner, repo])? + .join(filename); + let url = format!( + "{}/{}/{}/resolve/main/{}", + self.hf_endpoint, owner, repo, filename + ); + retry!( + self.max_attempts, + 1000, + 3000, + self.download(&url, &saveout, Some(&format!("{owner}/{repo}/{filename}")),) + )?; saveout .to_str() .map(|s| s.to_string()) @@ -475,7 +453,7 @@ impl Hub { Ok(y) } - /// Download a file from a github release to a specified path with a progress bar + /// Download a file from a URL to a local path with a progress bar. pub fn download + std::fmt::Debug>( &self, src: &str, @@ -497,10 +475,12 @@ impl Hub { let ntotal = resp .headers() .get(ureq::http::header::CONTENT_LENGTH) - .and_then(|v| v.to_str().ok()?.parse::().ok()) - .context("Content-Length header is missing or invalid")?; + .and_then(|v| v.to_str().ok()?.parse::().ok()); - let mut pb = crate::PB::fetch(ntotal); + let mut pb = match ntotal { + Some(n) => crate::PB::fetch(n), + None => crate::PB::fetch_stream(), + }; if let Some(msg) = message { pb = pb.with_message(msg); } @@ -527,11 +507,13 @@ impl Hub { pb.inc(bytes_read as u64); } - // Verify download completeness - if downloaded_bytes as u64 != ntotal { - anyhow::bail!( - "The downloaded file is incomplete. Expected: {ntotal} bytes, got: {downloaded_bytes} bytes" - ); + // Verify download completeness (only when Content-Length is known) + if let Some(ntotal) = ntotal { + if downloaded_bytes as u64 != ntotal { + anyhow::bail!( + "The downloaded file is incomplete. Expected: {ntotal} bytes, got: {downloaded_bytes} bytes" + ); + } } // Only persist the temporary file to the final destination if download is complete @@ -620,6 +602,23 @@ impl Hub { } } + /// Parse a Hugging Face URL into (owner, repo, filename). + /// + /// Supports both `/resolve/` and `/blob/` URL formats: + /// - `https://huggingface.co///resolve/main/?download=true` + /// - `https://hf-mirror.com///blob/main/` + pub fn parse_hf_url(url: &str) -> Option<(String, String, String)> { + let re = + Regex::new(r"^https?://[^/]+/([^/]+)/([^/]+)/(?:resolve|blob)/[^/]+/(.+?)(?:\?.*)?$") + .ok()?; + let caps = re.captures(url)?; + Some(( + caps.get(1)?.as_str().to_string(), + caps.get(2)?.as_str().to_string(), + caps.get(3)?.as_str().to_string(), + )) + } + pub fn with_owner(mut self, owner: &str) -> Self { self.owner = owner.to_string(); self @@ -630,12 +629,11 @@ impl Hub { self } - #[cfg(feature = "hf-hub")] - pub fn with_hf_owner_repo(self, owner: &str, repo: &str) -> Result { - Self::from_hf(owner, repo) + pub fn with_hf_repo(mut self, owner: &str, repo: &str) -> Self { + self.hf_repo = Some((owner.into(), repo.into())); + self } - #[cfg(feature = "hf-hub")] pub fn with_hf_endpoint(mut self, x: &str) -> Self { self.hf_endpoint = x.to_string(); self @@ -656,48 +654,78 @@ impl Hub { /// For Hugging Face repositories, shows commit SHA and file tree. /// For GitHub repositories, shows releases with file counts and provides usage examples. pub fn info(&self) -> Result<()> { - println!("Repository: {}/{}", self.owner, self.repo); - - #[cfg(feature = "hf-hub")] - { - if let Some(hf_repo) = &self.hf_repo { - let info = hf_repo.info()?; - - println!("Type: Hugging Face Repository"); - println!("Commit SHA: {}", info.sha); - - println!("\nFiles ({} total):", info.siblings.len()); - let mut files: Vec<_> = - info.siblings.iter().map(|s| s.rfilename.as_str()).collect(); - files.sort(); - - Self::print_tree(&files, ""); - return Ok(()); + if let Some((hf_owner, hf_repo)) = &self.hf_repo { + println!("Repository: {hf_owner}/{hf_repo}"); + println!("Type: Hugging Face Repository"); + println!("Endpoint: {}", self.hf_endpoint); + let url = format!( + "{}/api/models/{}/{}?blobs=true", + self.hf_endpoint, hf_owner, hf_repo + ); + match self.fetch_get_response(&url) { + Ok(resp) => { + let body: serde_json::Value = serde_json::from_str( + &resp + .into_body() + .read_to_string() + .context("Failed to read HF API response")?, + )?; + if let Some(sha) = body["sha"].as_str() { + println!("Commit SHA: {sha}"); + } + if let Some(siblings) = body["siblings"].as_array() { + println!("\nFiles ({} total):", siblings.len()); + let mut files: Vec<(String, Option)> = siblings + .iter() + .filter_map(|s| { + let name = s["rfilename"].as_str()?.to_string(); + let size = s["size"].as_u64().or_else(|| { + s["lfs"].as_object().and_then(|lfs| lfs["size"].as_u64()) + }); + Some((name, size)) + }) + .collect(); + files.sort_by(|a, b| a.0.cmp(&b.0)); + let refs: Vec<(&str, Option)> = + files.iter().map(|(n, s)| (n.as_str(), *s)).collect(); + Self::print_tree(&refs, ""); + } + } + Err(e) => { + println!("Failed to fetch repository info: {e}"); + println!("Visit: {}/{}/{}", self.hf_endpoint, hf_owner, hf_repo); + } } + return Ok(()); } + println!("Repository: {}/{}", self.owner, self.repo); println!("Type: GitHub Release Repository"); - let tags = self.tags(); + let releases = self + .get_releases(&self.owner, &self.repo, &self.to, &self.ttl) + .unwrap_or_default(); - if tags.is_empty() { + if releases.is_empty() { println!("No releases found in this repository."); } else { - println!("\nReleases ({} total):", tags.len()); - for tag in &tags { - let files = self.files(tag); - println!(" {} ({} files):", tag, files.len()); + println!("\nReleases ({} total):", releases.len()); + for release in &releases { + println!(" {} ({} files):", release.tag_name, release.assets.len()); - if files.is_empty() { + if release.assets.is_empty() { println!(" (no files)"); - } else if files.len() <= 5 { - // Show all files if 5 or fewer - let file_refs: Vec<&str> = files.iter().map(|s| s.as_str()).collect(); - Self::print_tree(&file_refs, " "); } else { - // Show first 5 files and indicate there are more - let file_refs: Vec<&str> = files.iter().take(5).map(|s| s.as_str()).collect(); - Self::print_tree(&file_refs, " "); - println!(" ... and {} more files", files.len() - 5); + let files: Vec<(&str, Option)> = release + .assets + .iter() + .map(|a| (a.name.as_str(), Some(a.size))) + .collect(); + if files.len() <= 5 { + Self::print_tree(&files, " "); + } else { + Self::print_tree(&files[..5], " "); + println!(" ... and {} more files", files.len() - 5); + } } println!(); } @@ -715,19 +743,20 @@ impl Hub { Ok(()) } - fn print_tree(files: &[&str], prefix: &str) { + fn print_tree(files: &[(&str, Option)], prefix: &str) { use std::collections::HashMap; #[derive(Default)] struct TreeNode { children: HashMap, is_file: bool, + size: Option, } let mut root = TreeNode::default(); // Build tree structure - for file_path in files { + for (file_path, size) in files { let parts: Vec<&str> = file_path.split('/').collect(); let mut current = &mut root; @@ -736,6 +765,7 @@ impl Hub { current = current.children.entry(part.to_string()).or_default(); if is_last { current.is_file = true; + current.size = *size; } } } @@ -744,7 +774,18 @@ impl Hub { fn print_node(node: &TreeNode, name: &str, prefix: &str, is_last: bool, is_root: bool) { if !is_root { let connector = if is_last { "└── " } else { "├── " }; - println!("{prefix}{connector}{name}"); + if node.is_file { + if let Some(size) = node.size { + println!( + "{prefix}{connector}{name} [{}]", + crate::human_bytes_binary(size as f64, 3) + ); + } else { + println!("{prefix}{connector}{name}"); + } + } else { + println!("{prefix}{connector}{name}"); + } } let mut children: Vec<_> = node.children.iter().collect(); diff --git a/src/models/vision/db/config.rs b/src/models/vision/db/config.rs index ae41a1f..3484fc6 100644 --- a/src/models/vision/db/config.rs +++ b/src/models/vision/db/config.rs @@ -63,8 +63,8 @@ impl crate::Config { fn ppocr_det_v5() -> Self { Self::db() - .with_model_ixx(0, 2, (608, 640, 1600)) - .with_model_ixx(0, 3, (608, 640, 1600)) + .with_model_ixx(0, 2, (608, 960, 1600)) + .with_model_ixx(0, 3, (608, 960, 1600)) } /// PaddleOCR v5 mobile detection diff --git a/src/models/vlm/clip/config.rs b/src/models/vlm/clip/config.rs index c8093a0..e19f505 100644 --- a/src/models/vlm/clip/config.rs +++ b/src/models/vlm/clip/config.rs @@ -13,13 +13,31 @@ /// > /// > # Model Variants /// > -/// > - **clip-vit-b16**: ViT-B/16 model for general image-text tasks -/// > - **clip-vit-b32**: ViT-B/32 model for general image-text tasks -/// > - **clip-vit-l14**: ViT-L/14 model for general image-text tasks -/// > - **jina-clip-v1**: Jina CLIP v1 with improved performance -/// > - **jina-clip-v2**: Jina CLIP v2 with 512x512 resolution -/// > - **mobileclip-s0/s1/s2/b/blt**: MobileCLIP variants for mobile devices -/// > - **mobileclip2-s0/s2/s3/s4/b/l14**: MobileCLIP v2 variants +/// > ## OpenAI CLIP +/// > - **clip-vit-b16**: ViT-B/16 (85M params) +/// > - **clip-vit-b32**: ViT-B/32 (87M params) +/// > - **clip-vit-l14**: ViT-L/14 (304M params) +/// > +/// > ## Jina CLIP +/// > - **jina-clip-v1**: Improved performance, 224x224 +/// > - **jina-clip-v2**: 512x512 resolution, better accuracy +/// > +/// > ## MobileCLIP (Apple) +/// > - **mobileclip-s0/s1/s2**: Small variants (0-2) +/// > - **mobileclip-b**: Base variant +/// > - **mobileclip-blt**: Base with large text encoder +/// > +/// > ## MobileCLIP v2 +/// > - **mobileclip2-s0/s2/s4/b/l14**: Enhanced mobile variants +/// > +/// > ## SigLIP (Google DeepMind) +/// > - **siglip-b16-224/256/384/512**: Base models, patch16 +/// > - **siglip-l16-256/384**: Large models, patch16 +/// > +/// > ## SigLIP v2 (Google DeepMind) +/// > - **siglip2-b16-224/256/384/512**: Base models v2 +/// > - **siglip2-l16-256/384/512**: Large models v2 +/// > - **siglip2-so400m-patch14/16**: 400M parameter models /// > /// > # Implemented Features / Tasks /// > @@ -27,7 +45,9 @@ /// > - [X] **Image-Text Retrieval**: Retrieve relevant text for images /// > - [X] **Text-Image Retrieval**: Retrieve relevant images for text /// > - [X] **Mobile Optimization**: Lightweight models for mobile devices -/// > - [X] **Multi-Scale Support**: Various input resolutions +/// > - [X] **Multi-Scale Support**: Various input resolutions (224, 256, 384, 512) +/// > - [X] **Dual Encoder Architecture**: Separate vision and text encoders +/// > - [X] **Contrastive Learning**: Image-text similarity scoring /// > /// Model configuration for `CLIP` /// @@ -197,4 +217,305 @@ impl crate::Config { .with_textual_file("l-14-textual.onnx") .with_visual_file("l-14-visual.onnx") } + + pub fn siglip() -> Self { + Self::clip() + .with_name("clip") + .with_batch_size_min_opt_max_all(1, 1, 8) // batch size + .with_visual_ixx(0, 1, 3) // channel + .with_textual_ixx(0, 1, 64) // seq len + .with_image_mean([0.5, 0.5, 0.5]) + .with_image_std([0.5, 0.5, 0.5]) + .with_model_max_length(64) + } + + /// SigLIP Base, patch16, 224x224 + pub fn siglip_b16_224() -> Self { + Self::siglip() + .with_visual_ixx(0, 2, 224) + .with_visual_ixx(0, 3, 224) + .with_tokenizer_file("Xenova/siglip-base-patch16-224/tokenizer.json") + .with_tokenizer_config_file("Xenova/siglip-base-patch16-224/tokenizer_config.json") + .with_special_tokens_map_file("Xenova/siglip-base-patch16-224/special_tokens_map.json") + .with_textual_file("Xenova/siglip-base-patch16-224/onnx/text_model.onnx") + .with_visual_file("Xenova/siglip-base-patch16-224/onnx/vision_model.onnx") + } + + /// SigLIP Base, patch16, 256x256 + pub fn siglip_b16_256() -> Self { + Self::siglip() + .with_textual_file("Xenova/siglip-base-patch16-256/onnx/text_model.onnx") + .with_visual_file("Xenova/siglip-base-patch16-256/onnx/vision_model.onnx") + .with_visual_ixx(0, 2, 256) + .with_visual_ixx(0, 3, 256) + .with_tokenizer_file("Xenova/siglip-base-patch16-256/tokenizer.json") + .with_tokenizer_config_file("Xenova/siglip-base-patch16-256/tokenizer_config.json") + .with_special_tokens_map_file("Xenova/siglip-base-patch16-256/special_tokens_map.json") + } + + /// SigLIP Base, patch16, 384x384 + pub fn siglip_b16_384() -> Self { + Self::siglip() + .with_textual_file("Xenova/siglip-base-patch16-384/onnx/text_model.onnx") + .with_visual_file("Xenova/siglip-base-patch16-384/onnx/vision_model.onnx") + .with_visual_ixx(0, 2, 384) + .with_visual_ixx(0, 3, 384) + .with_tokenizer_file("Xenova/siglip-base-patch16-384/tokenizer.json") + .with_tokenizer_config_file("Xenova/siglip-base-patch16-384/tokenizer_config.json") + .with_special_tokens_map_file("Xenova/siglip-base-patch16-384/special_tokens_map.json") + } + + /// SigLIP Base, patch16, 512x512 + pub fn siglip_b16_512() -> Self { + Self::siglip() + .with_textual_file("Xenova/siglip-base-patch16-512/onnx/text_model.onnx") + .with_visual_file("Xenova/siglip-base-patch16-512/onnx/vision_model.onnx") + .with_visual_ixx(0, 2, 512) + .with_visual_ixx(0, 3, 512) + .with_tokenizer_file("Xenova/siglip-base-patch16-512/tokenizer.json") + .with_tokenizer_config_file("Xenova/siglip-base-patch16-512/tokenizer_config.json") + .with_special_tokens_map_file("Xenova/siglip-base-patch16-512/special_tokens_map.json") + } + + /// SigLIP Large, patch16, 256x256 + pub fn siglip_l16_256() -> Self { + Self::siglip() + .with_textual_file("Xenova/siglip-large-patch16-256/onnx/text_model.onnx") + .with_visual_file("Xenova/siglip-large-patch16-256/onnx/vision_model.onnx") + .with_visual_ixx(0, 2, 256) + .with_visual_ixx(0, 3, 256) + .with_tokenizer_file("Xenova/siglip-large-patch16-256/tokenizer.json") + .with_tokenizer_config_file("Xenova/siglip-large-patch16-256/tokenizer_config.json") + .with_special_tokens_map_file("Xenova/siglip-large-patch16-256/special_tokens_map.json") + } + + /// SigLIP Large, patch16, 384x384 + pub fn siglip_l16_384() -> Self { + Self::siglip() + .with_textual_file("Xenova/siglip-large-patch16-384/onnx/text_model.onnx") + .with_visual_file("Xenova/siglip-large-patch16-384/onnx/vision_model.onnx") + .with_visual_ixx(0, 2, 384) + .with_visual_ixx(0, 3, 384) + .with_tokenizer_file("Xenova/siglip-large-patch16-384/tokenizer.json") + .with_tokenizer_config_file("Xenova/siglip-large-patch16-384/tokenizer_config.json") + .with_special_tokens_map_file("Xenova/siglip-large-patch16-384/special_tokens_map.json") + } + + /// SigLIP v2 Base, patch16, 224x224 + pub fn siglip2_b16_224() -> Self { + Self::siglip() + .with_visual_ixx(0, 2, 224) + .with_visual_ixx(0, 3, 224) + .with_tokenizer_file("onnx-community/siglip2-base-patch16-224-ONNX/tokenizer.json") + .with_tokenizer_config_file( + "onnx-community/siglip2-base-patch16-224-ONNX/tokenizer_config.json", + ) + .with_special_tokens_map_file( + "onnx-community/siglip2-base-patch16-224-ONNX/special_tokens_map.json", + ) + .with_textual_file("onnx-community/siglip2-base-patch16-224-ONNX/onnx/text_model.onnx") + .with_visual_file("onnx-community/siglip2-base-patch16-224-ONNX/onnx/vision_model.onnx") + } + + /// SigLIP v2 Base, patch16, 256x256 + pub fn siglip2_b16_256() -> Self { + Self::siglip() + .with_visual_ixx(0, 2, 256) + .with_visual_ixx(0, 3, 256) + .with_tokenizer_file("onnx-community/siglip2-base-patch16-256-ONNX/tokenizer.json") + .with_tokenizer_config_file( + "onnx-community/siglip2-base-patch16-256-ONNX/tokenizer_config.json", + ) + .with_special_tokens_map_file( + "onnx-community/siglip2-base-patch16-256-ONNX/special_tokens_map.json", + ) + .with_textual_file("onnx-community/siglip2-base-patch16-256-ONNX/onnx/text_model.onnx") + .with_visual_file("onnx-community/siglip2-base-patch16-256-ONNX/onnx/vision_model.onnx") + } + + /// SigLIP v2 Base, patch16, 384x384 + pub fn siglip2_b16_384() -> Self { + Self::siglip() + .with_visual_ixx(0, 2, 384) + .with_visual_ixx(0, 3, 384) + .with_tokenizer_file("onnx-community/siglip2-base-patch16-384-ONNX/tokenizer.json") + .with_tokenizer_config_file( + "onnx-community/siglip2-base-patch16-384-ONNX/tokenizer_config.json", + ) + .with_special_tokens_map_file( + "onnx-community/siglip2-base-patch16-384-ONNX/special_tokens_map.json", + ) + .with_textual_file("onnx-community/siglip2-base-patch16-384-ONNX/onnx/text_model.onnx") + .with_visual_file("onnx-community/siglip2-base-patch16-384-ONNX/onnx/vision_model.onnx") + } + + /// SigLIP v2 Base, patch16, 512x512 + pub fn siglip2_b16_512() -> Self { + Self::siglip() + .with_visual_ixx(0, 2, 512) + .with_visual_ixx(0, 3, 512) + .with_tokenizer_file("onnx-community/siglip2-base-patch16-512-ONNX/tokenizer.json") + .with_tokenizer_config_file( + "onnx-community/siglip2-base-patch16-512-ONNX/tokenizer_config.json", + ) + .with_special_tokens_map_file( + "onnx-community/siglip2-base-patch16-512-ONNX/special_tokens_map.json", + ) + .with_textual_file("onnx-community/siglip2-base-patch16-512-ONNX/onnx/text_model.onnx") + .with_visual_file("onnx-community/siglip2-base-patch16-512-ONNX/onnx/vision_model.onnx") + } + + /// SigLIP v2 Large, patch16, 256x256 + pub fn siglip2_l16_256() -> Self { + Self::siglip() + .with_visual_ixx(0, 2, 256) + .with_visual_ixx(0, 3, 256) + .with_tokenizer_file("onnx-community/siglip2-large-patch16-256-ONNX/tokenizer.json") + .with_tokenizer_config_file( + "onnx-community/siglip2-large-patch16-256-ONNX/tokenizer_config.json", + ) + .with_special_tokens_map_file( + "onnx-community/siglip2-large-patch16-256-ONNX/special_tokens_map.json", + ) + .with_textual_file("onnx-community/siglip2-large-patch16-256-ONNX/onnx/text_model.onnx") + .with_visual_file( + "onnx-community/siglip2-large-patch16-256-ONNX/onnx/vision_model.onnx", + ) + } + + /// SigLIP v2 Large, patch16, 384x384 + pub fn siglip2_l16_384() -> Self { + Self::siglip() + .with_visual_ixx(0, 2, 384) + .with_visual_ixx(0, 3, 384) + .with_tokenizer_file("onnx-community/siglip2-large-patch16-384-ONNX/tokenizer.json") + .with_tokenizer_config_file( + "onnx-community/siglip2-large-patch16-384-ONNX/tokenizer_config.json", + ) + .with_special_tokens_map_file( + "onnx-community/siglip2-large-patch16-384-ONNX/special_tokens_map.json", + ) + .with_textual_file("onnx-community/siglip2-large-patch16-384-ONNX/onnx/text_model.onnx") + .with_visual_file( + "onnx-community/siglip2-large-patch16-384-ONNX/onnx/vision_model.onnx", + ) + } + + /// SigLIP v2 Large, patch16, 512x512 + pub fn siglip2_l16_512() -> Self { + Self::siglip() + .with_visual_ixx(0, 2, 512) + .with_visual_ixx(0, 3, 512) + .with_tokenizer_file("onnx-community/siglip2-large-patch16-512-ONNX/tokenizer.json") + .with_tokenizer_config_file( + "onnx-community/siglip2-large-patch16-512-ONNX/tokenizer_config.json", + ) + .with_special_tokens_map_file( + "onnx-community/siglip2-large-patch16-512-ONNX/special_tokens_map.json", + ) + .with_textual_file("onnx-community/siglip2-large-patch16-512-ONNX/onnx/text_model.onnx") + .with_visual_file( + "onnx-community/siglip2-large-patch16-512-ONNX/onnx/vision_model.onnx", + ) + } + + /// SigLIP v2 400M, patch14, 224x224 + pub fn siglip2_so400m_patch14_224() -> Self { + Self::siglip() + .with_visual_ixx(0, 2, 224) + .with_visual_ixx(0, 3, 224) + .with_tokenizer_file("onnx-community/siglip2-so400m-patch14-224-ONNX/tokenizer.json") + .with_tokenizer_config_file( + "onnx-community/siglip2-so400m-patch14-224-ONNX/tokenizer_config.json", + ) + .with_special_tokens_map_file( + "onnx-community/siglip2-so400m-patch14-224-ONNX/special_tokens_map.json", + ) + .with_textual_file( + "onnx-community/siglip2-so400m-patch14-224-ONNX/onnx/text_model.onnx", + ) + .with_visual_file( + "onnx-community/siglip2-so400m-patch14-224-ONNX/onnx/vision_model.onnx", + ) + } + + /// SigLIP v2 400M, patch14, 384x384 + pub fn siglip2_so400m_patch14_384() -> Self { + Self::siglip() + .with_visual_ixx(0, 2, 384) + .with_visual_ixx(0, 3, 384) + .with_tokenizer_file("onnx-community/siglip2-so400m-patch14-384-ONNX/tokenizer.json") + .with_tokenizer_config_file( + "onnx-community/siglip2-so400m-patch14-384-ONNX/tokenizer_config.json", + ) + .with_special_tokens_map_file( + "onnx-community/siglip2-so400m-patch14-384-ONNX/special_tokens_map.json", + ) + .with_textual_file( + "onnx-community/siglip2-so400m-patch14-384-ONNX/onnx/text_model.onnx", + ) + .with_visual_file( + "onnx-community/siglip2-so400m-patch14-384-ONNX/onnx/vision_model.onnx", + ) + } + + /// SigLIP v2 400M, patch16, 256x256 + pub fn siglip2_so400m_patch16_256() -> Self { + Self::siglip() + .with_visual_ixx(0, 2, 256) + .with_visual_ixx(0, 3, 256) + .with_tokenizer_file("onnx-community/siglip2-so400m-patch16-256-ONNX/tokenizer.json") + .with_tokenizer_config_file( + "onnx-community/siglip2-so400m-patch16-256-ONNX/tokenizer_config.json", + ) + .with_special_tokens_map_file( + "onnx-community/siglip2-so400m-patch16-256-ONNX/special_tokens_map.json", + ) + .with_textual_file( + "onnx-community/siglip2-so400m-patch16-256-ONNX/onnx/text_model.onnx", + ) + .with_visual_file( + "onnx-community/siglip2-so400m-patch16-256-ONNX/onnx/vision_model.onnx", + ) + } + + /// SigLIP v2 400M, patch16, 384x384 + pub fn siglip2_so400m_patch16_384() -> Self { + Self::siglip() + .with_visual_ixx(0, 2, 384) + .with_visual_ixx(0, 3, 384) + .with_tokenizer_file("onnx-community/siglip2-so400m-patch16-384-ONNX/tokenizer.json") + .with_tokenizer_config_file( + "onnx-community/siglip2-so400m-patch16-384-ONNX/tokenizer_config.json", + ) + .with_special_tokens_map_file( + "onnx-community/siglip2-so400m-patch16-384-ONNX/special_tokens_map.json", + ) + .with_textual_file( + "onnx-community/siglip2-so400m-patch16-384-ONNX/onnx/text_model.onnx", + ) + .with_visual_file( + "onnx-community/siglip2-so400m-patch16-384-ONNX/onnx/vision_model.onnx", + ) + } + + /// SigLIP v2 400M, patch16, 512x512 + pub fn siglip2_so400m_patch16_512() -> Self { + Self::siglip() + .with_visual_ixx(0, 2, 512) + .with_visual_ixx(0, 3, 512) + .with_tokenizer_file("onnx-community/siglip2-so400m-patch16-512-ONNX/tokenizer.json") + .with_tokenizer_config_file( + "onnx-community/siglip2-so400m-patch16-512-ONNX/tokenizer_config.json", + ) + .with_special_tokens_map_file( + "onnx-community/siglip2-so400m-patch16-512-ONNX/special_tokens_map.json", + ) + .with_textual_file( + "onnx-community/siglip2-so400m-patch16-512-ONNX/onnx/text_model.onnx", + ) + .with_visual_file( + "onnx-community/siglip2-so400m-patch16-512-ONNX/onnx/vision_model.onnx", + ) + } } diff --git a/src/models/vlm/clip/impl.rs b/src/models/vlm/clip/impl.rs index 955295d..b3c3b3d 100644 --- a/src/models/vlm/clip/impl.rs +++ b/src/models/vlm/clip/impl.rs @@ -1,5 +1,6 @@ use aksr::Builder; use anyhow::Result; +use ndarray::Dimension; use crate::{ inputs, Config, Engine, Engines, FromConfig, Image, ImageProcessor, Model, Module, @@ -66,14 +67,22 @@ impl Model for Clip { self.image_processor.process(images)? }); let ys = crate::perf!("CLIP::visual-inference", engines.run(&Module::Visual, &ys)?); - let y = crate::perf!( - "CLIP::visual-postprocess", - ys.get::(0) - .ok_or_else(|| anyhow::anyhow!("Failed to get visual output"))? - .to_owned() - ); + let y = crate::perf!("CLIP::visual-postprocess", { + let mut y = Y::default(); + for i in 0..ys.len() { + if let Some(x) = ys.get::(i) { + match x.dim().ndim() { + 2 => y = y.with_embedding(x.to_owned()), + 3 => y = y.with_last_hidden_state(x.to_owned()), + _ => continue, + } + } + } + + y + }); - Ok(Y::default().with_embedding(y)) + Ok(y) } fn encode_texts(&mut self, engines: &mut Engines, texts: &[&str]) -> Result { @@ -91,13 +100,21 @@ impl Model for Clip { "CLIP::textual-inference", engines.run(&Module::Textual, inputs![ys]?)? ); - let y = crate::perf!( - "CLIP::textual-postprocess", - ys.get::(0) - .ok_or_else(|| anyhow::anyhow!("Failed to get textual output"))? - .to_owned() - ); + let y = crate::perf!("CLIP::textual-postprocess", { + let mut y = Y::default(); + for i in 0..ys.len() { + if let Some(x) = ys.get::(i) { + match x.dim().ndim() { + 2 => y = y.with_embedding(x.to_owned()), + 3 => y = y.with_last_hidden_state(x.to_owned()), + _ => continue, + } + } + } + + y + }); - Ok(Y::default().with_embedding(y)) + Ok(y) } } diff --git a/src/models/vlm/fastvlm/impl.rs b/src/models/vlm/fastvlm/impl.rs index 629325a..639a366 100644 --- a/src/models/vlm/fastvlm/impl.rs +++ b/src/models/vlm/fastvlm/impl.rs @@ -206,7 +206,7 @@ impl Model for FastVLM { let text_embed = Engine::from_config(config.take_module(&Module::Textual)?)?; let decoder = Engine::from_config(config.take_module(&Module::TextualDecoderMerged)?)?; - let max_length = config.inference.max_tokens.unwrap_or(1024); + let max_length = config.text_processor.max_tokens.unwrap_or(1024); let image_token = "".to_string(); let image_token_id = 151646; let eos_token = "<|im_end|>".to_string(); diff --git a/src/models/vlm/smolvlm/impl.rs b/src/models/vlm/smolvlm/impl.rs index 484e436..c386ad5 100644 --- a/src/models/vlm/smolvlm/impl.rs +++ b/src/models/vlm/smolvlm/impl.rs @@ -251,8 +251,8 @@ impl Model for SmolVLM { let eos_token_id = 49279; let image_token_id = 49190; let image_seq_len = 64; - let max_length = config.inference.max_tokens.unwrap_or(1024); - let ignore_eos = config.inference.ignore_eos; + let max_length = config.text_processor.max_tokens.unwrap_or(1024); + let ignore_eos = config.text_processor.ignore_eos; let scale = config .scale .take() diff --git a/src/ort/config.rs b/src/ort/config.rs index 283e5ac..9814d9d 100644 --- a/src/ort/config.rs +++ b/src/ort/config.rs @@ -1,5 +1,6 @@ use aksr::Builder; use anyhow::Result; +use std::collections::HashSet; use crate::{try_fetch_file_stem, DType, Device, EpConfig, Hub, Iiix}; @@ -39,79 +40,212 @@ impl Default for ORTConfig { impl ORTConfig { pub fn try_commit(mut self, name: &str) -> Result { + tracing::debug!( + "Model commit: resolving '{}' with file '{}'", + name, + self.file + ); + // Identify the local model or fetch the remote model if std::path::PathBuf::from(&self.file).exists() { - // Local + // Local file detected - no download required + tracing::debug!("Local model file found, skipping download: {}", &self.file); self.spec = format!("{}/{}", name, try_fetch_file_stem(&self.file)?); } else { if self.file.is_empty() && name.is_empty() { anyhow::bail!( - "Failed to commit model. Invalid model config: neither `name` nor `file` were specified. Failed to fetch model from Hub." + "Failed to commit model. Invalid model config: neither `name` nor `file` were specified. Failed to fetch model from HuggingFace Hub or GitHub release." ) } // Remote match Hub::is_valid_github_release_url(&self.file) { Some((owner, repo, tag, _file_name)) => { + // Explicit GitHub release URL detected + tracing::debug!( + "Explicit GitHub URL detected: {}/{} (tag: {})", + owner, + repo, + tag + ); let stem = try_fetch_file_stem(&self.file)?; self.spec = format!("{name}/{owner}-{repo}-{tag}-{stem}"); self.file = Hub::default().try_fetch(&self.file)?; } None => { - // append dtype to model file - match self.dtype { + // Not an explicit GitHub URL — could be a HuggingFace Hub path or + // a GitHub release file. + // + // Determine whether to prepend `name` as a tag prefix: + // - 1 segment (e.g. "model.onnx") + // → bare filename, prepend `name` to form "tag/file" + // - 2+ segments (e.g. "tag/model.onnx" or "owner/repo/dir/model.onnx") + // → already contains path structure, keep as-is + // + // Note: the actual HF-vs-GitHub distinction is handled inside + // `Hub::try_fetch` (≥3 segments → HF Hub, 2 segments → GitHub release). + let parts = self.file.split('/').filter(|s| !s.is_empty()).count(); + if parts > 1 { + tracing::debug!( + "File path has directory structure, using as-is: {}", + self.file + ); + } else { + tracing::debug!( + "Bare filename, prepending model name as tag: {}/{}", + name, + self.file + ); + self.file = format!("{}/{}", name, self.file); + } + + // Save original path for resolving external data files later + let ext_file = self.file.clone(); + + // Build candidate file paths based on DType: + // - Auto/Fp32: use the file path as-is (or "{dtype}.onnx" if empty) + // - Other dtypes + non-empty file: try "model{delim}{dtype}.onnx" + // with delimiters ['-', '_', '.'] (e.g., "model-fp16.onnx") + // - Other dtypes + empty file: use "{dtype}.onnx" directly + let candidates: Vec = match self.dtype { d @ (DType::Auto | DType::Fp32) => { if self.file.is_empty() { - self.file = format!("{d}.onnx"); + vec![format!("{d}.onnx")] + } else { + vec![self.file.clone()] } } dtype => { if self.file.is_empty() { - self.file = format!("{dtype}.onnx"); + vec![format!("{dtype}.onnx")] } else { - let pos = self.file.len() - 5; // .onnx - let suffix = self.file.split_off(pos); - self.file = format!("{}-{}{}", self.file, dtype, suffix); + ['-', '_', '.'] + .iter() + .map(|delim| { + let mut base = self.file.clone(); + let suffix = base.split_off(base.len() - 5); // 5 -> ".onnx" + format!("{base}{delim}{dtype}{suffix}") + }) + .collect() } } + }; + tracing::debug!( + "Generated {} candidate file paths for resolution: {:?}", + candidates.len(), + candidates + ); + + // Phase 1: Check if any candidate is already cached locally (no HTTP) + let mut hub = Hub::default(); + let mut fetch_success = false; + for file in &candidates { + if let Some(cached_path) = hub.cached(file) { + self.file = cached_path; + let stem = try_fetch_file_stem(file)?; + self.spec = format!("{name}/{stem}"); + tracing::debug!("Cache hit: {} -> {}", file, &self.file); + fetch_success = true; + break; + } } - let stem = try_fetch_file_stem(&self.file)?; - self.spec = format!("{name}/{stem}"); + // Phase 2: Nothing cached — try fetching each candidate from remote + if !fetch_success { + for file in &candidates { + tracing::debug!("Requesting Hub to download: {file}"); + match hub.try_fetch(file) { + Ok(f) => { + self.file = f; + let stem = try_fetch_file_stem(file)?; + self.spec = format!("{name}/{stem}"); + tracing::debug!( + "Successfully resolved candidate '{}' to spec: {}", + file, + &self.spec + ); + fetch_success = true; + break; + } + Err(err) => { + tracing::warn!("Failed to download candidate '{file}': {err}"); + } + } + } + } - let parts: Vec<&str> = self.file.split('/').filter(|x| !x.is_empty()).collect(); - if parts.len() > 1 { - self.file = Hub::default().try_fetch(&self.file)?; - } else { - self.file = Hub::default().try_fetch(&format!("{}/{}", name, self.file))?; + if !fetch_success { + anyhow::bail!( + "Failed to fetch ONNX model file. \ + Neither a GitHub release file nor a HuggingFace Hub file \ + could be resolved. \ + Please verify the model file path: {:?}", + self.file + ); } - // try fetch external data file if it exists - if self.external_data_file { - let external_data_file = format!("{}_data", self.file); - tracing::info!("Trying to fetch external data file {}", external_data_file); + // Attempt to fetch external data files referenced by the model + let proto = crate::load_onnx(&self.file)?; + let graph = proto.graph.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "No graph found in ONNX proto. Invalid model: {}", + self.file + ) + })?; - match Hub::default().try_fetch(&external_data_file) { - Ok(external_data_file) => { - tracing::info!( - "Successfully fetched external data file: {}", - external_data_file + // Collect all external data file locations + let external_files: HashSet = graph + .initializer + .iter() + .filter(|t| t.data_location == 1) + .flat_map(|t| t.external_data.iter()) + .filter(|x| x.key == "location") + .map(|x| x.value.clone()) + .collect(); + + if !external_files.is_empty() { + tracing::debug!( + "Found {} external data files, requesting downloads: {:?}", + external_files.len(), + external_files + ); + } + + // Download all external data files + let base_path = ext_file.rsplit_once('/').map(|(base, _)| base); + for f in &external_files { + let base = base_path.ok_or_else(|| { + anyhow::anyhow!( + "Cannot resolve external data file path: \ + no parent directory in '{ext_file}'" + ) + })?; + let file_path = format!("{base}/{f}"); + match Hub::default().try_fetch(&file_path) { + Ok(local) => { + tracing::debug!( + "Successfully fetched external data file: {} -> {}", + file_path, + local ); } - Err(_) => { - tracing::warn!( - "No external data file found for model {}", - self.file + Err(err) => { + anyhow::bail!( + "Found external data reference '{file_path}' \ + but failed to fetch it: {err}" ); } } - } else { - tracing::info!("External data file is not enabled for model {}", self.file); } } } } + tracing::debug!( + "Model commit completed: spec='{}', file='{}'", + self.spec, + self.file + ); Ok(self) } } diff --git a/src/ort/engine.rs b/src/ort/engine.rs index 10c69a2..9ca7383 100644 --- a/src/ort/engine.rs +++ b/src/ort/engine.rs @@ -8,7 +8,6 @@ use ort::{ tensor::TensorElementType, value::{DynValue, Value}, }; -use prost::Message; use std::{ collections::{HashMap, HashSet}, sync::Arc, @@ -113,7 +112,7 @@ impl FromConfig for Engine { let session: Session; crate::perf!(&format!("ORT Engine ({})::init", config.spec), { - let proto = Self::load_onnx(&config.file)?; + let proto = crate::load_onnx(&config.file)?; let graph = match &proto.graph { Some(graph) => graph, None => { @@ -1346,17 +1345,17 @@ impl Engine { Ok(OrtTensorAttr::new(names, dtypes, dimss, vec![])) } - pub fn load_onnx>(p: P) -> Result { - let path_ref = p.as_ref(); - let f = std::fs::read(path_ref).map_err(|err| { - anyhow::anyhow!("Failed to read ONNX file '{path_ref:?}': {err}. Error: {err}") - })?; - onnx::ModelProto::decode(f.as_slice()).map_err(|err| { - anyhow::anyhow!( - "Failed to read the ONNX model: The file might be incomplete or corrupted. More detailed: {err}" - ) - }) - } + // pub fn load_onnx>(p: P) -> Result { + // let path_ref = p.as_ref(); + // let f = std::fs::read(path_ref).map_err(|err| { + // anyhow::anyhow!("Failed to read ONNX file '{path_ref:?}': {err}. Error: {err}") + // })?; + // onnx::ModelProto::decode(f.as_slice()).map_err(|err| { + // anyhow::anyhow!( + // "Failed to read the ONNX model: The file might be incomplete or corrupted. More detailed: {err}" + // ) + // }) + // } pub fn batch(&self) -> &MinOptMax { &self.inputs.minoptmax[0][0] diff --git a/src/ort/mod.rs b/src/ort/mod.rs index e5819ed..f959656 100644 --- a/src/ort/mod.rs +++ b/src/ort/mod.rs @@ -22,3 +22,16 @@ pub(crate) use iiix::*; pub use inputs::*; pub use min_opt_max::*; pub use xs::*; + +pub fn load_onnx>(p: P) -> anyhow::Result { + use prost::Message; + let path_ref = p.as_ref(); + let f = std::fs::read(path_ref).map_err(|err| { + anyhow::anyhow!("Failed to read ONNX file '{path_ref:?}': {err}. Error: {err}") + })?; + onnx::ModelProto::decode(f.as_slice()).map_err(|err| { + anyhow::anyhow!( + "Failed to read the ONNX model: The file might be incomplete or corrupted. More detailed: {err}" + ) + }) +} diff --git a/src/processor/text/config.rs b/src/processor/text/config.rs index 3e094d8..75b886b 100644 --- a/src/processor/text/config.rs +++ b/src/processor/text/config.rs @@ -24,6 +24,10 @@ pub struct TextProcessorConfig { pub temperature: f32, /// Top-p parameter for nucleus sampling. pub topp: f32, + /// Maximum number of tokens to generate. + pub max_tokens: Option, + /// Whether to ignore the end-of-sequence token. + pub ignore_eos: bool, } impl Default for TextProcessorConfig { @@ -38,38 +42,72 @@ impl Default for TextProcessorConfig { vocab_file: None, temperature: 1.0, topp: 0.9, + max_tokens: Default::default(), + ignore_eos: Default::default(), } } } impl TextProcessorConfig { - // TODO /// Build tokenizer from configuration. pub fn try_build_tokenizer(&self) -> anyhow::Result { + tracing::debug!("Building tokenizer from config"); let mut hub = crate::Hub::default(); + + // Resolve tokenizer file: check cache first, then fetch let mut tokenizer: Tokenizer = match &self.tokenizer_file { None => return Err(anyhow::anyhow!("tokenizer_file is required")), - Some(file) => Tokenizer::from_file(hub.try_fetch(file)?) - .map_err(|err| anyhow::anyhow!("Failed to build tokenizer: {err}"))?, + Some(file) => { + let path = if let Some(cached) = hub.cached(file) { + tracing::debug!("Tokenizer file cache hit: {cached}"); + cached + } else { + tracing::debug!("Tokenizer file not cached, requesting fetch: {file}"); + hub.try_fetch(file)? + }; + Tokenizer::from_file(&path).map_err(|err| { + anyhow::anyhow!("Failed to build tokenizer from '{path}': {err}") + })? + } }; + // TODO + // Resolve tokenizer config file for pad_id let pad_id = match &self.tokenizer_config_file { None => 0u32, - Some(file) => match hub.try_fetch(file) { - Ok(x) => { - let config: serde_json::Value = - serde_json::from_str(&std::fs::read_to_string(x)?)?; - config["pad_token_id"].as_u64().unwrap_or(0) as u32 + Some(file) => { + let path = if let Some(cached) = hub.cached(file) { + tracing::debug!("Tokenizer config file cache hit: {cached}"); + Some(cached) + } else { + tracing::debug!("Tokenizer config file not cached, requesting fetch: {file}"); + hub.try_fetch(file).ok() + }; + match path { + Some(x) => { + let config: serde_json::Value = + serde_json::from_str(&std::fs::read_to_string(&x)?)?; + let id = config["pad_token_id"].as_u64().unwrap_or(0) as u32; + tracing::debug!("Resolved pad_token_id: {id}"); + id + } + None => 0u32, } - Err(_) => 0u32, - }, + } }; let mut max_length = None; let mut pad_token = String::from("[PAD]"); + // TODO + // Resolve tokenizer config for max_length and pad_token if let Some(file) = &self.tokenizer_config_file { - if let Ok(x) = hub.try_fetch(file) { + let path = if let Some(cached) = hub.cached(file) { + Some(cached) + } else { + hub.try_fetch(file).ok() + }; + if let Some(x) = path { let tokenizer_config: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(x)?)?; max_length = tokenizer_config["model_max_length"].as_u64(); @@ -77,15 +115,23 @@ impl TextProcessorConfig { .as_str() .unwrap_or("[PAD]") .to_string(); + tracing::debug!( + "Tokenizer config resolved: max_length={:?}, pad_token='{}'", + max_length, + pad_token + ); } } + // TODO + // Apply padding and truncation let tokenizer = match self.model_max_length { Some(n) => { let n = match max_length { None => n, Some(x) => x.min(n), }; + tracing::debug!("Applying fixed padding: length={n}, pad_id={pad_id}"); tokenizer .with_padding(Some(PaddingParams { strategy: PaddingStrategy::Fixed(n as _), @@ -96,30 +142,39 @@ impl TextProcessorConfig { .clone() } None => match max_length { - Some(n) => tokenizer - .with_padding(Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - pad_token, - pad_id, - ..Default::default() - })) - .with_truncation(Some(TruncationParams { - max_length: n as _, - ..Default::default() - })) - .map_err(|err| anyhow::anyhow!("Failed to truncate: {err}"))? - .clone(), - None => tokenizer - .with_padding(Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - pad_token, - pad_id, - ..Default::default() - })) - .clone(), + Some(n) => { + tracing::debug!("Applying batch-longest padding with truncation: max_length={n}, pad_id={pad_id}"); + tokenizer + .with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + pad_token, + pad_id, + ..Default::default() + })) + .with_truncation(Some(TruncationParams { + max_length: n as _, + ..Default::default() + })) + .map_err(|err| anyhow::anyhow!("Failed to truncate: {err}"))? + .clone() + } + None => { + tracing::debug!( + "Applying batch-longest padding without truncation, pad_id={pad_id}" + ); + tokenizer + .with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + pad_token, + pad_id, + ..Default::default() + })) + .clone() + } }, }; + tracing::debug!("Tokenizer built successfully"); Ok(tokenizer.into()) } } diff --git a/src/processor/text/mod.rs b/src/processor/text/mod.rs index 9c8a0e7..6fe7f15 100644 --- a/src/processor/text/mod.rs +++ b/src/processor/text/mod.rs @@ -3,9 +3,11 @@ mod macros; mod chat_template; mod config; mod logits_sampler; +mod pooling; mod processor; pub use chat_template::ChatTemplate; pub use config::TextProcessorConfig; pub use logits_sampler::LogitsSampler; +pub use pooling::Pooling; pub use processor::TextProcessor; diff --git a/src/processor/text/pooling.rs b/src/processor/text/pooling.rs new file mode 100644 index 0000000..3d58fbb --- /dev/null +++ b/src/processor/text/pooling.rs @@ -0,0 +1,153 @@ +use ndarray::{s, Array2, Axis, Ix3}; + +use crate::tensor::{XView, X}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Pooling { + Cls, + First, + Last, + Mean, + Max, +} + +impl Pooling { + /// Apply pooling on hidden states. + /// + /// - last_hidden_states: [B, T, D] + /// - mask (optional): [B, T], true = valid token + /// + /// Returns: + /// - pooled embedding: [B, D] + pub fn apply(&self, last_hidden_states: &XView, mask: Option<&XView>) -> X { + let states = last_hidden_states + .0 + .view() + .into_dimensionality::() + .expect("pooling expects a 3D tensor"); + + let mask_view = mask.map(|m| { + m.0.view() + .into_dimensionality::>() + .expect("mask must be 2D") + }); + + let pooled = match self { + Pooling::Cls | Pooling::First => states.index_axis(Axis(1), 0).to_owned(), + Pooling::Last => { + let t = states.len_of(Axis(1)); + states.index_axis(Axis(1), t - 1).to_owned() + } + Pooling::Mean => { + let (b, _, d) = states.dim(); + let mut out = Array2::zeros((b, d)); + for i in 0..b { + let mut count = 0.0; + for j in 0..states.len_of(Axis(1)) { + let allow = mask_view.as_ref().map(|m| m[[i, j]]).unwrap_or(true); + if allow { + out.row_mut(i).scaled_add(1.0, &states.slice(s![i, j, ..])); + count += 1.0; + } + } + if count > 0.0 { + *out.row_mut(i) /= count; + } + } + out + } + Pooling::Max => { + let (b, _, d) = states.dim(); + let mut out = Array2::from_elem((b, d), f32::NEG_INFINITY); + for i in 0..b { + for j in 0..states.len_of(Axis(1)) { + let allow = mask_view.as_ref().map(|m| m[[i, j]]).unwrap_or(true); + if allow { + let row = states.slice(s![i, j, ..]); + for k in 0..d { + out[[i, k]] = out[[i, k]].max(row[k]); + } + } + } + } + out + } + }; + + X::from(pooled.into_dyn()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn build_states() -> X { + let data = + ndarray::Array3::from_shape_vec((1, 3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); + X::from(data) + } + + fn build_mask() -> X { + let mask_arr = ndarray::Array2::from_shape_vec((1, 3), vec![true, false, true]).unwrap(); + X::from(mask_arr.map(|v| *v as u8 != 0)) + } + + #[test] + fn pooling_cls_first_use_first_token() { + let states = build_states(); + let view = states.view(); + let pooled = Pooling::Cls.apply(&view, None); + let out = pooled + .0 + .into_dimensionality::>() + .unwrap(); + assert_eq!(out[[0, 0]], 1.0); + assert_eq!(out[[0, 1]], 2.0); + } + + #[test] + fn pooling_last() { + let states = build_states(); + let view = states.view(); + let pooled = Pooling::Last.apply(&view, None); + let out = pooled + .0 + .into_dimensionality::>() + .unwrap(); + assert_eq!(out[[0, 0]], 5.0); + assert_eq!(out[[0, 1]], 6.0); + } + + #[test] + fn pooling_mean_with_mask() { + let states = build_states(); + let view = states.view(); + let mask = build_mask(); + let mask_view = mask.view(); + let pooled = Pooling::Mean.apply(&view, Some(&mask_view)); + let out = pooled + .0 + .into_dimensionality::>() + .unwrap(); + // mask keeps token 0 and 2 (values 1,2 and 5,6) + assert_eq!(out[[0, 0]], (1.0 + 5.0) / 2.0); + assert_eq!(out[[0, 1]], (2.0 + 6.0) / 2.0); + } + + #[test] + fn pooling_max_ignores_masked_positions() { + let states = build_states(); + let view = states.view(); + let mask = build_mask(); + let mask_view = mask.view(); + let pooled = Pooling::Max.apply(&view, Some(&mask_view)); + let out = pooled + .0 + .into_dimensionality::>() + .unwrap(); + // only allowed rows are 0 and 2, max per column is max(1,5)=5 and max(2,6)=6 + assert_eq!(out[[0, 0]], 5.0); + assert_eq!(out[[0, 1]], 6.0); + } +} diff --git a/src/results/y.rs b/src/results/y.rs index d2d95ed..d45aef0 100644 --- a/src/results/y.rs +++ b/src/results/y.rs @@ -21,6 +21,7 @@ pub struct Y { pub masks: Vec, pub images: Vec, pub embedding: X, + pub last_hidden_state: X, // un-pooled, [B, T, D] / [B, N, D] pub extra: HashMap, } @@ -56,7 +57,10 @@ impl std::fmt::Debug for Y { s.field("Images", &self.images); } if !self.embedding.is_empty() { - s.field("Embeddings", &self.embedding); + s.field("Embedding", &self.embedding); + } + if !self.last_hidden_state.is_empty() { + s.field("LastHiddenState(un-pooled)", &self.last_hidden_state); } if !self.extra.is_empty() { s.field("Extra", &self.extra); @@ -77,6 +81,7 @@ impl Y { && self.masks.is_empty() && self.images.is_empty() && self.embedding.is_empty() + && self.last_hidden_state.is_empty() && self.extra.is_empty() } } diff --git a/src/utils/progressbar.rs b/src/utils/progressbar.rs index 2c94bdd..615e8d4 100644 --- a/src/utils/progressbar.rs +++ b/src/utils/progressbar.rs @@ -166,6 +166,28 @@ impl PB { ]) } + pub fn fetch_stream() -> Self { + let pb = Self { + inner: ProgressBar::new_spinner(), + prefix: "Fetching", + completion_prefix: "Fetched", + ..Default::default() + }; + pb.apply_style(false); + pb.with_layout(vec![ + PBComponent::Prefix, + PBComponent::Message, + PBComponent::BinaryBytes, + PBComponent::BinarySpeed, + ]) + .with_completion_layout(vec![ + PBComponent::FinishedPrefix, + PBComponent::Message, + PBComponent::BinaryBytes, + PBComponent::InElapsed, + ]) + } + pub fn iterating(total: u64) -> Self { let counter = if total == u64::MAX { PBComponent::CounterWithInfinity