diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..0bb75f7 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.onnx filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index f6559a4..9e67325 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ model-archive/ *.pt test_data/* -models/ +models/* +!models/table-transformer-structure-recognition_fp16.onnx **/.env +**/examples/ diff --git a/Cargo.lock b/Cargo.lock index e7141ce..2469a17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1109,6 +1109,7 @@ dependencies = [ "criterion", "dirs", "futures", + "half", "html2md", "image", "imageproc", diff --git a/examples/sample-tables.pdf b/examples/sample-tables.pdf new file mode 100644 index 0000000..e80692a Binary files /dev/null and b/examples/sample-tables.pdf differ diff --git a/ferrules-cli/src/main.rs b/ferrules-cli/src/main.rs index aa5eb4e..b8019b3 100644 --- a/ferrules-cli/src/main.rs +++ b/ferrules-cli/src/main.rs @@ -229,6 +229,9 @@ fn parse_ep_args(args: &Args) -> Vec { #[tokio::main(flavor = "multi_thread")] async fn main() { let args = Args::parse(); + if args.debug || std::env::var("RUST_LOG").is_ok() { + tracing_subscriber::fmt::init(); + } // Check providers let providers = parse_ep_args(&args); @@ -239,8 +242,6 @@ async fn main() { inter_threads: args.inter_threads, opt_level: args.graph_opt_level.map(|v| v.try_into().unwrap()), }; - // Global tasks - let parser = FerrulesParser::new(ort_config); let page_range = match args.page_range { Some(ref page_range_str) => match parse_page_range(page_range_str) { @@ -263,10 +264,12 @@ async fn main() { }, None => None, }; - let pb = setup_progress_bar(&args.file_path, None, page_range.clone()); let pbc = pb.clone(); + // Global tasks + let parser = FerrulesParser::new(ort_config); + let doc_name = args .file_path .file_name() @@ -440,6 +443,20 @@ async fn main() { ], ); } + ferrules_core::error::FerrulesError::TableTransformerModelError(e) => { + format_error( + "Table Transformation Failed", + "Failed to process table using the vision model.", + vec![ + ("Error", e), + ("File", args.file_path.display().to_string()), + ( + "Suggestion", + "Check if the model files are present and valid.".to_string(), + ), + ], + ); + } } std::process::exit(1); } diff --git a/ferrules-core/Cargo.toml b/ferrules-core/Cargo.toml index b509977..b2e24ab 100644 --- a/ferrules-core/Cargo.toml +++ b/ferrules-core/Cargo.toml @@ -10,6 +10,7 @@ itertools = "0.14.0" futures = "0.3.31" colored = "3.0.0" dirs = "6.0.0" +half = "2.4.1" anyhow = { workspace = true } uuid = { workspace = true } tracing = { workspace = true } @@ -39,7 +40,7 @@ regex = "1.11.1" html2md = "0.2.15" [target.'cfg(target_os = "macos")'.dependencies] -ort = { version = "=2.0.0-rc.9", features = ["coreml", "fetch-models"] } +ort = { version = "=2.0.0-rc.9", features = ["coreml", "fetch-models", "half"] } objc2 = { version = "^0.5.2" } objc2-foundation = { version = "^0.2.2" } objc2-vision = { version = "^0.2.2", features = [ diff --git a/ferrules-core/src/blocks.rs b/ferrules-core/src/blocks.rs index 9bf27f0..6666531 100644 --- a/ferrules-core/src/blocks.rs +++ b/ferrules-core/src/blocks.rs @@ -28,6 +28,46 @@ pub struct List { pub(crate) items: Vec, } +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub enum TableAlgorithm { + #[default] + Unknown, + Lattice, + Stream, + Vision, +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct TableBlock { + pub(crate) id: usize, + pub(crate) caption: Option, + pub rows: Vec, + pub has_borders: bool, + pub algorithm: TableAlgorithm, +} + +impl TableBlock { + pub(crate) fn path(&self) -> String { + format!("table_{}.png", self.id) + } +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct TableRow { + pub cells: Vec, + pub is_header: bool, + pub bbox: BBox, +} + +#[derive(Clone, Debug, Default, Deserialize, Serialize)] +pub struct TableCell { + pub content: Vec, + pub text: String, + pub row_span: u8, + pub col_span: u8, + pub bbox: BBox, +} + #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct Title { pub level: TitleLevel, @@ -43,7 +83,7 @@ pub enum BlockType { ListBlock(List), TextBlock(TextBlock), Image(ImageBlock), - Table, + Table(TableBlock), } impl std::fmt::Display for BlockType { @@ -121,7 +161,21 @@ impl Block { } BlockType::Title(_title) => todo!(), BlockType::Image(_image_block) => todo!(), - BlockType::Table => todo!(), + BlockType::Table(table) => { + if let ElementType::Table(incoming_table_opt) = &element.kind { + self.bbox.merge(&element.bbox); + if let Some(incoming_table) = incoming_table_opt { + table.rows.extend(incoming_table.rows.clone()); + } + Ok(()) + } else { + Err(FerrulesError::BlockMergeError { + element, + block_id: self.id, + kind: self.kind.clone(), + }) + } + } } } @@ -133,7 +187,7 @@ impl Block { BlockType::Title(_) => "TITLE", BlockType::ListBlock(_) => "LIST", BlockType::Image(_) => "IMAGE", - BlockType::Table => "TABLE", + BlockType::Table(_) => "TABLE", } } } diff --git a/ferrules-core/src/draw.rs b/ferrules-core/src/draw.rs index bf304b3..c2d4deb 100644 --- a/ferrules-core/src/draw.rs +++ b/ferrules-core/src/draw.rs @@ -16,6 +16,10 @@ const BLOCK_COLOR: [u8; 4] = [209, 139, 0, 255]; const LAYOUT_COLOR: [u8; 4] = [0, 0, 255, 255]; const LINE_OCR_COLOR: [u8; 4] = [17, 138, 1, 255]; const LINE_PDFIRUM_COLOR: [u8; 4] = [255, 0, 0, 255]; +const PATH_COLOR: [u8; 4] = [0, 255, 0, 255]; // Green for paths +const TABLE_ROW_COLOR: [u8; 4] = [0, 0, 255, 255]; // Blue +const TABLE_CELL_COLOR: [u8; 4] = [255, 0, 0, 255]; // Red +const VISION_COLOR: [u8; 4] = [138, 43, 226, 96]; // Violet/Purple with medium alpha fn load_font() -> FontArc { FontArc::try_from_slice(FONT_BYTES).unwrap() @@ -54,7 +58,7 @@ pub(crate) fn draw_text_lines( pub(crate) fn draw_layout_bboxes( bboxes: &[LayoutBBox], page_img: &DynamicImage, -) -> anyhow::Result, Vec>> { +) -> Result, Vec>, FerrulesError> { // Convert the dynamic image to RGBA for in-place drawing. let mut out_img = page_img.to_rgba8(); @@ -92,7 +96,7 @@ pub(crate) fn draw_layout_bboxes( pub(crate) fn draw_ocr_bboxes( bboxes: &[OCRLines], page_img: &DynamicImage, -) -> anyhow::Result, Vec>> { +) -> Result, Vec>, FerrulesError> { // Convert the dynamic image to RGBA for in-place drawing. let mut out_img = page_img.to_rgba8(); @@ -126,38 +130,273 @@ pub(crate) fn draw_ocr_bboxes( Ok(out_img) } +pub(crate) fn draw_paths( + paths: &[crate::entities::PDFPath], + page_img: &DynamicImage, +) -> Result, Vec>, FerrulesError> { + let mut out_img = page_img.to_rgba8(); + + for path in paths { + for segment in &path.segments { + match segment { + crate::entities::Segment::Line { start, end } => { + let start = (start.0 as f32, start.1 as f32); + let end = (end.0 as f32, end.1 as f32); + imageproc::drawing::draw_line_segment_mut( + &mut out_img, + start, + end, + Rgba(PATH_COLOR), + ); + } + crate::entities::Segment::Rect { bbox } => { + let x0 = bbox.x0 as i32; + let y0 = bbox.y0 as i32; + let width = (bbox.width() as u32).max(1); + let height = (bbox.height() as u32).max(1); + let rect = Rect::at(x0, y0).of_size(width, height); + draw_hollow_rect_mut(&mut out_img, rect, Rgba(PATH_COLOR)); + } + } + } + } + + Ok(out_img) +} + pub(crate) fn draw_blocks( bboxes: &[Block], page_img: &DynamicImage, -) -> anyhow::Result, Vec>> { +) -> Result, Vec>, FerrulesError> { // Convert the dynamic image to RGBA for in-place drawing. let mut out_img = page_img.to_rgba8(); let font: FontArc = load_font(); for block in bboxes { - let x0 = block.bbox.x0 as i32; - let y0 = block.bbox.y0 as i32; - let x1 = block.bbox.x1 as i32; - let y1 = block.bbox.y1 as i32; + match &block.kind { + crate::blocks::BlockType::Table(table) => { + draw_table_structure(table, &mut out_img); + } + _ => { + let x0 = block.bbox.x0 as i32; + let y0 = block.bbox.y0 as i32; + let x1 = block.bbox.x1 as i32; + let y1 = block.bbox.y1 as i32; + + let width = (x1 - x0).max(1) as u32; + let height = (y1 - y0).max(1) as u32; + + let rect = Rect::at(x0, y0).of_size(width, height); + + draw_hollow_rect_mut(&mut out_img, rect, Rgba(BLOCK_COLOR)); + let scale = 70; + let legend_size = page_img.width().max(page_img.height()) / scale; + imageproc::drawing::draw_text_mut( + &mut out_img, + image::Rgba(BLOCK_COLOR), + block.bbox.x0 as i32, + (block.bbox.y0 - legend_size as f32) as i32, + legend_size as f32, + &font, + block.label(), + ); + } + } + } + + Ok(out_img) +} + +fn draw_filled_rect_alpha(img: &mut image::RgbaImage, rect: Rect, color: Rgba) { + let alpha = color[3] as f32 / 255.0; + let (w, h) = img.dimensions(); + + let left = rect.left().max(0); + let top = rect.top().max(0); + let right = rect.right().min(w as i32); + let bottom = rect.bottom().min(h as i32); + + for y in top..bottom { + for x in left..right { + let px = img.get_pixel_mut(x as u32, y as u32); + px[0] = ((1.0 - alpha) * px[0] as f32 + alpha * color[0] as f32) as u8; + px[1] = ((1.0 - alpha) * px[1] as f32 + alpha * color[1] as f32) as u8; + px[2] = ((1.0 - alpha) * px[2] as f32 + alpha * color[2] as f32) as u8; + } + } +} + +fn draw_table_structure( + table_block: &crate::blocks::TableBlock, + out_img: &mut ImageBuffer, Vec>, +) { + // 1. Draw Vision hints (bottom layer) + if let crate::blocks::TableAlgorithm::Vision = &table_block.algorithm { + for row in &table_block.rows { + // Row detection + let row_rect = Rect::at(row.bbox.x0 as i32, row.bbox.y0 as i32).of_size( + row.bbox.width().max(1.0) as u32, + row.bbox.height().max(1.0) as u32, + ); + draw_filled_rect_alpha(out_img, row_rect, Rgba(VISION_COLOR)); + } + + // Column detections based on first row cells + if let Some(first_row) = table_block.rows.first() { + let y_start = first_row.bbox.y0; + let y_end = table_block + .rows + .last() + .map(|r| r.bbox.y1) + .unwrap_or(y_start); + for cell in &first_row.cells { + let col_rect = Rect::at(cell.bbox.x0 as i32, y_start as i32).of_size( + cell.bbox.width().max(1.0) as u32, + (y_end - y_start).max(1.0) as u32, + ); + draw_filled_rect_alpha(out_img, col_rect, Rgba(VISION_COLOR)); + } + } + } + + // 2. Draw Table BBox + // We can assume the table block itself is drawn by the caller if needed, + // but here we focus on internal structure. + // Draw Rows + for row in &table_block.rows { + let x0 = row.bbox.x0 as i32; + let y0 = row.bbox.y0 as i32; + let x1 = row.bbox.x1 as i32; + let y1 = row.bbox.y1 as i32; let width = (x1 - x0).max(1) as u32; let height = (y1 - y0).max(1) as u32; - let rect = Rect::at(x0, y0).of_size(width, height); + draw_hollow_rect_mut(out_img, rect, Rgba(TABLE_ROW_COLOR)); - draw_hollow_rect_mut(&mut out_img, rect, Rgba(BLOCK_COLOR)); - let scale = 70; - let legend_size = page_img.width().max(page_img.height()) / scale; - imageproc::drawing::draw_text_mut( - &mut out_img, - image::Rgba(BLOCK_COLOR), - block.bbox.x0 as i32, - (block.bbox.y0 - legend_size as f32) as i32, - legend_size as f32, - &font, - block.label(), - ); + // Draw Cells + for cell in &row.cells { + let x0 = cell.bbox.x0 as i32; + let y0 = cell.bbox.y0 as i32; + let x1 = cell.bbox.x1 as i32; + let y1 = cell.bbox.y1 as i32; + let width = (x1 - x0).max(1) as u32; + let height = (y1 - y0).max(1) as u32; + let rect = Rect::at(x0, y0).of_size(width, height); + draw_hollow_rect_mut(out_img, rect, Rgba(TABLE_CELL_COLOR)); + } } +} - Ok(out_img) +#[cfg(test)] +mod tests { + use super::*; + use crate::blocks::{TableBlock, TableCell, TableRow}; + use crate::entities::BBox; + use image::RgbaImage; + + #[test] + fn test_draw_table_structure() { + let page_img = DynamicImage::ImageRgba8(RgbaImage::new(100, 100)); + let mut out_img = page_img.to_rgba8(); + + // Create dummy table + let table_block = TableBlock { + id: 1, + caption: None, + has_borders: true, + rows: vec![ + TableRow { + is_header: true, + bbox: BBox { + x0: 10.0, + y0: 10.0, + x1: 90.0, + y1: 30.0, + }, + cells: vec![ + TableCell { + text: "Header".to_string(), + bbox: BBox { + x0: 10.0, + y0: 10.0, + x1: 50.0, + y1: 30.0, + }, + row_span: 1, + col_span: 1, + content: vec![], + }, + TableCell { + text: "Header 2".to_string(), + bbox: BBox { + x0: 50.0, + y0: 10.0, + x1: 90.0, + y1: 30.0, + }, + row_span: 1, + col_span: 1, + content: vec![], + }, + ], + }, + TableRow { + is_header: false, + bbox: BBox { + x0: 10.0, + y0: 30.0, + x1: 90.0, + y1: 50.0, + }, + cells: vec![ + TableCell { + text: "Cell 1".to_string(), + bbox: BBox { + x0: 10.0, + y0: 30.0, + x1: 50.0, + y1: 50.0, + }, + row_span: 1, + col_span: 1, + content: vec![], + }, + TableCell { + text: "Cell 2".to_string(), + bbox: BBox { + x0: 50.0, + y0: 30.0, + x1: 90.0, + y1: 50.0, + }, + row_span: 1, + col_span: 1, + content: vec![], + }, + ], + }, + ], + algorithm: crate::blocks::TableAlgorithm::Unknown, + }; + + // Directly call the internal function to test it + draw_table_structure(&table_block, &mut out_img); + + // Also test via draw_blocks + let block = crate::blocks::Block { + id: 1, + kind: crate::blocks::BlockType::Table(table_block), + pages_id: vec![0], + bbox: BBox { + x0: 10.0, + y0: 10.0, + x1: 90.0, + y1: 50.0, + }, + }; + + let result = draw_blocks(&[block], &page_img); + assert!(result.is_ok()); + } } diff --git a/ferrules-core/src/entities.rs b/ferrules-core/src/entities.rs index 1303a73..7af4063 100644 --- a/ferrules-core/src/entities.rs +++ b/ferrules-core/src/entities.rs @@ -5,7 +5,10 @@ use std::{path::PathBuf, time::Duration}; use pdfium_render::prelude::{PdfFontWeight, PdfPageTextChar, PdfRect}; -use crate::{blocks::Block, layout::model::LayoutBBox}; +use crate::{ + blocks::{Block, TableBlock}, + layout::model::LayoutBBox, +}; pub type PageID = usize; pub type ElementID = usize; @@ -71,7 +74,7 @@ impl BBox { self.y1 = self.y1.max(other.y1); } #[inline(always)] - fn overlap_x(&self, other: &Self) -> f32 { + pub fn overlap_x(&self, other: &Self) -> f32 { f32::max( 0f32, f32::min(self.x1, other.x1) - f32::max(self.x0, other.x0), @@ -155,7 +158,7 @@ pub enum ElementType { ListItem, Caption, Image, - Table, + Table(Option), } impl std::fmt::Display for ElementType { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { @@ -169,7 +172,7 @@ pub struct Element { pub layout_block_id: i32, pub text_block: ElementText, pub kind: ElementType, - pub page_id: usize, + pub page_id: PageID, pub bbox: BBox, } @@ -184,7 +187,7 @@ impl Element { "Page-header" => ElementType::Header, "Title" => ElementType::Title, "Section-header" => ElementType::Subtitle, - "Table" => ElementType::Table, + "Table" => ElementType::Table(None), "Picture" => ElementType::Image, _ => { unreachable!("can't have other type of layout bbox") @@ -256,7 +259,7 @@ pub struct ParsedDocument { pub metadata: DocumentMetadata, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct CharSpan { pub bbox: BBox, pub text: String, @@ -305,7 +308,7 @@ impl CharSpan { } } } -#[derive(Default)] +#[derive(Clone, Default)] pub struct Line { pub text: String, pub bbox: BBox, @@ -368,6 +371,20 @@ impl Line { } } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub enum Segment { + Line { start: (f32, f32), end: (f32, f32) }, + Rect { bbox: BBox }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct PDFPath { + pub segments: Vec, + pub is_stroke: bool, + pub is_fill: bool, + pub stroke_width: Option, +} + #[cfg(test)] mod tests { use super::*; diff --git a/ferrules-core/src/error.rs b/ferrules-core/src/error.rs index d776244..590b16e 100644 --- a/ferrules-core/src/error.rs +++ b/ferrules-core/src/error.rs @@ -23,7 +23,8 @@ pub enum FerrulesError { }, #[error("saving error page number {page_idx} in :{tmp_dir:?}")] DebugPageError { tmp_dir: PathBuf, page_idx: PageID }, - #[error("saving error page number {page_idx} in :{tmp_dir:?}")] ParseTextError { tmp_dir: PathBuf, page_idx: PageID }, + #[error("table transformer model error: {0}")] + TableTransformerModelError(String), } diff --git a/ferrules-core/src/layout/model.rs b/ferrules-core/src/layout/model.rs index 648414b..32f5da8 100644 --- a/ferrules-core/src/layout/model.rs +++ b/ferrules-core/src/layout/model.rs @@ -399,7 +399,7 @@ impl ORTLayoutParser { } /// runs nms on without taking into account which class -fn nms(raw_bboxes: &mut Vec, iou_threshold: f32) { +pub(crate) fn nms(raw_bboxes: &mut Vec, iou_threshold: f32) { raw_bboxes.sort_by(|r1, r2| r2.proba.partial_cmp(&r1.proba).unwrap()); let mut current_index = 0; for index in 0..raw_bboxes.len() { diff --git a/ferrules-core/src/lib.rs b/ferrules-core/src/lib.rs index 6712438..5ed6c0e 100644 --- a/ferrules-core/src/lib.rs +++ b/ferrules-core/src/lib.rs @@ -44,7 +44,7 @@ //! &doc_bytes, //! "document".into(), //! Default::default(), -//! None, // No progress callback +//! None::, // No progress callback //! ).await?; //! //! Ok(()) diff --git a/ferrules-core/src/ocr/mod.rs b/ferrules-core/src/ocr/mod.rs index dbfec7b..e0dffe2 100644 --- a/ferrules-core/src/ocr/mod.rs +++ b/ferrules-core/src/ocr/mod.rs @@ -120,23 +120,25 @@ mod ocr_mac { mod tests { use super::*; use image::ImageReader; - use std::time::Instant; + use std::{path::Path, time::Instant}; #[test] fn test_ocr_apple_vision() { - let image = ImageReader::open("./test_data/double_cols.jpg") - .unwrap() - .decode() - .unwrap(); - - let s = Instant::now(); - let ocr_result = parse_image_ocr(&image, 1f32); - assert!(ocr_result.is_ok()); - - println!( - "OCR took: {}ms", - Instant::now().duration_since(s).as_millis() - ); + if Path::new("./test_data/double_cols.jpg").exists() { + let image = ImageReader::open("./test_data/double_cols.jpg") + .unwrap() + .decode() + .unwrap(); + + let s = Instant::now(); + let ocr_result = parse_image_ocr(&image, 1f32); + assert!(ocr_result.is_ok()); + + println!( + "OCR took: {}ms", + Instant::now().duration_since(s).as_millis() + ); + } } } } diff --git a/ferrules-core/src/parse/document.rs b/ferrules-core/src/parse/document.rs index f3e97cb..8c0205b 100644 --- a/ferrules-core/src/parse/document.rs +++ b/ferrules-core/src/parse/document.rs @@ -18,6 +18,7 @@ use crate::{ model::{ORTConfig, ORTLayoutParser}, ParseLayoutQueue, }, + parse::table::{ParseTableQueue, TableParser, TableTransformer}, }; /// Configuration options for parsing documents with FerrulesParser @@ -53,6 +54,7 @@ impl Default for FerrulesParseConfig<'_> { async fn parse_task( parse_native_result: ParseNativePageResult, layout_queue: ParseLayoutQueue, + table_queue: ParseTableQueue, debug_dir: Option, callback: Option, ) -> Result @@ -61,7 +63,13 @@ where { let page_id = parse_native_result.page_id; - let result = parse_page_full(parse_native_result, debug_dir, layout_queue.clone()).await; + let result = parse_page_full( + parse_native_result, + debug_dir, + layout_queue.clone(), + table_queue.clone(), + ) + .await; if let Some(callback) = callback { callback(page_id) } @@ -76,6 +84,7 @@ where pub struct FerrulesParser { layout_queue: ParseLayoutQueue, native_queue: ParseNativeQueue, + table_queue: ParseTableQueue, } impl FerrulesParser { @@ -91,12 +100,16 @@ impl FerrulesParser { /// Panics if the layout model cannot be loaded with the given configuration pub fn new(layout_config: ORTConfig) -> Self { let layout_model = - Arc::new(ORTLayoutParser::new(layout_config).expect("can't load layout model")); + Arc::new(ORTLayoutParser::new(layout_config.clone()).expect("can't load layout model")); let native_queue = ParseNativeQueue::new(); let layout_queue = ParseLayoutQueue::new(layout_model); + let transformer = TableTransformer::new(&layout_config).ok(); + let table_parser = Arc::new(TableParser::new(transformer)); + let table_queue = ParseTableQueue::new(table_parser); Self { layout_queue, native_queue, + table_queue, } } /// Parses a document into a structured format with optional page-level progress callback @@ -219,6 +232,7 @@ impl FerrulesParser { parse_task( parse_native_result, self.layout_queue.clone(), + self.table_queue.clone(), tmp_dir, callback, ) diff --git a/ferrules-core/src/parse/merge.rs b/ferrules-core/src/parse/merge.rs index b85392f..7809278 100644 --- a/ferrules-core/src/parse/merge.rs +++ b/ferrules-core/src/parse/merge.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use tracing::instrument; use crate::{ - blocks::{Block, BlockType, ImageBlock, List, TextBlock, Title, TitleLevel}, + blocks::{Block, BlockType, ImageBlock, List, TableBlock, TextBlock, Title, TitleLevel}, entities::{Element, ElementID, ElementType, Line, PageID}, error::FerrulesError, layout::model::LayoutBBox, @@ -209,27 +209,27 @@ pub(crate) fn merge_elements_into_blocks( while let Some(mut curr_el) = element_it.next() { match &mut curr_el.kind { ElementType::Text => { - let text_block = Block { + let mut text_block = Block { id: block_id, kind: crate::blocks::BlockType::TextBlock(TextBlock { - text: curr_el.text_block.text, + text: curr_el.text_block.text.clone(), }), pages_id: vec![curr_el.page_id], bbox: curr_el.bbox, }; - // TODO : Change this to use some minimum gap + // TODO: This might be a bug here // Check to see if we have another text block that is close - // while let Some(next_el) = element_it.peek() { - // if matches!(next_el.kind, crate::entities::ElementType::Text(_)) - // && (curr_el.bbox.distance(&next_el.bbox, 1.0, 1.0) - // < MAXIMUM_ASSIGNMENT_DISTANCE) - // { - // text_block.merge(next_el)?; - // element_it.next(); - // } else { - // break; - // } - // } + while let Some(next_el) = element_it.peek() { + if matches!(next_el.kind, crate::entities::ElementType::Text) + && (text_block.bbox.distance(&next_el.bbox, 1.0, 1.0) + < MAXIMUM_ASSIGNMENT_DISTANCE) + { + let next_el = element_it.next().unwrap(); + text_block.merge(next_el)?; + } else { + break; + } + } block_id += 1; blocks.push(text_block); } @@ -431,8 +431,21 @@ pub(crate) fn merge_elements_into_blocks( block_id += 1; blocks.push(title); } - _ => { - continue; + ElementType::Table(table_opt) => { + let table_block = Block { + id: block_id, + kind: BlockType::Table(table_opt.clone().unwrap_or_else(|| TableBlock { + id: block_id, + caption: None, + rows: Vec::new(), + has_borders: false, + algorithm: crate::blocks::TableAlgorithm::Unknown, + })), + pages_id: vec![curr_el.page_id], + bbox: curr_el.bbox, + }; + block_id += 1; + blocks.push(table_block); } } } @@ -773,4 +786,46 @@ mod tests { } Ok(()) } + + #[test] + fn test_merge_consecutive_tables() -> anyhow::Result<()> { + let table1_bbox = BBox { + x0: 0.0, + y0: 0.0, + x1: 2.0, + y1: 2.0, + }; + let table2_bbox = BBox { + x0: 0.0, + y0: 2.5, + x1: 2.0, + y1: 4.5, + }; + + let elements = vec![ + Element { + id: 0, + layout_block_id: 0, + kind: ElementType::Table(None), + text_block: ElementText::default(), + page_id: 1, + bbox: table1_bbox, + }, + Element { + id: 1, + layout_block_id: 1, + kind: ElementType::Table(None), + text_block: ElementText::default(), + page_id: 1, + bbox: table2_bbox, + }, + ]; + + let blocks = merge_elements_into_blocks(elements, HashMap::new())?; + + assert_eq!(blocks.len(), 2); + assert!(matches!(blocks[0].kind, BlockType::Table(_))); + assert!(matches!(blocks[1].kind, BlockType::Table(_))); + Ok(()) + } } diff --git a/ferrules-core/src/parse/mod.rs b/ferrules-core/src/parse/mod.rs index 8fc95df..c6b7833 100644 --- a/ferrules-core/src/parse/mod.rs +++ b/ferrules-core/src/parse/mod.rs @@ -2,4 +2,5 @@ pub mod document; pub(crate) mod merge; pub mod native; mod page; +pub mod table; pub mod titles; diff --git a/ferrules-core/src/parse/native.rs b/ferrules-core/src/parse/native.rs index d8cb525..cf23279 100644 --- a/ferrules-core/src/parse/native.rs +++ b/ferrules-core/src/parse/native.rs @@ -1,11 +1,12 @@ use std::{ops::Range, sync::Arc, time::Instant}; use image::DynamicImage; -use pdfium_render::prelude::{PdfPage, PdfPageTextChar, PdfRenderConfig, Pdfium}; +use pdfium_render::prelude::*; + use tracing::{instrument, Span}; use crate::{ - entities::{BBox, CharSpan, Line, PageID}, + entities::{BBox, CharSpan, Line, PDFPath, PageID, Segment}, error::FerrulesError, layout::model::ORTLayoutParser, }; @@ -96,6 +97,7 @@ pub struct ParseNativePageResult { // TODO: page_native_rotation pub page_id: PageID, pub text_lines: Vec, + pub paths: Vec, pub page_bbox: BBox, pub page_image: Arc, pub page_image_scale1: DynamicImage, @@ -142,6 +144,20 @@ pub(crate) fn parse_page_native( required_raster_height: u32, ) -> anyhow::Result { let start_time = Instant::now(); + + let page_bbox = BBox { + x0: 0f32, + y0: 0f32, + x1: page.width().value, + y1: page.height().value, + }; + + // NOTE: Extract paths BEFORE flatten. `page.flatten()` merges annotations and + // form fields into the page content stream, which invalidates pdfium's + // internal page‐object list. Calling `page.objects()` after flatten + // dereferences stale pointers and segfaults. + let paths = extract_page_paths(page, &page_bbox); + if flatten_page { page.flatten()?; } @@ -152,12 +168,6 @@ pub(crate) fn parse_page_native( }; let downscale_factor = 1f32 / rescale_factor; - let page_bbox = BBox { - x0: 0f32, - y0: 0f32, - x1: page.width().value, - y1: page.height().value, - }; let page_image = page .render_with_config(&PdfRenderConfig::default().scale_page_by_factor(rescale_factor)) .map(|bitmap| bitmap.as_image())?; @@ -179,6 +189,7 @@ pub(crate) fn parse_page_native( Ok(ParseNativePageResult { page_id, text_lines, + paths, page_bbox, page_image: Arc::new(page_image), page_image_scale1, @@ -189,6 +200,60 @@ pub(crate) fn parse_page_native( }) } +fn extract_page_paths(page: &PdfPage, page_bbox: &BBox) -> Vec { + let mut paths = Vec::new(); + + for object in page.objects().iter() { + if let Some(path_obj) = object.as_path_object() { + let mut segments = Vec::new(); + let mut current_point: Option<(f32, f32)> = None; + + for segment in path_obj.segments().iter() { + match segment.segment_type() { + PdfPathSegmentType::LineTo => { + let point = segment.point(); + let (x, y) = (point.0.value, point.1.value); + // NOTE: PDF coordinates are bottom-up, convert to top-down + let converted_y = page_bbox.height() - y; + let converted_point = (x, converted_y); + + if let Some(start) = current_point { + segments.push(Segment::Line { + start, + end: converted_point, + }); + current_point = Some(converted_point); + } else { + current_point = Some(converted_point); + } + } + PdfPathSegmentType::MoveTo => { + let point = segment.point(); + let (x, y) = (point.0.value, point.1.value); + // PDF coordinates are bottom-up, convert to top-down + let converted_y = page_bbox.height() - y; + current_point = Some((x, converted_y)); + } + _ => {} + } + } + + if !segments.is_empty() { + paths.push(PDFPath { + segments, + is_stroke: path_obj.is_stroked().unwrap_or(false), + is_fill: path_obj + .fill_mode() + .map(|m| m != PdfPathFillMode::None) + .unwrap_or(false), + stroke_width: path_obj.stroke_width().ok().map(|p| p.value), + }); + } + } + } + paths +} + fn handle_parse_native_req( pdfium: &Pdfium, req: ParseNativeRequest, diff --git a/ferrules-core/src/parse/page.rs b/ferrules-core/src/parse/page.rs index 9f16341..640c6d1 100644 --- a/ferrules-core/src/parse/page.rs +++ b/ferrules-core/src/parse/page.rs @@ -4,18 +4,20 @@ use std::{ sync::Arc, time::Instant, }; +use tokio::task::JoinSet; use image::DynamicImage; use tracing::instrument; use crate::{ draw::{draw_blocks, draw_layout_bboxes, draw_text_lines}, - entities::{Element, Line, PageID, StructuredPage}, + entities::{Element, ElementType, Line, PDFPath, PageID, StructuredPage}, error::FerrulesError, layout::{ model::LayoutBBox, Metadata, ParseLayoutQueue, ParseLayoutRequest, ParseLayoutResponse, }, ocr::parse_image_ocr, + parse::table::ParseTableQueue, }; use super::{ @@ -104,11 +106,13 @@ pub async fn parse_page_full( parse_native_result: ParseNativePageResult, debug_dir: Option, layout_queue: ParseLayoutQueue, + table_queue: ParseTableQueue, ) -> Result { let span = tracing::Span::current(); let ParseNativePageResult { page_id, text_lines, + paths, page_bbox, page_image, page_image_scale1, @@ -143,7 +147,36 @@ pub async fn parse_page_full( parse_page_text(text_lines, &page_layout, &page_image, downscale_factor)?; // Merging elements with layout - let elements = build_page_elements(&page_layout, &text_lines, page_id)?; + let mut elements = build_page_elements(&page_layout, &text_lines, page_id)?; + let text_lines_arc = Arc::new(text_lines.clone()); + let paths_arc = Arc::new(paths); + + // Table parsing + let mut set = JoinSet::new(); + for (idx, element) in elements.iter().enumerate() { + if matches!(element.kind, ElementType::Table(_)) { + let (tx, rx) = tokio::sync::oneshot::channel(); + let req = crate::parse::table::ParseTableRequest { + page_id, + page_image: Arc::clone(&page_image), + lines: Arc::clone(&text_lines_arc), + paths: Arc::clone(&paths_arc), + table_bbox: element.bbox.clone(), + downscale_factor, + metadata: crate::parse::table::TableMetadata { response_tx: tx }, + }; + table_queue.push(req).await?; + set.spawn(async move { (idx, rx.await) }); + } + } + + while let Some(res) = set.join_next().await { + if let Ok((idx, Ok(Ok(resp)))) = res { + if let ElementType::Table(ref mut table_opt) = elements[idx].kind { + *table_opt = Some(resp.table_block); + } + } + } if let Some(tmp_dir) = debug_dir { debug_page( &tmp_dir, @@ -153,6 +186,7 @@ pub async fn parse_page_full( need_ocr, &page_layout, &elements, + &paths_arc, )? }; @@ -191,6 +225,7 @@ fn debug_page( need_ocr: bool, page_layout: &[LayoutBBox], elements: &[Element], + paths: &[PDFPath], ) -> Result<(), FerrulesError> { let output_file = tmp_dir.join(format!("page_{}.png", page_idx)); let final_output_file = tmp_dir.join(format!("page_blocks_{}.png", page_idx)); @@ -209,11 +244,22 @@ fn debug_page( // Draw the final prediction - // TODO: Implement titles hashmap for titles in the page let blocks = merge_elements_into_blocks(elements.to_vec(), HashMap::new())?; - let final_img = + let final_img_buffer = draw_blocks(&blocks, page_image).map_err(|_| FerrulesError::DebugPageError { tmp_dir: tmp_dir.to_path_buf(), page_idx, })?; + + // Draw paths on final image for debugging + let dynamic_final_img = image::DynamicImage::ImageRgba8(final_img_buffer); + let final_img_with_paths = + crate::draw::draw_paths(paths, &dynamic_final_img).map_err(|_| { + FerrulesError::DebugPageError { + tmp_dir: tmp_dir.to_path_buf(), + page_idx, + } + })?; + out_img .save(output_file) .map_err(|_| FerrulesError::DebugPageError { @@ -221,7 +267,7 @@ fn debug_page( page_idx, })?; - final_img + final_img_with_paths .save(final_output_file) .map_err(|_| FerrulesError::DebugPageError { tmp_dir: tmp_dir.to_path_buf(), diff --git a/ferrules-core/src/parse/table.rs b/ferrules-core/src/parse/table.rs new file mode 100644 index 0000000..7538dbd --- /dev/null +++ b/ferrules-core/src/parse/table.rs @@ -0,0 +1,1026 @@ +use image::imageops::FilterType; +use image::{DynamicImage, GenericImageView}; +use ndarray::{Array4, Axis}; +use ort::execution_providers::{ + CPUExecutionProvider, CUDAExecutionProvider, CoreMLExecutionProvider, TensorRTExecutionProvider, +}; +use ort::session::builder::GraphOptimizationLevel; +use ort::session::Session; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot, Semaphore}; +use tracing::{Instrument, Span}; + +use crate::blocks::{TableAlgorithm, TableBlock}; +use crate::entities::{BBox, PDFPath, PageID}; +use crate::error::FerrulesError; +use crate::layout::model::{nms, LayoutBBox}; + +#[derive(Debug)] +pub struct TableMetadata { + pub(crate) response_tx: oneshot::Sender>, +} + +#[derive(Debug)] +pub(crate) struct ParseTableRequest { + pub(crate) page_id: PageID, + pub(crate) page_image: Arc, + pub(crate) lines: Arc>, + pub(crate) paths: Arc>, + pub(crate) table_bbox: BBox, + pub(crate) downscale_factor: f32, + pub(crate) metadata: TableMetadata, +} + +#[derive(Debug)] +pub(crate) struct ParseTableResponse { + pub(crate) table_block: TableBlock, +} + +#[derive(Debug, Clone)] +pub struct ParseTableQueue { + queue: mpsc::Sender<(ParseTableRequest, Span)>, +} + +impl ParseTableQueue { + pub fn new(table_parser: Arc) -> Self { + let (queue_sender, queue_receiver) = mpsc::channel(16); // Buffer size + + tokio::task::spawn(start_table_parser(table_parser, queue_receiver)); + Self { + queue: queue_sender, + } + } + + pub(crate) async fn push(&self, req: ParseTableRequest) -> Result<(), FerrulesError> { + let span = Span::current(); + self.queue + .send((req, span)) + .await + .map_err(|_| FerrulesError::LayoutParsingError) // Reusing error or add specific + } +} + +async fn start_table_parser( + table_parser: Arc, + mut input_rx: mpsc::Receiver<(ParseTableRequest, Span)>, +) { + let s = Arc::new(Semaphore::new(4)); // Limit concurrent table workers + while let Some((req, span)) = input_rx.recv().await { + let _guard = span.enter(); + tokio::spawn(handle_table_request(s.clone(), table_parser.clone(), req).in_current_span()); + } +} + +async fn handle_table_request( + s: Arc, + _parser: Arc, + req: ParseTableRequest, +) { + let _permit = s.acquire().await.unwrap(); + + let ParseTableRequest { + page_id, + page_image, + lines, + paths, + table_bbox, + downscale_factor, + metadata, + } = req; + + let parser = _parser.clone(); + let lines = lines.clone(); + let paths = paths.clone(); + let page_image = page_image.clone(); + let table_bbox = table_bbox.clone(); + + let table_result = tokio::task::spawn_blocking(move || { + parser.parse( + page_id, + &lines, + &paths, + &table_bbox, + &page_image, + downscale_factor, + ) + }) + .await; + + // Handle JoinError from spawn_blocking + let table_result = match table_result { + Ok(res) => res, + Err(_e) => return, // Task panicked or cancelled + }; + + drop(_permit); + + let response = table_result.map(|t| ParseTableResponse { table_block: t }); + + let _ = metadata.response_tx.send(response); +} + +pub const TABLE_MODEL_BYTES: &[u8] = + include_bytes!("../../../models/table-transformer-structure-recognition_fp16.onnx"); + +pub struct TableParser { + transformer: Option, + table_id_counter: AtomicUsize, +} + +pub struct TableTransformer { + session: Session, + _output_names: Vec, +} + +impl TableTransformer { + const SHORTEST_EDGE: usize = 800; + const MAX_SIZE: usize = 1333; + const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406]; + const IMAGENET_STD: [f32; 3] = [0.229, 0.224, 0.225]; + + // Structure Recognition Labels: + // 0: table, 1: column, 2: row, 3: column header, 4: projected row header, 5: spanning cell + const TABLE_LABELS: [&'static str; 6] = [ + "table", + "column", + "row", + "column_header", + "projected_row_header", + "spanning_cell", + ]; + + const CONFIDENCE_THRESHOLD: f32 = 0.6; + + fn scale_wh(&self, w0: f32, h0: f32) -> (f32, f32, f32) { + let mut r = Self::SHORTEST_EDGE as f32 / w0.min(h0); + if (w0.max(h0) * r) > Self::MAX_SIZE as f32 { + r = Self::MAX_SIZE as f32 / w0.max(h0); + } + (r, (w0 * r).round(), (h0 * r).round()) + } + + pub fn new(config: &crate::layout::model::ORTConfig) -> Result { + let mut execution_providers = Vec::new(); + + // Providers + for provider in &config.execution_providers { + match provider { + crate::layout::model::OrtExecutionProvider::Trt(device_id) => { + execution_providers.push( + TensorRTExecutionProvider::default() + .with_device_id(*device_id) + .build(), + ); + } + crate::layout::model::OrtExecutionProvider::CUDA(device_id) => { + execution_providers.push( + CUDAExecutionProvider::default() + .with_device_id(*device_id) + .build(), + ); + } + crate::layout::model::OrtExecutionProvider::CoreML { ane_only } => { + let provider = CoreMLExecutionProvider::default(); + let provider = if *ane_only { + provider.with_ane_only().build() + } else { + provider.build() + }; + execution_providers.push(provider) + } + crate::layout::model::OrtExecutionProvider::CPU => { + execution_providers.push(CPUExecutionProvider::default().build()); + } + } + } + + let opt_lvl = match config.opt_level { + Some(crate::layout::model::ORTGraphOptimizationLevel::Level1) => { + GraphOptimizationLevel::Level1 + } + Some(crate::layout::model::ORTGraphOptimizationLevel::Level2) => { + GraphOptimizationLevel::Level2 + } + Some(crate::layout::model::ORTGraphOptimizationLevel::Level3) => { + GraphOptimizationLevel::Level3 + } + None => GraphOptimizationLevel::Disable, + }; + + let session = Session::builder() + .map_err(|e| FerrulesError::TableTransformerModelError(e.to_string()))? + .with_execution_providers(execution_providers) + .map_err(|e| FerrulesError::TableTransformerModelError(e.to_string()))? + .with_optimization_level(opt_lvl) + .map_err(|e| FerrulesError::TableTransformerModelError(e.to_string()))? + .with_intra_threads(config.intra_threads) + .map_err(|e| FerrulesError::TableTransformerModelError(e.to_string()))? + .with_inter_threads(config.inter_threads) + .map_err(|e| FerrulesError::TableTransformerModelError(e.to_string()))? + .commit_from_memory(TABLE_MODEL_BYTES) + .map_err(|e| FerrulesError::TableTransformerModelError(e.to_string()))?; + + let output_names = session.outputs.iter().map(|o| o.name.clone()).collect(); + + Ok(Self { + session, + _output_names: output_names, + }) + } + + pub fn preprocess(&self, img: &DynamicImage) -> Array4 { + let (w0, h0) = img.dimensions(); + let (_, w_new, h_new) = self.scale_wh(w0 as f32, h0 as f32); + + let resized = img.resize_exact(w_new as u32, h_new as u32, FilterType::Triangle); + let (w_final, h_final) = resized.dimensions(); + + let mut input = Array4::zeros([1, 3, h_final as usize, w_final as usize]); + + for (x, y, pixel) in resized.pixels() { + let [r, g, b, _] = pixel.0; + // Normalize with ImageNet mean/std + input[[0, 0, y as usize, x as usize]] = + (r as f32 / 255.0 - Self::IMAGENET_MEAN[0]) / Self::IMAGENET_STD[0]; + input[[0, 1, y as usize, x as usize]] = + (g as f32 / 255.0 - Self::IMAGENET_MEAN[1]) / Self::IMAGENET_STD[1]; + input[[0, 2, y as usize, x as usize]] = + (b as f32 / 255.0 - Self::IMAGENET_MEAN[2]) / Self::IMAGENET_STD[2]; + } + + input + } + + pub fn run( + &self, + input: Array4, + ) -> Result, FerrulesError> { + let input_f16 = input.mapv(half::f16::from_f32); + let outputs = self + .session + .run( + ort::inputs![input_f16] + .map_err(|e| FerrulesError::TableTransformerModelError(e.to_string()))?, + ) + .map_err(|e| FerrulesError::TableTransformerModelError(e.to_string()))?; + Ok(outputs) + } + + /// Decode the DETR-style output from the Table Transformer. + /// Boxes are [center_x, center_y, width, height] normalized. + pub fn postprocess( + &self, + outputs: &ort::session::SessionOutputs, + orig_width: u32, + orig_height: u32, + ) -> Result, FerrulesError> { + let logits = outputs["logits"] + .try_extract_tensor::() + .map_err(|e| FerrulesError::TableTransformerModelError(e.to_string()))?; + let boxes = outputs["pred_boxes"] + .try_extract_tensor::() + .map_err(|e| FerrulesError::TableTransformerModelError(e.to_string()))?; + let logits = logits.mapv(|x| x.to_f32()); + let boxes = boxes.mapv(|x| x.to_f32()); + + // logits: [1, 125, 7] (Structure Recognition has 6 classes + 1 Background) + // boxes: [1, 125, 4] + + let mut results = Vec::new(); + + let logits = logits.index_axis(Axis(0), 0); + let boxes = boxes.index_axis(Axis(0), 0); + + for i in 0..125 { + let logit = logits.index_axis(Axis(0), i); + let box_coords = boxes.index_axis(Axis(0), i); + + // Apply softmax to get proper probabilities + let max_logit = logit.iter().copied().fold(f32::NEG_INFINITY, f32::max); + let exp_sum: f32 = logit.iter().map(|&v| (v - max_logit).exp()).sum(); + let softmax_probs: Vec = logit + .iter() + .map(|&v| (v - max_logit).exp() / exp_sum) + .collect(); + + // Find best class + let (max_idx, &max_prob) = softmax_probs + .iter() + .enumerate() + .take(Self::TABLE_LABELS.len()) // only first 6 are valid classes + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + .unwrap(); + + if max_prob < Self::CONFIDENCE_THRESHOLD { + continue; + } + + let cx = box_coords[0] * orig_width as f32; + let cy = box_coords[1] * orig_height as f32; + let w = box_coords[2] * orig_width as f32; + let h = box_coords[3] * orig_height as f32; + + results.push(LayoutBBox { + id: i as i32, + bbox: BBox { + x0: (cx - w / 2.0).max(0.0).min(orig_width as f32), + y0: (cy - h / 2.0).max(0.0).min(orig_height as f32), + x1: (cx + w / 2.0).max(0.0).min(orig_width as f32), + y1: (cy + h / 2.0).max(0.0).min(orig_height as f32), + }, + + label: Self::TABLE_LABELS[max_idx], + proba: max_prob, + }); + } + + Ok(results) + } +} + +impl TableParser { + const ROW_OVERLAP_THRESHOLD: f32 = 0.5; + + /// Minimum cell count below which a table is suspicious (when paired with a large area). + const MIN_CELL_COUNT: usize = 2; + /// Area threshold used when the cell count is low. + const SUSPICIOUS_AREA_LOW_CELLS: f32 = 1500.0; + /// Minimum row count below which a table is suspicious (when paired with a large area). + const MIN_ROW_COUNT: usize = 1; + /// Area threshold used when the row count is low. + const SUSPICIOUS_AREA_LOW_ROWS: f32 = 3000.0; + /// Large-area threshold: tables above this always try vision. + const LARGE_AREA_THRESHOLD: f32 = 5000.0; + /// Minimum ratio of total cell area to table area. + /// Below this the stream result is considered incomplete. + const CELL_COVERAGE_THRESHOLD: f32 = 0.3; + + pub fn new(transformer: Option) -> Self { + Self { + transformer, + table_id_counter: AtomicUsize::new(0), + } + } + + /// Heuristic to decide whether the Vision (Table Transformer) fallback + /// should be attempted after a Stream parse. + /// + /// Returns `true` when: + /// - Stream found **no rows** at all. + /// - Few cells in a suspiciously large area. + /// - Few rows in a suspiciously large area. + /// - The table area exceeds `LARGE_AREA_THRESHOLD`. + /// - The total cell area covers less than `CELL_COVERAGE_THRESHOLD` of the table area. + fn should_try_vision(&self, table: &TableBlock, table_area: f32) -> bool { + let row_count = table.rows.len(); + if row_count == 0 { + return true; + } + + let cell_count: usize = table.rows.iter().map(|r| r.cells.len()).sum(); + + let is_suspicious = (cell_count <= Self::MIN_CELL_COUNT + && table_area > Self::SUSPICIOUS_AREA_LOW_CELLS) + || (row_count <= Self::MIN_ROW_COUNT && table_area > Self::SUSPICIOUS_AREA_LOW_ROWS); + + if is_suspicious || table_area > Self::LARGE_AREA_THRESHOLD { + return true; + } + + // Check whether the detected cells actually cover a reasonable + // fraction of the table bounding box. A low ratio means stream + // parsing likely missed significant content. + if table_area > 0.0 { + let total_cell_area: f32 = table + .rows + .iter() + .flat_map(|r| &r.cells) + .map(|c| c.bbox.area()) + .sum(); + if total_cell_area / table_area < Self::CELL_COVERAGE_THRESHOLD { + return true; + } + } + + false + } + + /// Main entry point to parse a table within a given bounding box on a page. + pub fn parse( + &self, + _page_id: PageID, + lines: &[crate::entities::Line], + paths: &[PDFPath], + table_bbox: &BBox, + page_image: &DynamicImage, + downscale_factor: f32, + ) -> Result { + // TODO: Decide between Lattice and Stream + if !paths.is_empty() { + tracing::debug!( + "Page {} - BBox {:?} - {} paths. Trying Lattice...", + _page_id, + table_bbox, + paths.len() + ); + if let Some(mut table) = self.parse_lattice(lines, paths, table_bbox) { + table.algorithm = TableAlgorithm::Lattice; + tracing::debug!("Page {} - Lattice successful.", _page_id); + return Ok(table); + } + tracing::debug!("Page {} - Lattice failed.", _page_id); + } else { + tracing::debug!("Page {} has no paths. Skipping Lattice.", _page_id); + } + + let mut table = self.parse_stream(lines, table_bbox)?; + table.algorithm = TableAlgorithm::Stream; + + let area = table_bbox.area(); + let cell_count: usize = table.rows.iter().map(|r| r.cells.len()).sum(); + let row_count = table.rows.len(); + + if self.should_try_vision(&table, area) { + tracing::debug!( + "Page {} - Stream suspicious ({} cells, {} rows). Trying Vision comparison...", + _page_id, + cell_count, + row_count + ); + + if let Ok(vision_table) = + self.parse_vision(page_image, lines, table_bbox, downscale_factor) + { + let vision_cell_count: usize = + vision_table.rows.iter().map(|r| r.cells.len()).sum(); + + // Pick vision if it found significantly more cells OR if stream was empty + if vision_cell_count > cell_count + || (cell_count == 0 && !vision_table.rows.is_empty()) + { + tracing::debug!( + "Page {} - Vision ({} cells) preferred over Stream ({} cells).", + _page_id, + vision_cell_count, + cell_count + ); + return Ok(vision_table); + } + } + } + + tracing::debug!( + "Page {} - Stream successful ({} cells).", + _page_id, + cell_count + ); + Ok(table) + } + + fn parse_lattice( + &self, + lines: &[crate::entities::Line], + paths: &[crate::entities::PDFPath], + table_bbox: &BBox, + ) -> Option { + let padding = 5.0; + let mut h_lines = Vec::new(); + let mut v_lines = Vec::new(); + + for path in paths { + for segment in &path.segments { + match segment { + crate::entities::Segment::Line { start, end } => { + let (x1, y1) = *start; + let (x2, y2) = *end; + + // Horizontal line + if (y1 - y2).abs() < 1.0 { + let y = (y1 + y2) / 2.0; + if y >= table_bbox.y0 - padding && y <= table_bbox.y1 + padding { + let x_min = x1.min(x2); + let x_max = x1.max(x2); + if x_min < table_bbox.x1 + padding + && x_max > table_bbox.x0 - padding + { + h_lines.push((y, x_min, x_max)); + } + } + } + // Vertical line + else if (x1 - x2).abs() < 1.0 { + let x = (x1 + x2) / 2.0; + if x >= table_bbox.x0 - padding && x <= table_bbox.x1 + padding { + let y_min = y1.min(y2); + let y_max = y1.max(y2); + if y_min < table_bbox.y1 + padding + && y_max > table_bbox.y0 - padding + { + v_lines.push((x, y_min, y_max)); + } + } + } + } + crate::entities::Segment::Rect { bbox } => { + if table_bbox.intersection(bbox) > 0.0 { + h_lines.push((bbox.y0, bbox.x0, bbox.x1)); + h_lines.push((bbox.y1, bbox.x0, bbox.x1)); + v_lines.push((bbox.x0, bbox.y0, bbox.y1)); + v_lines.push((bbox.x1, bbox.y0, bbox.y1)); + } + } + } + } + } + + tracing::debug!( + "Lattice - BBox {:?} - segments: H={}, V={}", + table_bbox, + h_lines.len(), + v_lines.len() + ); + + if h_lines.is_empty() || v_lines.is_empty() { + return None; + } + + // Simple clustering + h_lines.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + v_lines.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + + let mut unique_h = Vec::new(); + if !h_lines.is_empty() { + let mut curr = h_lines[0]; + for next in h_lines.iter().skip(1) { + if (next.0 - curr.0).abs() < 2.0 { + curr.1 = curr.1.min(next.1); + curr.2 = curr.2.max(next.2); + } else { + unique_h.push(curr); + curr = *next; + } + } + unique_h.push(curr); + } + + let mut unique_v = Vec::new(); + if !v_lines.is_empty() { + let mut curr = v_lines[0]; + for next in v_lines.iter().skip(1) { + if (next.0 - curr.0).abs() < 2.0 { + curr.1 = curr.1.min(next.1); + curr.2 = curr.2.max(next.2); + } else { + unique_v.push(curr); + curr = *next; + } + } + unique_v.push(curr); + } + + let h_coords: Vec = unique_h.iter().map(|l| l.0).collect(); + let v_coords: Vec = unique_v.iter().map(|l| l.0).collect(); + + if h_coords.len() < 2 || v_coords.len() < 2 { + tracing::debug!( + "Lattice failed: insufficient grid lines. H={:?}, V={:?}", + h_coords, + v_coords + ); + return None; + } + tracing::debug!("Lattice Grid: H={:?}, V={:?}", h_coords, v_coords); + + // Pre-filter lines that intersect the table + let table_lines: Vec<_> = lines + .iter() + .filter(|l| table_bbox.intersection(&l.bbox) > 0.0) + .collect(); + + let mut rows = Vec::new(); + + // Matrix to track visited grid cells (for spanning) + let num_rows = h_coords.len() - 1; + let num_cols = v_coords.len() - 1; + let mut visited = vec![vec![false; num_cols]; num_rows]; + + for i in 0..num_rows { + let y0 = h_coords[i]; + let y1 = h_coords[i + 1]; + let mut row_cells = Vec::new(); + + for j in 0..num_cols { + if visited[i][j] { + continue; + } + + // Determine horizontal span + let mut col_span = 1; + while j + col_span < num_cols { + let next_x = v_coords[j + col_span]; + // Is there a vertical line at next_x covering this row? + let has_boundary = v_lines.iter().any(|(sx, ymin, ymax)| { + if (sx - next_x).abs() > 1.5 { + return false; + } + let overlap_start = y0.max(*ymin); + let overlap_end = y1.min(*ymax); + let overlap = (overlap_end - overlap_start).max(0.0); + overlap > (y1 - y0) * 0.1 + }); + + if has_boundary { + break; + } + col_span += 1; + } + + let row_span = 1; + let x0 = v_coords[j]; + let x1 = v_coords[j + col_span]; + let cell_bbox = BBox { x0, y0, x1, y1 }; + + // Mark visited + for r in i..i + row_span { + for c in j..j + col_span { + visited[r][c] = true; + } + } + + // Collect spans that belong to this cell + let mut cell_spans = Vec::new(); + for line in &table_lines { + for span in &line.spans { + let (cx, cy) = span.bbox.center(); + if cx >= x0 && cx <= x1 && cy >= y0 && cy <= y1 { + cell_spans.push(span); + } + } + } + + // Sort spans by Y then X to maintain reading order + cell_spans.sort_by(|a, b| { + a.bbox + .y0 + .partial_cmp(&b.bbox.y0) + .unwrap() + .then(a.bbox.x0.partial_cmp(&b.bbox.x0).unwrap()) + }); + + let mut cell_text = String::new(); + let mut last_bbox: Option = None; + + for span in cell_spans { + // Clean text: replace newlines with space + let cleaned = span.text.replace('\n', " "); + let cleaned_trimmed = cleaned.trim(); + if cleaned_trimmed.is_empty() { + continue; + } + + if let Some(last) = last_bbox { + let is_new_line = (span.bbox.y0 - last.y0).abs() > last.height() * 0.3; + let gap_x = span.bbox.x0 - last.x1; + + // Add space if there is a significant vertical or horizontal gap + if is_new_line || gap_x > 2.0 { + if !cell_text.ends_with(' ') { + cell_text.push(' '); + } + } + } + + cell_text.push_str(cleaned_trimmed); + last_bbox = Some(span.bbox.clone()); + } + + row_cells.push(crate::blocks::TableCell { + content: vec![], + text: cell_text, + row_span: row_span as u8, + col_span: col_span as u8, + bbox: cell_bbox, + }); + } + + if !row_cells.is_empty() { + let mut row_bbox = row_cells[0].bbox.clone(); + for cell in &row_cells[1..] { + row_bbox.merge(&cell.bbox); + } + rows.push(crate::blocks::TableRow { + cells: row_cells, + is_header: false, + bbox: row_bbox, + }); + } + } + + if rows.is_empty() { + return None; + } + + let table_id = self.table_id_counter.fetch_add(1, Ordering::SeqCst); + Some(TableBlock { + id: table_id, + caption: None, + rows, + has_borders: true, + algorithm: TableAlgorithm::Lattice, + }) + } + + fn parse_vision( + &self, + image: &DynamicImage, + lines: &[crate::entities::Line], + table_bbox: &BBox, + downscale_factor: f32, + ) -> Result { + let transformer = match &self.transformer { + Some(t) => t, + None => return Ok(TableBlock::default()), + }; + + // 1. Crop image to table_bbox (in image coordinates) + let scale = 1.0 / downscale_factor; + let x0_f = table_bbox.x0 * scale; + let y0_f = table_bbox.y0 * scale; + let x0 = x0_f.floor() as u32; + let y0 = y0_f.floor() as u32; + + // Ensure we don't go out of bounds + let x0 = x0.min(image.width()); + let y0 = y0.min(image.height()); + + // Calculate width/height in image coordinates + let w_img = ((table_bbox.width() * scale) as u32).max(1); + let h_img = ((table_bbox.height() * scale) as u32).max(1); + + let w = w_img.min(image.width() - x0).max(1); + let h = h_img.min(image.height() - y0).max(1); + + let crop = image.crop_imm(x0, y0, w, h); + + // 2. Preprocess + let input = transformer.preprocess(&crop); + + // 3. Run Inference + let outputs = transformer.run(input).map_err(|e| { + tracing::error!("parse_vision: Inference failed: {:?}", e); + FerrulesError::TableTransformerModelError(e.to_string()) + })?; + + // 4. Postprocess + let detections = transformer.postprocess(&outputs, w, h).map_err(|e| { + tracing::error!("parse_vision: Postprocess failed: {:?}", e); + FerrulesError::TableTransformerModelError(e.to_string()) + })?; + + // 5. Map detections to Table structure + // Simple mapping: find all 'row' and 'column' labels + let mut rows: Vec = detections + .iter() + .filter(|d| d.label == "row") + .cloned() + .collect(); + let mut cols: Vec = detections + .iter() + .filter(|d| d.label == "column") + .cloned() + .collect(); + + // 5a. Apply NMS to rows and columns independently + nms(&mut rows, 0.5); + nms(&mut cols, 0.5); + + rows.sort_by(|a, b| a.bbox.y0.partial_cmp(&b.bbox.y0).unwrap()); + cols.sort_by(|a, b| a.bbox.x0.partial_cmp(&b.bbox.x0).unwrap()); + + // Snap outermost column/row edges to the table bbox so cells cover + // the full table area. The model detects content areas which are + // typically slightly narrower than the full table boundaries. + if let Some(first_col) = cols.first_mut() { + first_col.bbox.x0 = 0.0; + } + if let Some(last_col) = cols.last_mut() { + last_col.bbox.x1 = w as f32; + } + if let Some(first_row) = rows.first_mut() { + first_row.bbox.y0 = 0.0; + } + if let Some(last_row) = rows.last_mut() { + last_row.bbox.y1 = h as f32; + } + + // Extract spanning cells and column headers + let spanning_cells: Vec<&LayoutBBox> = detections + .iter() + .filter(|d| d.label == "spanning_cell") + .collect(); + let header_dets: Vec<&LayoutBBox> = detections + .iter() + .filter(|d| d.label == "column_header") + .collect(); + + let mut table_rows = Vec::new(); + for row_det in &rows { + let row_y0_pdf = (row_det.bbox.y0 + y0 as f32) * downscale_factor; + let row_y1_pdf = (row_det.bbox.y1 + y0 as f32) * downscale_factor; + + // Check if this row is a header row + let is_header = header_dets.iter().any(|hdr| { + let row_bbox_crop = &row_det.bbox; + row_bbox_crop.intersection(&hdr.bbox) / row_bbox_crop.area() > 0.5 + }); + + let mut cells = Vec::new(); + let mut col_idx = 0; + while col_idx < cols.len() { + // Build the cell bbox for current (row, col) in crop-pixel space + let cell_crop = BBox { + x0: cols[col_idx].bbox.x0, + y0: row_det.bbox.y0, + x1: cols[col_idx].bbox.x1, + y1: row_det.bbox.y1, + }; + + // Check if a spanning cell covers this position + let spanning = spanning_cells + .iter() + .find(|sc| cell_crop.intersection(&sc.bbox) / cell_crop.area() > 0.5); + + let col_span = if let Some(sc) = spanning { + // Count how many consecutive columns this spanning cell covers + let mut span = 1; + for j in (col_idx + 1)..cols.len() { + let col_overlap = cols[j].bbox.overlap_x(&sc.bbox); + if col_overlap / cols[j].bbox.width() > 0.5 { + span += 1; + } else { + break; + } + } + span + } else { + 1usize + }; + + // Build the merged cell bbox spanning col_idx..col_idx+col_span + let last_col = &cols[(col_idx + col_span - 1).min(cols.len() - 1)]; + let cell_x0_pdf = (cols[col_idx].bbox.x0 + x0 as f32) * downscale_factor; + let cell_x1_pdf = (last_col.bbox.x1 + x0 as f32) * downscale_factor; + + let cell_bbox = BBox { + x0: cell_x0_pdf.max(table_bbox.x0), + y0: row_y0_pdf.max(table_bbox.y0), + x1: cell_x1_pdf.min(table_bbox.x1), + y1: row_y1_pdf.min(table_bbox.y1), + }; + + let cell_text = lines + .iter() + .filter(|l| cell_bbox.intersection(&l.bbox) / l.bbox.area() > 0.5) + .map(|l| l.text.as_str()) + .collect::>() + .join(" "); + + cells.push(crate::blocks::TableCell { + text: cell_text, + bbox: cell_bbox, + col_span: col_span as u8, + row_span: 1, + content: Vec::new(), + }); + + col_idx += col_span; + } + + table_rows.push(crate::blocks::TableRow { + cells, + bbox: BBox { + x0: table_bbox.x0, + y0: row_y0_pdf, + x1: table_bbox.x1, + y1: row_y1_pdf, + }, + is_header, + }); + } + + let table_id = self.table_id_counter.fetch_add(1, Ordering::SeqCst); + Ok(TableBlock { + id: table_id, + caption: None, + rows: table_rows, + has_borders: true, + algorithm: TableAlgorithm::Vision, + }) + } + + fn parse_stream( + &self, + lines: &[crate::entities::Line], + table_bbox: &BBox, + ) -> Result { + // 1. Filter lines within table_bbox + let mut table_lines: Vec<_> = lines + .iter() + .filter(|l| table_bbox.contains(&l.bbox)) + .collect(); + + // 2. Sort lines by Y (vertical) + table_lines.sort_by(|a, b| a.bbox.y0.partial_cmp(&b.bbox.y0).unwrap()); + + // 3. Group lines into rows based on Y overlap + let mut rows = Vec::new(); + if table_lines.is_empty() { + let table_id = self.table_id_counter.fetch_add(1, Ordering::SeqCst); + return Ok(TableBlock { + id: table_id, + caption: None, + rows: vec![], + has_borders: false, + algorithm: TableAlgorithm::Unknown, + }); + } + + let mut current_row_lines = vec![table_lines[0]]; + for line in table_lines.iter().skip(1) { + let last_line = current_row_lines.last().unwrap(); + // NOTE: If the next line significantly overlaps vertically or is very close, it's the same row + if line.bbox.y0 + < last_line.bbox.y1 - last_line.bbox.height() * Self::ROW_OVERLAP_THRESHOLD + { + current_row_lines.push(line); + } else { + rows.push(self.process_row_lines(¤t_row_lines)); + current_row_lines = vec![line]; + } + } + rows.push(self.process_row_lines(¤t_row_lines)); + + let table_id = self.table_id_counter.fetch_add(1, Ordering::SeqCst); + Ok(TableBlock { + id: table_id, + caption: None, + rows, + has_borders: false, + algorithm: TableAlgorithm::Stream, + }) + } + + fn process_row_lines(&self, row_lines: &[&crate::entities::Line]) -> crate::blocks::TableRow { + if row_lines.is_empty() { + return crate::blocks::TableRow::default(); + } + + let mut sorted_lines = row_lines.to_vec(); + sorted_lines.sort_by(|a, b| a.bbox.x0.partial_cmp(&b.bbox.x0).unwrap()); + + let mut cells = Vec::new(); + let mut current_cell_text = sorted_lines[0].text.clone(); + let mut current_cell_bbox = sorted_lines[0].bbox.clone(); + + for line in sorted_lines.iter().skip(1) { + // Threshold for horizontal gap between words/cells in a table + // Usually tables have larger gaps than normal text + let horizontal_gap = line.bbox.x0 - current_cell_bbox.x1; + if horizontal_gap < 10.0 { + current_cell_text.push(' '); + current_cell_text.push_str(&line.text); + current_cell_bbox.merge(&line.bbox); + } else { + cells.push(crate::blocks::TableCell { + content: vec![], + text: current_cell_text.trim().to_string(), + bbox: current_cell_bbox, + col_span: 1, + row_span: 1, + }); + current_cell_text = line.text.clone(); + current_cell_bbox = line.bbox.clone(); + } + } + + cells.push(crate::blocks::TableCell { + content: vec![], + text: current_cell_text.trim().to_string(), + bbox: current_cell_bbox, + col_span: 1, + row_span: 1, + }); + + let mut row_bbox = cells[0].bbox.clone(); + for cell in &cells[1..] { + row_bbox.merge(&cell.bbox); + } + + crate::blocks::TableRow { + cells, + bbox: row_bbox, + is_header: false, + } + } +} diff --git a/ferrules-core/src/render/html.rs b/ferrules-core/src/render/html.rs index 53c8d9c..1e2d2ff 100644 --- a/ferrules-core/src/render/html.rs +++ b/ferrules-core/src/render/html.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use build_html::{Html, HtmlContainer, HtmlElement, HtmlPage, HtmlTag}; +use build_html::{Html, HtmlChild, HtmlContainer, HtmlElement, HtmlPage, HtmlTag}; use regex::Regex; use crate::blocks::{Block, BlockType}; @@ -34,12 +34,13 @@ impl HTMLRenderer { .with_html(self.root_element) .to_html_string() } -} -impl Renderer for HTMLRenderer { - type Ok = (); - - fn render_block(&mut self, block: &Block) -> anyhow::Result { + fn render_block_to_container( + block: &Block, + container: &mut HtmlElement, + img_src_path: Option<&PathBuf>, + list_regex: &Regex, + ) -> anyhow::Result<()> { match &block.kind { BlockType::Title(title) => { let level = title.level.clamp(1, 6); @@ -54,39 +55,39 @@ impl Renderer for HTMLRenderer { let el = HtmlElement::new(tag) .with_child(title.text.as_str().into()) .into(); - self.root_element.add_child(el); + container.add_child(el); } BlockType::Header(text_block) => { let el = HtmlElement::new(HtmlTag::Header) .with_child(text_block.text.as_str().into()) .into(); - self.root_element.add_child(el); + container.add_child(el); } BlockType::Footer(text_block) => { let el = HtmlElement::new(HtmlTag::Footer) .with_child(text_block.text.as_str().into()) .into(); - self.root_element.add_child(el); + container.add_child(el); } BlockType::ListBlock(list) => { let mut ul = HtmlElement::new(HtmlTag::UnorderedList); for item in &list.items { - let clean_text = self.list_regex.replace(item, "").into_owned(); + let clean_text = list_regex.replace(item, "").into_owned(); let li = HtmlElement::new(HtmlTag::ListElement) .with_child(clean_text.as_str().into()) .into(); ul.add_child(li); } - self.root_element.add_child(ul.into()); + container.add_child(ul.into()); } BlockType::TextBlock(text_block) => { let el = HtmlElement::new(HtmlTag::ParagraphText) .with_child(text_block.text.as_str().into()) .into(); - self.root_element.add_child(el); + container.add_child(el); } BlockType::Image(image_block) => { - if let Some(img_src_path) = &self.img_src_path { + if let Some(img_src_path) = img_src_path { let mut figure = HtmlElement::new(HtmlTag::Figure); let img_src = img_src_path .join(image_block.path()) @@ -103,17 +104,68 @@ impl Renderer for HTMLRenderer { figure.add_child(figcaption); } - self.root_element.add_child(figure.into()); + container.add_child(figure.into()); } } - _ => { - eprintln!("not implemented yet") + BlockType::Table(table) => { + let mut table_html = String::from(""); + if let Some(caption) = &table.caption { + table_html.push_str(&format!("", caption)); + } + for row in &table.rows { + table_html.push_str(""); + for cell in &row.cells { + let tag = if row.is_header { "th" } else { "td" }; + table_html.push_str(&format!("<{}", tag)); + if cell.col_span > 1 { + table_html.push_str(&format!(" colspan=\"{}\"", cell.col_span)); + } + if cell.row_span > 1 { + table_html.push_str(&format!(" rowspan=\"{}\"", cell.row_span)); + } + table_html.push_str(">"); + + if !cell.content.is_empty() { + // Render content into a temporary container (div) + let mut cell_container = HtmlElement::new(HtmlTag::Div); + for block in &cell.content { + Self::render_block_to_container( + block, + &mut cell_container, + img_src_path, + list_regex, + )?; + } + table_html.push_str(&cell_container.to_html_string()); + } else if !cell.text.is_empty() { + table_html.push_str(&cell.text); + } + + table_html.push_str(&format!("", tag)); + } + table_html.push_str(""); + } + table_html.push_str("
{}
"); + container.add_child(HtmlChild::Raw(table_html)); } } Ok(()) } } +impl Renderer for HTMLRenderer { + type Ok = (); + + fn render_block(&mut self, block: &Block) -> anyhow::Result { + Self::render_block_to_container( + block, + &mut self.root_element, + self.img_src_path.as_ref(), + &self.list_regex, + ) + } +} + #[tracing::instrument(skip_all)] pub fn to_html( blocks: R, diff --git a/ferrules-core/src/utils.rs b/ferrules-core/src/utils.rs index bcdc647..5b89510 100644 --- a/ferrules-core/src/utils.rs +++ b/ferrules-core/src/utils.rs @@ -80,7 +80,28 @@ fn save_doc_images(imgs_dir: &Path, doc: &ParsedDocument) -> anyhow::Result<()> None => continue, } } - blocks::BlockType::Table => todo!(), + blocks::BlockType::Table(table_block) => { + let page_id = block.pages_id.first().unwrap(); + match doc.pages.iter().find(|&p| p.id == *page_id) { + Some(page) => { + assert!(page.height as u32 > 0); + assert!(page.width as u32 > 0); + + let x = (block.bbox.x0 - IMAGE_PADDING as f32) as u32; + let y = (block.bbox.y0 - IMAGE_PADDING as f32) as u32; + let width = (block.bbox.width().max(1.0) as u32 + 2 * IMAGE_PADDING) + .min(page.width as u32); + let height = (block.bbox.height().max(1.0) as u32 + 2 * IMAGE_PADDING) + .min(page.height as u32); + + let crop = page.image.clone().crop(x, y, width, height); + + let output_file = imgs_dir.join(table_block.path()); + crop.save(output_file)?; + } + None => continue, + } + } _ => continue, } } diff --git a/models/table-transformer-structure-recognition_fp16.onnx b/models/table-transformer-structure-recognition_fp16.onnx new file mode 100644 index 0000000..d6aab2d --- /dev/null +++ b/models/table-transformer-structure-recognition_fp16.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ef70ccff2de697c61a8c836ec7ed896d2a58416fe5cfde6f266dc7999006ec3 +size 58355554 diff --git a/models/yolov8s-doclaynet.onnx b/models/yolov8s-doclaynet.onnx index 55cb32f..42dd844 100644 Binary files a/models/yolov8s-doclaynet.onnx and b/models/yolov8s-doclaynet.onnx differ diff --git a/python/export_table_transformer.py b/python/export_table_transformer.py new file mode 100644 index 0000000..1f40348 --- /dev/null +++ b/python/export_table_transformer.py @@ -0,0 +1,60 @@ +import torch +from transformers import TableTransformerForObjectDetection, DetrImageProcessor +import onnx +import onnxruntime +import os + +# Define paths +MODEL_NAME = "microsoft/table-transformer-structure-recognition" +OUTPUT_DIR = "../models" +ONNX_PATH = os.path.join(OUTPUT_DIR, "table-transformer-structure-recognition.onnx") + +os.makedirs(OUTPUT_DIR, exist_ok=True) + +print(f"Loading model: {MODEL_NAME}") +model = TableTransformerForObjectDetection.from_pretrained(MODEL_NAME) +processor = DetrImageProcessor.from_pretrained(MODEL_NAME) +model.eval() + +# Create dummy input +# The model expects pixel_values of shape [batch_size, 3, height, width] +# and pixel_mask of shape [batch_size, height, width] +# Standard size for this model is often 800-1000, let's use 1000x1000 +dummy_input = torch.randn(1, 3, 1000, 1000) +# mask is optional in some exports but good to check. For standard export we usually just pass pixel_values + +print("Exporting to ONNX...") +# Dynamic axes are crucial for variable image sizes +dynamic_axes = { + "pixel_values": {0: "batch_size", 2: "height", 3: "width"}, + "logits": {0: "batch_size", 1: "num_queries"}, + "pred_boxes": {0: "batch_size", 1: "num_queries"}, +} + +torch.onnx.export( + model, + dummy_input, + ONNX_PATH, + export_params=True, + opset_version=14, + do_constant_folding=True, + input_names=["pixel_values"], + output_names=["logits", "pred_boxes"], + dynamic_axes=dynamic_axes, +) + +print(f"Model exported to {ONNX_PATH}") + +# Verify +print("Verifying ONNX model...") +onnx_model = onnx.load(ONNX_PATH) +onnx.checker.check_model(onnx_model) +print("ONNX model verified.") + +# Quantization / FP16 (Optional but recommended for size) +print("Converting to FP16...") +from onnxconverter_common import float16 +model_fp16 = float16.convert_float_to_float16(onnx_model) +ONNX_FP16_PATH = ONNX_PATH.replace(".onnx", "_fp16.onnx") +onnx.save(model_fp16, ONNX_FP16_PATH) +print(f"FP16 Model saved to {ONNX_FP16_PATH}")