diff --git a/README.md b/README.md index 5d08140..ba676de 100644 --- a/README.md +++ b/README.md @@ -1,35 +1,37 @@ # usls -A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv8](https://github.com/ultralytics/ultralytics) `(Classification, Segmentation, Detection and Pose Detection)`, [YOLOv9](https://github.com/WongKinYiu/yolov9), [RTDETR](https://arxiv.org/abs/2304.08069), [CLIP](https://github.com/openai/CLIP), [DINOv2](https://github.com/facebookresearch/dinov2), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [BLIP](https://arxiv.org/abs/2201.12086), and others. Many execution providers are supported, sunch as `CUDA`, `TensorRT` and `CoreML`. +A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv8](https://github.com/ultralytics/ultralytics) `(Classification, Segmentation, Detection and Pose Detection)`, [YOLOv9](https://github.com/WongKinYiu/yolov9), [RTDETR](https://arxiv.org/abs/2304.08069), [CLIP](https://github.com/openai/CLIP), [DINOv2](https://github.com/facebookresearch/dinov2), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [BLIP](https://arxiv.org/abs/2201.12086), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) and others. Many execution providers are supported, sunch as `CUDA`, `TensorRT` and `CoreML`. ## Supported Models -| Model | Example | CUDA(f32) | CUDA(f16) | TensorRT(f32) | TensorRT(f16) | -| :-------------------: | :----------------------: | :----------------: | :----------------: | :------------------------: | :-----------------------: | -| YOLOv8-detection | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | -| YOLOv8-pose | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | -| YOLOv8-classification | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | -| YOLOv8-segmentation | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | -| YOLOv8-OBB | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** | -| YOLOv9 | [demo](examples/yolov9) | ✅ | ✅ | ✅ | ✅ | -| RT-DETR | [demo](examples/rtdetr) | ✅ | ✅ | ✅ | ✅ | -| FastSAM | [demo](examples/fastsam) | ✅ | ✅ | ✅ | ✅ | -| YOLO-World | [demo](examples/yolo-world) | ✅ | ✅ | ✅ | ✅ | -| DINOv2 | [demo](examples/dinov2) | ✅ | ✅ | ✅ | ✅ | -| CLIP | [demo](examples/clip) | ✅ | ✅ | ✅ visual
❌ textual | ✅ visual
❌ textual | -| BLIP | [demo](examples/blip) | ✅ | ✅ | ✅ visual
❌ textual | ✅ visual
❌ textual | -| OCR(DB, SVTR) | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** | ***TODO*** | +| Model | Example | CUDA
f32 | CUDA
f16 | TensorRT
f32 | TensorRT
f16 | +| :-----------------------------: | :----------------------: | :-----------: | :-----------: | :------------------------: | :-----------------------: | +| **YOLOv8-detection** | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | +| **YOLOv8-pose** | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | +| **YOLOv8-classification** | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | +| **YOLOv8-segmentation** | [demo](examples/yolov8) | ✅ | ✅ | ✅ | ✅ | +| **YOLOv8-OBB** | TODO | TODO | TODO | TODO | TODO | +| **YOLOv9** | [demo](examples/yolov9) | ✅ | ✅ | ✅ | ✅ | +| **RT-DETR** | [demo](examples/rtdetr) | ✅ | ✅ | ✅ | ✅ | +| **FastSAM** | [demo](examples/fastsam) | ✅ | ✅ | ✅ | ✅ | +| **YOLO-World** | [demo](examples/yolo-world) | ✅ | ✅ | ✅ | ✅ | +| **DINOv2** | [demo](examples/dinov2) | ✅ | ✅ | ✅ | ✅ | +| **CLIP** | [demo](examples/clip) | ✅ | ✅ | ✅ visual
❌ textual | ✅ visual
❌ textual | +| **BLIP** | [demo](examples/blip) | ✅ | ✅ | ✅ visual
❌ textual | ✅ visual
❌ textual | +| [**DB(Text Detection)**](https://arxiv.org/abs/1911.08947) | [demo](examples/db) | ✅ | ❌ | ✅ | ✅ | +| **SVTR, TROCR** | TODO | TODO | TODO | TODO | TODO | ## Solution Models Additionally, this repo also provides some solution models such as pedestrian `fall detection`, `head detection`, `trash detection`, and more. -| Model | Example | -| :---------------------: | :------------------------------: | -| face-landmark detection | [demo](examples/yolov8-face) | -| head detection | [demo](examples/yolov8-head) | -| fall detection | [demo](examples/yolov8-falldown) | -| trash detection | [demo](examples/yolov8-plastic-bag) | +| Model | Example | +| :-------------------------------------------------------: | :------------------------------: | +| **face-landmark detection**
**人脸 & 关键点检测** | [demo](examples/yolov8-face) | +| **head detection**
**人头检测** | [demo](examples/yolov8-head) | +| **fall detection**
**摔倒检测** | [demo](examples/yolov8-falldown) | +| **trash detection**
**垃圾检测** | [demo](examples/yolov8-plastic-bag) | +| **text detection(PPOCR-det v3, v4)**
**PPOCR文本检测** | [demo](examples/db) | ## Demo @@ -63,8 +65,8 @@ cargo add --git https://github.com/jamjamjon/usls cargo add usls ``` - #### 3. Set `Options` and build model + ```Rust let options = Options::default() .with_model("../models/yolov8m-seg-dyn-f16.onnx") @@ -100,7 +102,6 @@ let x = DataLoader::try_read("./assets/bus.jpg")?; let _y = model.run(&[x])?; ``` - ## Script: converte ONNX model from `float32` to `float16` ```python diff --git a/assets/math.jpg b/assets/math.jpg new file mode 100644 index 0000000..0b5b656 Binary files /dev/null and b/assets/math.jpg differ diff --git a/examples/db/README.md b/examples/db/README.md new file mode 100644 index 0000000..91c681c --- /dev/null +++ b/examples/db/README.md @@ -0,0 +1,39 @@ +## Quick Start + +```shell +cargo run -r --example db +``` + +## Or you can manully + +### 1. Donwload ONNX Model + +[ppocr-v3-db-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v3-db-dyn.onnx) +[ppocr-v4-db-dyn](https://github.com/jamjamjon/assets/releases/download/v0.0.1/ppocr-v4-db-dyn.onnx) + +### 2. Specify the ONNX model path in `main.rs` + +```Rust +let options = Options::default() + .with_model("ONNX_PATH") // <= modify this + .with_profile(false); +``` + +### 3. Run + +```bash +cargo run -r --example db +``` + +### Speed test + +| Model | Image size | TensorRT
f16 | TensorRT
f32 | CUDA
f32 | +| --------------- | ---------- | ----------------- | ----------------- | ------------- | +| ppocr-v3-db-dyn | 640x640 | 1.8585ms | 2.5739ms | 4.3314ms | +| ppocr-v4-db-dyn | 640x640 | 2.0507ms | 2.8264ms | 6.6064ms | + +***Test on RTX3060*** + +## Results + +![](./demo.jpg) diff --git a/examples/db/demo.jpg b/examples/db/demo.jpg new file mode 100644 index 0000000..bb0db03 Binary files /dev/null and b/examples/db/demo.jpg differ diff --git a/examples/db/main.rs b/examples/db/main.rs new file mode 100644 index 0000000..24093a5 --- /dev/null +++ b/examples/db/main.rs @@ -0,0 +1,25 @@ +use usls::{models::DB, DataLoader, Options}; + +fn main() -> Result<(), Box> { + // build model + let options = Options::default() + .with_model("../models/ppocr-v4-db-dyn.onnx") + .with_i00((1, 1, 4).into()) + .with_i02((608, 640, 960).into()) + .with_i03((608, 640, 960).into()) + .with_confs(&[0.7]) + .with_saveout("DB-Text-Detection") + .with_dry_run(5) + // .with_trt(0) + // .with_fp16(true) + .with_profile(true); + let mut model = DB::new(&options)?; + + // load image + let x = DataLoader::try_read("./assets/math.jpg")?; + + // run + let _y = model.run(&[x])?; + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index 9004f52..749d12a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ pub mod models; pub mod ops; mod options; mod point; +mod polygon; mod rect; mod results; mod rotated_rect; @@ -30,6 +31,7 @@ pub use metric::Metric; pub use min_opt_max::MinOptMax; pub use options::Options; pub use point::Point; +pub use polygon::Polygon; pub use rect::Rect; pub use results::Results; pub use rotated_rect::RotatedRect; diff --git a/src/models/db.rs b/src/models/db.rs new file mode 100644 index 0000000..870f0e1 --- /dev/null +++ b/src/models/db.rs @@ -0,0 +1,155 @@ +use crate::{ + ops, Annotator, Bbox, DynConf, MinOptMax, Options, OrtEngine, Point, Polygon, Results, +}; +use anyhow::Result; +use image::{DynamicImage, ImageBuffer}; +use ndarray::{Array, Axis, IxDyn}; + +#[derive(Debug)] +pub struct DB { + engine: OrtEngine, + height: MinOptMax, + width: MinOptMax, + batch: MinOptMax, + annotator: Annotator, + confs: DynConf, + saveout: Option, + names: Option>, +} + +impl DB { + pub fn new(options: &Options) -> Result { + let engine = OrtEngine::new(options)?; + let (batch, height, width) = ( + engine.inputs_minoptmax()[0][0].to_owned(), + engine.inputs_minoptmax()[0][2].to_owned(), + engine.inputs_minoptmax()[0][3].to_owned(), + ); + let annotator = Annotator::default(); + let names = Some(vec!["Text".to_string()]); + let confs = DynConf::new(&options.confs, 1); + engine.dry_run()?; + + Ok(Self { + engine, + names, + confs, + height, + width, + batch, + saveout: options.saveout.to_owned(), + annotator, + }) + } + + pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { + let xs_ = ops::letterbox(xs, self.height.opt as u32, self.width.opt as u32)?; + let ys = self.engine.run(&[xs_])?; + let ys = self.postprocess(ys, xs)?; + match &self.saveout { + None => {} + Some(saveout) => { + for (img0, y) in xs.iter().zip(ys.iter()) { + let mut img = img0.to_rgb8(); + self.annotator.plot(&mut img, y); + self.annotator.save(&img, saveout); + } + } + } + Ok(ys) + } + + pub fn postprocess( + &self, + xs: Vec>, + xs0: &[DynamicImage], + ) -> Result> { + let mut ys = Vec::new(); + for (idx, mask) in xs[0].axis_iter(Axis(0)).enumerate() { + let mut ys_bbox = Vec::new(); + // input image + let image_width = xs0[idx].width() as f32; + let image_height = xs0[idx].height() as f32; + + // h,w,1 + let h = mask.dim()[1]; + let w = mask.dim()[2]; + let mask = mask.into_shape((h, w, 1))?.into_owned(); + + // build image from ndarray + let mask_im: ImageBuffer, Vec> = + ImageBuffer::from_raw(w as u32, h as u32, mask.into_raw_vec()) + .expect("Faild to create image from ndarray"); + let mut mask_im = image::DynamicImage::from(mask_im); + + // rescale + let (_, w_mask, h_mask) = ops::scale_wh(image_width, image_height, w as f32, h as f32); + let mask_original = mask_im.crop(0, 0, w_mask as u32, h_mask as u32); + let mask_original = mask_original.resize_exact( + image_width as u32, + image_height as u32, + image::imageops::FilterType::Triangle, + ); + + // contours + let contours: Vec> = + imageproc::contours::find_contours(&mask_original.into_luma8()); + + for contour in contours.iter() { + // polygon + let points: Vec = contour + .points + .iter() + .map(|p| Point::new(p.x as f32, p.y as f32)) + .collect(); + let polygon = Polygon::new(&points); + let mut rect = polygon.find_min_rect(); + + // min size filter + if rect.height() < 3.0 || rect.width() < 3.0 { + continue; + } + + // confs filter + let confidence = polygon.area() / rect.area(); + if confidence < self.confs[0] { + continue; + } + + // TODO: expand polygon + let unclip_ratio = 1.5; + let delta = rect.area() * unclip_ratio / rect.perimeter(); + + // save + let y_bbox = Bbox::new( + rect.expand(delta, delta, image_width, image_height), + 0, + confidence, + self.names.as_ref().map(|names| names[0].clone()), + ); + ys_bbox.push(y_bbox); + } + let y = Results { + probs: None, + bboxes: Some(ys_bbox), + keypoints: None, + masks: None, + }; + ys.push(y); + } + + Ok(ys) + } + + pub fn batch(&self) -> isize { + self.batch.opt + } + + pub fn width(&self) -> isize { + self.width.opt + } + + pub fn height(&self) -> isize { + self.height.opt + } +} diff --git a/src/models/mod.rs b/src/models/mod.rs index 9dc0d3f..61a95b0 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,11 +1,13 @@ mod blip; mod clip; +mod db; mod dinov2; mod rtdetr; mod yolo; pub use blip::Blip; pub use clip::Clip; +pub use db::DB; pub use dinov2::Dinov2; pub use rtdetr::RTDETR; pub use yolo::YOLO; diff --git a/src/models/yolo.rs b/src/models/yolo.rs index e783bda..f2f37da 100644 --- a/src/models/yolo.rs +++ b/src/models/yolo.rs @@ -128,13 +128,6 @@ impl YOLO { }) } - // pub fn run_with_dl(&mut self, dl: &Dataloader) -> Result> { - // for (images, paths) in dataloader { - // self.run(&images) - // } - // Ok(()) - // } - pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { let xs_ = ops::letterbox(xs, self.height() as u32, self.width() as u32)?; let ys = self.engine.run(&[xs_])?; @@ -296,10 +289,9 @@ impl YOLO { // build image from ndarray let mask_im: ImageBuffer, Vec> = - match ImageBuffer::from_raw(nw as u32, nh as u32, mask.into_raw_vec()) { - Some(image) => image, - None => panic!("can not create image from ndarray"), - }; + ImageBuffer::from_raw(nw as u32, nh as u32, mask.into_raw_vec()) + .expect("Faild to create image from ndarray"); + let mut mask_im = image::DynamicImage::from(mask_im); // -> dyn // rescale masks diff --git a/src/polygon.rs b/src/polygon.rs new file mode 100644 index 0000000..42d646e --- /dev/null +++ b/src/polygon.rs @@ -0,0 +1,54 @@ +use crate::{Point, Rect, RotatedRect}; + +#[derive(Default, Debug, PartialOrd, PartialEq, Clone)] +pub struct Polygon { + points: Vec, +} + +impl Polygon { + pub fn new(points: &[Point]) -> Self { + // TODO: refactor + Self { + points: points.to_vec(), + } + } + + pub fn area(&self) -> f32 { + // make sure points are already sorted + let mut area = 0.0; + let n = self.points.len(); + for i in 0..n { + let j = (i + 1) % n; + area += self.points[i].x * self.points[j].y; + area -= self.points[j].x * self.points[i].y; + } + area.abs() / 2.0 + } + + pub fn find_min_rect(&self) -> Rect { + let (mut min_x, mut min_y, mut max_x, mut max_y) = (f32::MAX, f32::MAX, f32::MIN, f32::MIN); + for point in self.points.iter() { + if point.x <= min_x { + min_x = point.x + } + if point.x > max_x { + max_x = point.x + } + if point.y <= min_y { + min_y = point.y + } + if point.y > max_y { + max_y = point.y + } + } + ((min_x, min_y), (max_x, max_y)).into() + } + + pub fn find_min_rotated_rect() -> RotatedRect { + todo!() + } + + pub fn expand(&mut self) -> Self { + todo!() + } +} diff --git a/src/rect.rs b/src/rect.rs index 8ce25f7..80d5c2a 100644 --- a/src/rect.rs +++ b/src/rect.rs @@ -120,6 +120,10 @@ impl Rect { self.height() * self.width() } + pub fn perimeter(&self) -> f32 { + (self.height() + self.width()) * 2.0 + } + pub fn is_empty(&self) -> bool { self.area() == 0.0 } @@ -150,6 +154,15 @@ impl Rect { && self.ymin() <= other.ymin() && self.ymax() >= other.ymax() } + + pub fn expand(&mut self, x: f32, y: f32, max_x: f32, max_y: f32) -> Self { + Self::from_xyxy( + (self.xmin() - x).max(0.0f32).min(max_x), + (self.ymin() - y).max(0.0f32).min(max_y), + (self.xmax() + x).max(0.0f32).min(max_x), + (self.ymax() + y).max(0.0f32).min(max_y), + ) + } } #[cfg(test)]