diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index e2c1b0a..0000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,47 +0,0 @@ -## v0.0.5 - 2024-07-12 - -### Changed - -- Accelerated `YOLO`'s post-processing using `Rayon`. Now, `YOLOv8-seg` takes only around **~8ms (~20ms in the previous version)**, depending on your machine. Note that this repo's implementation of `YOLOv8-Segment` saves not only the masks but also their contour points. The official `YOLOv8` Python version only saves the masks, making it appear much faster. -- Merged all `YOLOv8-related` solution models into YOLO examples. -- Consolidated all `YOLO-series` model examples into the YOLO example. -- Refactored the `YOLO` struct to unify all `YOLO versions` and `YOLO tasks`. It now supports user-defined YOLO models with different `Preds Tensor Formats`. -- Introduced a new `Nms` trait, combining `apply_bboxes_nms()` and `apply_mbrs_nms()` into `apply_nms()`. - -### Added - -- Added support for `YOLOv6` and `YOLOv7`. -- Updated documentation for `y.rs`. -- Updated documentation for `bbox.rs`. -- Updated the `README.md`. -- Added `with_yolo_preds()` to `Options`. -- Added support for `Depth-Anything-v2`. -- Added `RTDETR` to the `YOLOVersion` struct. - -### Removed - -- Merged the following models' examples into the YOLOv8 example: `yolov8-face`, `yolov8-falldown`, `yolov8-head`, `yolov8-trash`, `fastsam`, and `face-parsing`. -- Removed `anchors_first`, `conf_independent`, and their related methods from `Options`. - - -## v0.0.4 - 2024-06-30 - -### Added - -- Add X struct to handle input and preprocessing -- Add Ops struct to manage common operations -- Use SIMD (fast_image_resize) to accelerate model pre-processing and post-processing.YOLOv8-seg post-processing (~120ms => ~20ms), Depth-Anything post-processing (~23ms => ~2ms). - -### Deprecated - -- Mark `Ops::descale_mask()` as deprecated. - -### Fixed - -### Changed - -### Removed - -### Refactored - -### Others diff --git a/Cargo.toml b/Cargo.toml index a54224d..ad1b662 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,13 @@ [package] name = "usls" -version = "0.0.6" +version = "0.0.7" edition = "2021" description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." repository = "https://github.com/jamjamjon/usls" authors = ["Jamjamjon "] license = "MIT" readme = "README.md" -exclude = ["assets/*", "examples/*"] +exclude = ["assets/*", "examples/*", "scripts/*", "runs/*"] [dependencies] clap = { version = "4.2.4", features = ["derive"] } @@ -44,4 +44,15 @@ ab_glyph = "0.2.23" geo = "0.28.0" prost = "0.12.4" human_bytes = "0.4.3" -fast_image_resize = { version = "4.0.0", git = "https://github.com/jamjamjon/fast_image_resize", branch = "dev" , features = ["image"]} +fast_image_resize = { version = "4.2.1", features = ["image"]} + + +[dev-dependencies] +criterion = "0.5.1" + +[[bench]] +name = "yolo" +harness = false + +[lib] +bench = false diff --git a/README.md b/README.md index f1e9195..6104bf6 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,21 @@ # usls -[![Static Badge](https://img.shields.io/crates/v/usls.svg?style=for-the-badge&logo=rust)](https://crates.io/crates/usls) ![Static Badge](https://img.shields.io/crates/d/usls?style=for-the-badge) [![Static Badge](https://img.shields.io/badge/Documents-usls-blue?style=for-the-badge&logo=docs.rs)](https://docs.rs/usls) [![Static Badge](https://img.shields.io/badge/GitHub-black?style=for-the-badge&logo=github)](https://github.com/jamjamjon/usls) +[![Static Badge](https://img.shields.io/crates/v/usls.svg?style=for-the-badge&logo=rust)](https://crates.io/crates/usls) [![Static Badge](https://img.shields.io/badge/ONNXRuntime-v1.17.x-yellow?style=for-the-badge&logo=docs.rs)](https://github.com/microsoft/onnxruntime/releases) [![Static Badge](https://img.shields.io/badge/CUDA-11.x-green?style=for-the-badge&logo=docs.rs)](https://developer.nvidia.com/cuda-toolkit-archive) [![Static Badge](https://img.shields.io/badge/TRT-8.6.x.x-blue?style=for-the-badge&logo=docs.rs)](https://developer.nvidia.com/tensorrt) +[![Static Badge](https://img.shields.io/badge/Documents-usls-blue?style=for-the-badge&logo=docs.rs)](https://docs.rs/usls) ![Static Badge](https://img.shields.io/crates/d/usls?style=for-the-badge) + + + +A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [RTDETR](https://arxiv.org/abs/2304.08069), [SAM](https://github.com/facebookresearch/segment-anything), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [CLIP](https://github.com/openai/CLIP), [DINOv2](https://github.com/facebookresearch/dinov2), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [BLIP](https://arxiv.org/abs/2201.12086), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything), [MODNet](https://github.com/ZHKKKe/MODNet) and others. + + +| Segment Anything | +| :------------------------------------------------------: | +| | + +| YOLO + SAM | +| :------------------------------------------------------: | +| | -A Rust library integrated with **ONNXRuntime**, providing a collection of **Computer Vison** and **Vision-Language** models including [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [RTDETR](https://arxiv.org/abs/2304.08069), [CLIP](https://github.com/openai/CLIP), [DINOv2](https://github.com/facebookresearch/dinov2), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [BLIP](https://arxiv.org/abs/2201.12086), [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR), [Depth-Anything](https://github.com/LiheYoung/Depth-Anything), [MODNet](https://github.com/ZHKKKe/MODNet) and others. | Monocular Depth Estimation | | :--------------------------------------------------------------: | @@ -13,9 +26,7 @@ A Rust library integrated with **ONNXRuntime**, providing a collection of **Comp | :----------------------------------------------------: | :------------------------------------------------: | | | | -| Portrait Matting | -| :------------------------------------------------------: | -| | + ## Supported Models @@ -30,6 +41,10 @@ A Rust library integrated with **ONNXRuntime**, providing a collection of **Comp | [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | | [RTDETR](https://arxiv.org/abs/2304.08069) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | | [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | +| [SAM](https://github.com/facebookresearch/segment-anything) | Segmente Anything | [demo](examples/sam) | ✅ | ✅ | | | +| [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) | Segmente Anything | [demo](examples/sam) | ✅ | ✅ | | | +| [EdgeSAM](https://github.com/chongzhou96/EdgeSAM) | Segmente Anything | [demo](examples/sam) | ✅ | ✅ | | | +| [SAM-HQ](https://github.com/SysCV/sam-hq) | Segmente Anything | [demo](examples/sam) | ✅ | ✅ | | | | [YOLO-World](https://github.com/AILab-CVC/YOLO-World) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | | [DINOv2](https://github.com/facebookresearch/dinov2) | Vision-Self-Supervised | [demo](examples/dinov2) | ✅ | ✅ | ✅ | ✅ | | [CLIP](https://github.com/openai/CLIP) | Vision-Language | [demo](examples/clip) | ✅ | ✅ | ✅ visual
❌ textual | ✅ visual
❌ textual | @@ -64,103 +79,13 @@ cargo run -r --example yolo # blip, clip, yolop, svtr, db, ... ## Integrate into your own project -### 1. Add `usls` as a dependency to your project's `Cargo.toml` + ```Shell +# Add `usls` as a dependency to your project's `Cargo.toml` cargo add usls -``` -Or you can use specific commit - -```Shell +# Or you can use specific commit usls = { git = "https://github.com/jamjamjon/usls", rev = "???sha???"} + ``` - -### 2. Build model - -```Rust -let options = Options::default() - .with_yolo_version(YOLOVersion::V5) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR - .with_yolo_task(YOLOTask::Classify) // YOLOTask: Classify, Detect, Pose, Segment, Obb - .with_model("xxxx.onnx")?; -let mut model = YOLO::new(options)?; -``` - -- If you want to run your model with TensorRT or CoreML - - ```Rust - let options = Options::default() - .with_trt(0) // using cuda by default - // .with_coreml(0) - ``` -- If your model has dynamic shapes - - ```Rust - let options = Options::default() - .with_i00((1, 2, 4).into()) // dynamic batch - .with_i02((416, 640, 800).into()) // dynamic height - .with_i03((416, 640, 800).into()) // dynamic width - ``` -- If you want to set a confidence for each category - - ```Rust - let options = Options::default() - .with_confs(&[0.4, 0.15]) // class_0: 0.4, others: 0.15 - ``` -- Go check [Options](src/core/options.rs) for more model options. - -#### 3. Load images - -- Build `DataLoader` to load images - -```Rust -let dl = DataLoader::default() - .with_batch(model.batch.opt as usize) - .load("./assets/")?; - -for (xs, _paths) in dl { - let _y = model.run(&xs)?; -} -``` - -- Or simply read one image - -```Rust -let x = vec![DataLoader::try_read("./assets/bus.jpg")?]; -let y = model.run(&x)?; -``` - -#### 4. Annotate and save - -```Rust -let annotator = Annotator::default().with_saveout("YOLO"); -annotator.annotate(&x, &y); -``` - -#### 5. Get results - -The inference outputs of provided models will be saved to `Vec`. - -- You can get detection bboxes with `y.bboxes()`: - - ```Rust - let ys = model.run(&xs)?; - for y in ys { - // bboxes - if let Some(bboxes) = y.bboxes() { - for bbox in bboxes { - println!( - "Bbox: {}, {}, {}, {}, {}, {}", - bbox.xmin(), - bbox.ymin(), - bbox.xmax(), - bbox.ymax(), - bbox.confidence(), - bbox.id(), - ) - } - } - } - ``` - -- Other: [Docs](https://docs.rs/usls/latest/usls/struct.Y.html) diff --git a/assets/dog.jpg b/assets/dog.jpg new file mode 100644 index 0000000..583f69e Binary files /dev/null and b/assets/dog.jpg differ diff --git a/assets/truck.jpg b/assets/truck.jpg new file mode 100644 index 0000000..6b98688 Binary files /dev/null and b/assets/truck.jpg differ diff --git a/benches/yolo.rs b/benches/yolo.rs new file mode 100644 index 0000000..4868ba6 --- /dev/null +++ b/benches/yolo.rs @@ -0,0 +1,96 @@ +use anyhow::Result; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +use usls::{coco, models::YOLO, DataLoader, Options, Vision, YOLOTask, YOLOVersion}; + +enum Stage { + Pre, + Run, + Post, + Pipeline, +} + +fn yolo_stage_bench( + model: &mut YOLO, + x: &[image::DynamicImage], + stage: Stage, + n: u64, +) -> std::time::Duration { + let mut t_pre = std::time::Duration::new(0, 0); + let mut t_run = std::time::Duration::new(0, 0); + let mut t_post = std::time::Duration::new(0, 0); + let mut t_pipeline = std::time::Duration::new(0, 0); + for _ in 0..n { + let t0 = std::time::Instant::now(); + let xs = model.preprocess(x).unwrap(); + t_pre += t0.elapsed(); + + let t = std::time::Instant::now(); + let xs = model.inference(xs).unwrap(); + t_run += t.elapsed(); + + let t = std::time::Instant::now(); + let _ys = black_box(model.postprocess(xs, x).unwrap()); + t_post += t.elapsed(); + t_pipeline += t0.elapsed(); + } + match stage { + Stage::Pre => t_pre, + Stage::Run => t_run, + Stage::Post => t_post, + Stage::Pipeline => t_pipeline, + } +} + +pub fn benchmark_cuda(c: &mut Criterion, h: isize, w: isize) -> Result<()> { + let mut group = c.benchmark_group(format!("YOLO ({}-{})", w, h)); + group + .significance_level(0.05) + .sample_size(80) + .measurement_time(std::time::Duration::new(20, 0)); + + let options = Options::default() + .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR + .with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb + .with_model("yolov8m-dyn.onnx")? + .with_cuda(0) + // .with_cpu() + .with_dry_run(0) + .with_i00((1, 1, 4).into()) + .with_i02((320, h, 1280).into()) + .with_i03((320, w, 1280).into()) + .with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15 + .with_names2(&coco::KEYPOINTS_NAMES_17); + let mut model = YOLO::new(options)?; + + let xs = vec![DataLoader::try_read("./assets/bus.jpg")?]; + + group.bench_function("pre-process", |b| { + b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Pre, n)) + }); + + group.bench_function("run", |b| { + b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Run, n)) + }); + + group.bench_function("post-process", |b| { + b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Post, n)) + }); + + group.bench_function("pipeline", |b| { + b.iter_custom(|n| yolo_stage_bench(&mut model, &xs, Stage::Pipeline, n)) + }); + + group.finish(); + Ok(()) +} + +pub fn criterion_benchmark(c: &mut Criterion) { + // benchmark_cuda(c, 416, 416).unwrap(); + benchmark_cuda(c, 640, 640).unwrap(); + benchmark_cuda(c, 448, 768).unwrap(); + // benchmark_cuda(c, 800, 800).unwrap(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..b1335e6 --- /dev/null +++ b/build.rs @@ -0,0 +1,5 @@ +fn main() { + // Need this for CoreML. See: https://ort.pyke.io/perf/execution-providers#coreml + #[cfg(target_os = "macos")] + println!("cargo:rustc-link-arg=-fapple-link-rtlib"); +} diff --git a/examples/sam/README.md b/examples/sam/README.md new file mode 100644 index 0000000..fe5c596 --- /dev/null +++ b/examples/sam/README.md @@ -0,0 +1,21 @@ +## Quick Start + +```Shell + +# SAM +cargo run -r --example sam + +# MobileSAM +cargo run -r --example sam -- --kind mobile-sam + +# EdgeSAM +cargo run -r --example sam -- --kind edge-sam + +# SAM-HQ +cargo run -r --example sam -- --kind sam-hq +``` + + +## Results + +![](./demo.png) diff --git a/examples/sam/demo.png b/examples/sam/demo.png new file mode 100644 index 0000000..f62c6a5 Binary files /dev/null and b/examples/sam/demo.png differ diff --git a/examples/sam/main.rs b/examples/sam/main.rs new file mode 100644 index 0000000..0e6ef4d --- /dev/null +++ b/examples/sam/main.rs @@ -0,0 +1,106 @@ +use clap::Parser; + +use usls::{ + models::{SamKind, SamPrompt, SAM}, + Annotator, DataLoader, Options, +}; + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +pub struct Args { + #[arg(long, value_enum, default_value_t = SamKind::Sam)] + pub kind: SamKind, + + #[arg(long, default_value_t = 0)] + pub device_id: usize, + + #[arg(long)] + pub use_low_res_mask: bool, +} + +fn main() -> Result<(), Box> { + let args = Args::parse(); + + // Options + let (options_encoder, options_decoder, saveout) = match args.kind { + SamKind::Sam => { + let options_encoder = Options::default() + // .with_model("sam-vit-b-encoder.onnx")?; + .with_model("sam-vit-b-encoder-u8.onnx")?; + + let options_decoder = Options::default() + .with_i00((1, 1, 1).into()) + .with_i11((1, 1, 1).into()) + .with_i21((1, 1, 1).into()) + .with_sam_kind(SamKind::Sam) + // .with_model("sam-vit-b-decoder.onnx")?; + // .with_model("sam-vit-b-decoder-singlemask.onnx")?; + .with_model("sam-vit-b-decoder-u8.onnx")?; + (options_encoder, options_decoder, "SAM") + } + SamKind::MobileSam => { + let options_encoder = Options::default().with_model("mobile-sam-vit-t-encoder.onnx")?; + + let options_decoder = Options::default() + .with_i00((1, 1, 1).into()) + .with_i11((1, 1, 1).into()) + .with_i21((1, 1, 1).into()) + .with_sam_kind(SamKind::MobileSam) + .with_model("mobile-sam-vit-t-decoder.onnx")?; + (options_encoder, options_decoder, "Mobile-SAM") + } + SamKind::SamHq => { + let options_encoder = Options::default().with_model("sam-hq-vit-t-encoder.onnx")?; + + let options_decoder = Options::default() + .with_i00((1, 1, 1).into()) + .with_i21((1, 1, 1).into()) + .with_i31((1, 1, 1).into()) + .with_sam_kind(SamKind::SamHq) + .with_model("sam-hq-vit-t-decoder.onnx")?; + (options_encoder, options_decoder, "SAM-HQ") + } + SamKind::EdgeSam => { + let options_encoder = Options::default().with_model("edge-sam-3x-encoder.onnx")?; + let options_decoder = Options::default() + .with_i00((1, 1, 1).into()) + .with_i11((1, 1, 1).into()) + .with_i21((1, 1, 1).into()) + .with_sam_kind(SamKind::EdgeSam) + .with_model("edge-sam-3x-decoder.onnx")?; + (options_encoder, options_decoder, "Edge-SAM") + } + }; + let options_encoder = options_encoder + .with_cuda(args.device_id) + .with_i00((1, 1, 1).into()) + .with_i02((800, 1024, 1024).into()) + .with_i03((800, 1024, 1024).into()); + let options_decoder = options_decoder + .with_cuda(args.device_id) + .use_low_res_mask(args.use_low_res_mask) + .with_find_contours(true); + + // Build model + let mut model = SAM::new(options_encoder, options_decoder)?; + + // Load image + let xs = vec![DataLoader::try_read("./assets/truck.jpg")?]; + + // Build annotator + let annotator = Annotator::default().with_saveout(saveout); + + // Prompt + let prompts = vec![ + SamPrompt::default() + // .with_postive_point(500., 375.), // postive point + // .with_negative_point(774., 366.), // negative point + .with_bbox(215., 297., 643., 459.), // bbox + ]; + + // Run & Annotate + let ys = model.run(&xs, &prompts)?; + annotator.annotate(&xs, &ys); + + Ok(()) +} diff --git a/examples/yolo-sam/demo.png b/examples/yolo-sam/demo.png new file mode 100644 index 0000000..682facd Binary files /dev/null and b/examples/yolo-sam/demo.png differ diff --git a/examples/yolo-sam/main.rs b/examples/yolo-sam/main.rs new file mode 100644 index 0000000..0865b81 --- /dev/null +++ b/examples/yolo-sam/main.rs @@ -0,0 +1,63 @@ +use usls::{ + models::{SamKind, SamPrompt, YOLOTask, YOLOVersion, SAM, YOLO}, + Annotator, DataLoader, Options, Vision, +}; + +fn main() -> Result<(), Box> { + // build SAM + let options_encoder = Options::default() + .with_i00((1, 1, 1).into()) + .with_model("mobile-sam-vit-t-encoder.onnx")?; + let options_decoder = Options::default() + .with_i11((1, 1, 1).into()) + .with_i21((1, 1, 1).into()) + .with_find_contours(true) + .with_sam_kind(SamKind::Sam) + .with_model("mobile-sam-vit-t-decoder.onnx")?; + let mut sam = SAM::new(options_encoder, options_decoder)?; + + // build YOLOv8-Det + let options_yolo = Options::default() + .with_yolo_version(YOLOVersion::V8) + .with_yolo_task(YOLOTask::Detect) + .with_model("yolov8m-dyn.onnx")? + .with_cuda(0) + .with_i00((1, 1, 4).into()) + .with_i02((416, 640, 800).into()) + .with_i03((416, 640, 800).into()) + .with_find_contours(false) + .with_confs(&[0.45]); + let mut yolo = YOLO::new(options_yolo)?; + + // load one image + let xs = vec![DataLoader::try_read("./assets/dog.jpg")?]; + + // build annotator + let annotator = Annotator::default() + .with_bboxes_thickness(7) + .without_bboxes_name(true) + .without_bboxes_conf(true) + .without_mbrs(true) + .with_saveout("YOLO+SAM"); + + // run & annotate + let ys_det = yolo.run(&xs)?; + for y_det in ys_det { + if let Some(bboxes) = y_det.bboxes() { + for bbox in bboxes { + let ys_sam = sam.run( + &xs, + &[SamPrompt::default().with_bbox( + bbox.xmin(), + bbox.ymin(), + bbox.xmax(), + bbox.ymax(), + )], + )?; + annotator.annotate(&xs, &ys_sam); + } + } + } + + Ok(()) +} diff --git a/examples/yolo/README.md b/examples/yolo/README.md index 22390f4..638088e 100644 --- a/examples/yolo/README.md +++ b/examples/yolo/README.md @@ -25,29 +25,29 @@ ```Shell # Classify -cargo run -r --example yolo -- --task classify --version v5 # YOLOv5 -cargo run -r --example yolo -- --task classify --version v8 # YOLOv8 +cargo run -r --example yolo -- --task classify --ver v5 # YOLOv5 +cargo run -r --example yolo -- --task classify --ver v8 # YOLOv8 # Detect -cargo run -r --example yolo -- --task detect --version v5 # YOLOv5 -cargo run -r --example yolo -- --task detect --version v6 # YOLOv6 -cargo run -r --example yolo -- --task detect --version v7 # YOLOv7 -cargo run -r --example yolo -- --task detect --version v8 # YOLOv8 -cargo run -r --example yolo -- --task detect --version v9 # YOLOv9 -cargo run -r --example yolo -- --task detect --version v10 # YOLOv10 -cargo run -r --example yolo -- --task detect --version rtdetr # YOLOv8-RTDETR -cargo run -r --example yolo -- --task detect --version v8 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world +cargo run -r --example yolo -- --task detect --ver v5 # YOLOv5 +cargo run -r --example yolo -- --task detect --ver v6 # YOLOv6 +cargo run -r --example yolo -- --task detect --ver v7 # YOLOv7 +cargo run -r --example yolo -- --task detect --ver v8 # YOLOv8 +cargo run -r --example yolo -- --task detect --ver v9 # YOLOv9 +cargo run -r --example yolo -- --task detect --ver v10 # YOLOv10 +cargo run -r --example yolo -- --task detect --ver rtdetr # YOLOv8-RTDETR +cargo run -r --example yolo -- --task detect --ver v8 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world # Pose -cargo run -r --example yolo -- --task pose --version v8 # YOLOv8-Pose +cargo run -r --example yolo -- --task pose --ver v8 # YOLOv8-Pose # Segment -cargo run -r --example yolo -- --task segment --version v5 # YOLOv5-Segment -cargo run -r --example yolo -- --task segment --version v8 # YOLOv8-Segment -cargo run -r --example yolo -- --task segment --version v8 --model FastSAM-s-dyn-f16.onnx # FastSAM +cargo run -r --example yolo -- --task segment --ver v5 # YOLOv5-Segment +cargo run -r --example yolo -- --task segment --ver v8 # YOLOv8-Segment +cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM # Obb -cargo run -r --example yolo -- --task obb --version v8 # YOLOv8-Obb +cargo run -r --example yolo -- --task obb --ver v8 # YOLOv8-Obb ```
@@ -175,7 +175,3 @@ yolo export model=yolov8m-obb.pt format=onnx simplify [Here](https://github.com/THU-MIG/yolov10#export)
- - - - diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs index f4889a9..97af956 100644 --- a/examples/yolo/main.rs +++ b/examples/yolo/main.rs @@ -16,7 +16,7 @@ pub struct Args { pub task: YOLOTask, #[arg(long, value_enum, default_value_t = YOLOVersion::V8)] - pub version: YOLOVersion, + pub ver: YOLOVersion, #[arg(long, default_value_t = 224)] pub width_min: isize, @@ -59,6 +59,9 @@ pub struct Args { #[arg(long)] pub no_plot: bool, + + #[arg(long)] + pub no_contours: bool, } fn main() -> Result<()> { @@ -68,66 +71,87 @@ fn main() -> Result<()> { let options = Options::default(); // version & task - let options = - match args.version { - YOLOVersion::V5 => { - match args.task { - YOLOTask::Classify => options - .with_model(&args.model.unwrap_or("yolov5n-cls-dyn.onnx".to_string()))?, - YOLOTask::Detect => { - options.with_model(&args.model.unwrap_or("yolov5n-dyn.onnx".to_string()))? - } - YOLOTask::Segment => options - .with_model(&args.model.unwrap_or("yolov5n-seg-dyn.onnx".to_string()))?, - t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version), - } - } - YOLOVersion::V6 => match args.task { - YOLOTask::Detect => options + let (options, saveout) = match args.ver { + YOLOVersion::V5 => match args.task { + YOLOTask::Classify => ( + options.with_model(&args.model.unwrap_or("yolov5n-cls-dyn.onnx".to_string()))?, + "YOLOv5-Classify", + ), + YOLOTask::Detect => ( + options.with_model(&args.model.unwrap_or("yolov5n-dyn.onnx".to_string()))?, + "YOLOv5-Detect", + ), + YOLOTask::Segment => ( + options.with_model(&args.model.unwrap_or("yolov5n-seg-dyn.onnx".to_string()))?, + "YOLOv5-Segment", + ), + t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver), + }, + YOLOVersion::V6 => match args.task { + YOLOTask::Detect => ( + options .with_model(&args.model.unwrap_or("yolov6n-dyn.onnx".to_string()))? .with_nc(args.nc), - t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version), - }, - YOLOVersion::V7 => match args.task { - YOLOTask::Detect => options + "YOLOv6-Detect", + ), + t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver), + }, + YOLOVersion::V7 => match args.task { + YOLOTask::Detect => ( + options .with_model(&args.model.unwrap_or("yolov7-tiny-dyn.onnx".to_string()))? .with_nc(args.nc), - t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version), - }, - YOLOVersion::V8 => { - match args.task { - YOLOTask::Classify => options - .with_model(&args.model.unwrap_or("yolov8m-cls-dyn.onnx".to_string()))?, - YOLOTask::Detect => { - options.with_model(&args.model.unwrap_or("yolov8m-dyn.onnx".to_string()))? - } - YOLOTask::Segment => options - .with_model(&args.model.unwrap_or("yolov8m-seg-dyn.onnx".to_string()))?, - YOLOTask::Pose => options - .with_model(&args.model.unwrap_or("yolov8m-pose-dyn.onnx".to_string()))?, - YOLOTask::Obb => options - .with_model(&args.model.unwrap_or("yolov8m-obb-dyn.onnx".to_string()))?, - } - } - YOLOVersion::V9 => match args.task { - YOLOTask::Detect => options - .with_model(&args.model.unwrap_or("yolov9-c-dyn-f16.onnx".to_string()))?, - t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version), - }, - YOLOVersion::V10 => match args.task { - YOLOTask::Detect => { - options.with_model(&args.model.unwrap_or("yolov10n.onnx".to_string()))? - } - t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version), - }, - YOLOVersion::RTDETR => match args.task { - YOLOTask::Detect => { - options.with_model(&args.model.unwrap_or("rtdetr-l-f16.onnx".to_string()))? - } - t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.version), - }, - } - .with_yolo_version(args.version) + "YOLOv7-Detect", + ), + t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver), + }, + YOLOVersion::V8 => match args.task { + YOLOTask::Classify => ( + options.with_model(&args.model.unwrap_or("yolov8m-cls-dyn.onnx".to_string()))?, + "YOLOv8-Classify", + ), + YOLOTask::Detect => ( + options.with_model(&args.model.unwrap_or("yolov8m-dyn.onnx".to_string()))?, + "YOLOv8-Detect", + ), + YOLOTask::Segment => ( + options.with_model(&args.model.unwrap_or("yolov8m-seg-dyn.onnx".to_string()))?, + "YOLOv8-Segment", + ), + YOLOTask::Pose => ( + options.with_model(&args.model.unwrap_or("yolov8m-pose-dyn.onnx".to_string()))?, + "YOLOv8-Pose", + ), + YOLOTask::Obb => ( + options.with_model(&args.model.unwrap_or("yolov8m-obb-dyn.onnx".to_string()))?, + "YOLOv8-Obb", + ), + }, + YOLOVersion::V9 => match args.task { + YOLOTask::Detect => ( + options.with_model(&args.model.unwrap_or("yolov9-c-dyn-f16.onnx".to_string()))?, + "YOLOv9-Detect", + ), + t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver), + }, + YOLOVersion::V10 => match args.task { + YOLOTask::Detect => ( + options.with_model(&args.model.unwrap_or("yolov10n.onnx".to_string()))?, + "YOLOv10-Detect", + ), + t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver), + }, + YOLOVersion::RTDETR => match args.task { + YOLOTask::Detect => ( + options.with_model(&args.model.unwrap_or("rtdetr-l-f16.onnx".to_string()))?, + "RTDETR", + ), + t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver), + }, + }; + + let options = options + .with_yolo_version(args.ver) .with_yolo_task(args.task); // device @@ -152,6 +176,7 @@ fn main() -> Result<()> { .with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15 // .with_names(&coco::NAMES_80) .with_names2(&coco::KEYPOINTS_NAMES_17) + .with_find_contours(!args.no_contours) // find contours or not .with_profile(args.profile); let mut model = YOLO::new(options)?; @@ -163,9 +188,9 @@ fn main() -> Result<()> { // build annotator let annotator = Annotator::default() .with_skeletons(&coco::SKELETONS_16) - .with_bboxes_thickness(7) - .without_masks(true) // No masks plotting. - .with_saveout("YOLO-Series"); + .with_bboxes_thickness(4) + .without_masks(true) // No masks plotting when doing segment task. + .with_saveout(saveout); // run & annotate for (xs, _paths) in dl { diff --git a/src/core/annotator.rs b/src/core/annotator.rs index 42b6cf8..bbe67a4 100644 --- a/src/core/annotator.rs +++ b/src/core/annotator.rs @@ -340,13 +340,6 @@ impl Annotator { } } - // masks - if !self.without_masks { - if let Some(xs) = &y.masks() { - self.plot_masks(&mut img_rgba, xs); - } - } - // bboxes if !self.without_bboxes { if let Some(xs) = &y.bboxes() { @@ -368,6 +361,13 @@ impl Annotator { } } + // masks + if !self.without_masks { + if let Some(xs) = &y.masks() { + self.plot_masks(&mut img_rgba, xs); + } + } + // probs if let Some(xs) = &y.probs() { self.plot_probs(&mut img_rgba, xs); diff --git a/src/core/dataloader.rs b/src/core/dataloader.rs index e7bb444..6f16797 100644 --- a/src/core/dataloader.rs +++ b/src/core/dataloader.rs @@ -53,49 +53,58 @@ impl Default for DataLoader { } impl DataLoader { - pub fn load>(&mut self, source: P) -> Result { - let source = source.as_ref(); - let mut paths = VecDeque::new(); - - match source { - s if s.is_file() => paths.push_back(s.to_path_buf()), - s if s.is_dir() => { - for entry in WalkDir::new(s) - .into_iter() - .filter_entry(|e| !Self::_is_hidden(e)) - { - let entry = entry.unwrap(); - if entry.file_type().is_dir() { - continue; + pub fn load>(mut self, source: P) -> Result { + self.paths = match source.as_ref() { + s if s.is_file() => VecDeque::from([s.to_path_buf()]), + s if s.is_dir() => WalkDir::new(s) + .into_iter() + .filter_entry(|e| !Self::_is_hidden(e)) + .filter_map(|entry| match entry { + Err(_) => None, + Ok(entry) => { + if entry.file_type().is_dir() { + return None; + } + if !self.recursive && entry.depth() > 1 { + return None; + } + Some(entry.path().to_path_buf()) } - if !self.recursive && entry.depth() > 1 { - continue; - } - paths.push_back(entry.path().to_path_buf()); - } - } + }) + .collect::>(), // s if s.starts_with("rtsp://") || s.starts_with("rtmp://") || s.starts_with("http://")|| s.starts_with("https://") => todo!(), s if !s.exists() => bail!("{s:?} Not Exists"), _ => todo!(), - } - let n_new = paths.len(); - self.paths.append(&mut paths); - println!( - "{CHECK_MARK} Found images x{n_new} ({} total)", - self.paths.len() - ); - Ok(Self { - paths: self.paths.to_owned(), - batch: self.batch, - recursive: self.recursive, - }) + }; + println!("{CHECK_MARK} Found file x{}", self.paths.len()); + Ok(self) } pub fn try_read>(path: P) -> Result { let img = image::ImageReader::open(&path) - .map_err(|_| anyhow!("Failed to open image at {:?}", path.as_ref()))? + .map_err(|err| { + anyhow!( + "Failed to open image at {:?}. Error: {:?}", + path.as_ref(), + err + ) + })? + .with_guessed_format() + .map_err(|err| { + anyhow!( + "Failed to make a format guess based on the content: {:?}. Error: {:?}", + path.as_ref(), + err + ) + })? .decode() - .map_err(|_| anyhow!("Failed to decode image at {:?}", path.as_ref()))? + .map_err(|err| { + anyhow!( + "Failed to decode image at {:?}. Error: {:?}", + path.as_ref(), + err + ) + })? .into_rgb8(); Ok(DynamicImage::from(img)) } diff --git a/src/core/engine.rs b/src/core/engine.rs index 59181b3..4cda4d9 100644 --- a/src/core/engine.rs +++ b/src/core/engine.rs @@ -4,7 +4,6 @@ use human_bytes::human_bytes; use ndarray::{Array, IxDyn}; use ort::{ ExecutionProvider, Session, SessionBuilder, TensorElementType, TensorRTExecutionProvider, - MINOR_VERSION, }; use prost::Message; use std::collections::HashSet; @@ -41,7 +40,7 @@ impl OrtEngine { let model_proto = Self::load_onnx(&config.onnx_path)?; let graph = match &model_proto.graph { Some(graph) => graph, - None => anyhow::bail!("No graph found in this proto"), + None => anyhow::bail!("No graph found in this proto. Failed to parse ONNX model."), }; // model params & mems @@ -101,6 +100,30 @@ impl OrtEngine { (3, 3) => Self::_set_ixx(x, &config.i33, i, ii).unwrap_or(x_default), (3, 4) => Self::_set_ixx(x, &config.i34, i, ii).unwrap_or(x_default), (3, 5) => Self::_set_ixx(x, &config.i35, i, ii).unwrap_or(x_default), + (4, 0) => Self::_set_ixx(x, &config.i40, i, ii).unwrap_or(x_default), + (4, 1) => Self::_set_ixx(x, &config.i41, i, ii).unwrap_or(x_default), + (4, 2) => Self::_set_ixx(x, &config.i42, i, ii).unwrap_or(x_default), + (4, 3) => Self::_set_ixx(x, &config.i43, i, ii).unwrap_or(x_default), + (4, 4) => Self::_set_ixx(x, &config.i44, i, ii).unwrap_or(x_default), + (4, 5) => Self::_set_ixx(x, &config.i45, i, ii).unwrap_or(x_default), + (5, 0) => Self::_set_ixx(x, &config.i50, i, ii).unwrap_or(x_default), + (5, 1) => Self::_set_ixx(x, &config.i51, i, ii).unwrap_or(x_default), + (5, 2) => Self::_set_ixx(x, &config.i52, i, ii).unwrap_or(x_default), + (5, 3) => Self::_set_ixx(x, &config.i53, i, ii).unwrap_or(x_default), + (5, 4) => Self::_set_ixx(x, &config.i54, i, ii).unwrap_or(x_default), + (5, 5) => Self::_set_ixx(x, &config.i55, i, ii).unwrap_or(x_default), + (6, 0) => Self::_set_ixx(x, &config.i60, i, ii).unwrap_or(x_default), + (6, 1) => Self::_set_ixx(x, &config.i61, i, ii).unwrap_or(x_default), + (6, 2) => Self::_set_ixx(x, &config.i62, i, ii).unwrap_or(x_default), + (6, 3) => Self::_set_ixx(x, &config.i63, i, ii).unwrap_or(x_default), + (6, 4) => Self::_set_ixx(x, &config.i64_, i, ii).unwrap_or(x_default), + (6, 5) => Self::_set_ixx(x, &config.i65, i, ii).unwrap_or(x_default), + (7, 0) => Self::_set_ixx(x, &config.i70, i, ii).unwrap_or(x_default), + (7, 1) => Self::_set_ixx(x, &config.i71, i, ii).unwrap_or(x_default), + (7, 2) => Self::_set_ixx(x, &config.i72, i, ii).unwrap_or(x_default), + (7, 3) => Self::_set_ixx(x, &config.i73, i, ii).unwrap_or(x_default), + (7, 4) => Self::_set_ixx(x, &config.i74, i, ii).unwrap_or(x_default), + (7, 5) => Self::_set_ixx(x, &config.i75, i, ii).unwrap_or(x_default), _ => todo!(), }; v_.push(x); @@ -146,7 +169,7 @@ impl OrtEngine { // summary println!( - "{CHECK_MARK} ORT: 1.{MINOR_VERSION}.x | Opset: {} | EP: {:?} | Dtype: {:?} | Parameters: {}", + "{CHECK_MARK} Backend: ONNXRuntime | OpSet: {} | EP: {:?} | DType: {:?} | Params: {}", model_proto.opset_import[0].version, device, inputs_attrs.dtypes, @@ -291,6 +314,12 @@ impl OrtEngine { TensorElementType::Int64 => { ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn() } + TensorElementType::Uint8 => { + ort::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn() + } + TensorElementType::Int8 => { + ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn() + } _ => todo!(), }; xs_.push(Into::>::into(x_)); @@ -499,14 +528,12 @@ impl OrtEngine { let tensor_type = match Self::ort_dtype_from_onnx_dtype_id(tensor_type) { Some(dtype) => dtype, None => continue, - // None => anyhow::bail!("DType not supported"), }; dtypes.push(tensor_type); let shapes = match &tensor.shape { Some(shapes) => shapes, None => continue, - // None => anyhow::bail!("DType has no shapes"), }; let mut shape_: Vec = Vec::new(); for shape in shapes.dim.iter() { diff --git a/src/core/ops.rs b/src/core/ops.rs index da4341f..bc29e3e 100644 --- a/src/core/ops.rs +++ b/src/core/ops.rs @@ -1,7 +1,6 @@ //! Some processing functions to image and ndarray. use anyhow::Result; -use fast_image_resize as fir; use fast_image_resize::{ images::{CroppedImageMut, Image}, pixels::PixelType, @@ -11,8 +10,6 @@ use image::{DynamicImage, GenericImageView}; use ndarray::{s, Array, Axis, IxDyn}; use rayon::prelude::*; -use crate::X; - pub enum Ops<'a> { Resize(&'a [DynamicImage], u32, u32, &'a str), Letterbox(&'a [DynamicImage], u32, u32, &'a str, u8, &'a str, bool), @@ -26,30 +23,13 @@ pub enum Ops<'a> { } impl Ops<'_> { - pub fn apply(ops: &[Self]) -> Result { - let mut y = X::default(); - - for op in ops { - y = match op { - Self::Resize(xs, h, w, filter) => X::resize(xs, *h, *w, filter)?, - Self::Letterbox(xs, h, w, filter, bg, resize_by, center) => { - X::letterbox(xs, *h, *w, filter, *bg, resize_by, *center)? - } - Self::Normalize(min_, max_) => y.normalize(*min_, *max_)?, - Self::Standardize(mean, std, d) => y.standardize(mean, std, *d)?, - Self::Permute(shape) => y.permute(shape)?, - Self::InsertAxis(d) => y.insert_axis(*d)?, - Self::Nhwc2nchw => y.nhwc2nchw()?, - Self::Nchw2nhwc => y.nchw2nhwc()?, - _ => todo!(), - } - } - Ok(y) - } - pub fn normalize(x: Array, min: f32, max: f32) -> Result> { - if min > max { - anyhow::bail!("Input `min` is greater than `max`"); + if min >= max { + anyhow::bail!( + "Invalid range in `normalize`: `min` ({}) must be less than `max` ({}).", + min, + max + ); } Ok((x - min) / (max - min)) } @@ -61,11 +41,11 @@ impl Ops<'_> { dim: usize, ) -> Result> { if mean.len() != std.len() { - anyhow::bail!("The lengths of mean and std are not equal."); + anyhow::bail!("`standardize`: `mean` and `std` lengths are not equal. Mean length: {}, Std length: {}.", mean.len(), std.len()); } let shape = x.shape(); if dim >= shape.len() || shape[dim] != mean.len() { - anyhow::bail!("The specified dimension or mean/std length is inconsistent with the input dimensions."); + anyhow::bail!("`standardize`: Dimension mismatch. `dim` is {} but shape length is {} or `mean` length is {}.", dim, shape.len(), mean.len()); } let mut shape = vec![1; shape.len()]; shape[dim] = mean.len(); @@ -77,11 +57,11 @@ impl Ops<'_> { pub fn permute(x: Array, shape: &[usize]) -> Result> { if shape.len() != x.shape().len() { anyhow::bail!( - "Shape inconsistent. Target: {:?}, {}, got: {:?}, {}", - x.shape(), + "`permute`: Shape length mismatch. Expected: {}, got: {}. Target shape: {:?}, provided shape: {:?}.", x.shape().len(), - shape, - shape.len() + shape.len(), + x.shape(), + shape ); } Ok(x.permuted_axes(shape.to_vec()).into_dyn()) @@ -98,7 +78,7 @@ impl Ops<'_> { pub fn insert_axis(x: Array, d: usize) -> Result> { if x.shape().len() < d { anyhow::bail!( - "The specified axis insertion position {} exceeds the shape's maximum limit of {}.", + "`insert_axis`: The specified axis position {} exceeds the maximum shape length {}.", d, x.shape().len() ); @@ -109,7 +89,7 @@ impl Ops<'_> { pub fn norm(xs: Array, d: usize) -> Result> { if xs.shape().len() < d { anyhow::bail!( - "The specified axis {} exceeds the shape's maximum limit of {}.", + "`norm`: Specified axis {} exceeds the maximum dimension length {}.", d, xs.shape().len() ); @@ -149,22 +129,22 @@ impl Ops<'_> { crop_src: bool, filter: &str, ) -> Result> { - let src_mask = fir::images::Image::from_vec_u8( + let src = Image::from_vec_u8( w0 as _, h0 as _, v.iter().flat_map(|x| x.to_le_bytes()).collect(), - fir::PixelType::F32, + PixelType::F32, )?; - let mut dst_mask = fir::images::Image::new(w1 as _, h1 as _, src_mask.pixel_type()); + let mut dst = Image::new(w1 as _, h1 as _, src.pixel_type()); let (mut resizer, mut options) = Self::build_resizer_filter(filter)?; if crop_src { let (_, w, h) = Self::scale_wh(w1 as _, h1 as _, w0 as _, h0 as _); options = options.crop(0., 0., w.into(), h.into()); }; - resizer.resize(&src_mask, &mut dst_mask, &options)?; + resizer.resize(&src, &mut dst, &options)?; // u8*2 -> f32 - let mask_f32: Vec = dst_mask + let mask_f32: Vec = dst .into_vec() .chunks_exact(4) .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) @@ -184,16 +164,15 @@ impl Ops<'_> { crop_src: bool, filter: &str, ) -> Result> { - let src_mask = - fir::images::Image::from_vec_u8(w0 as _, h0 as _, v.to_vec(), fir::PixelType::U8)?; - let mut dst_mask = fir::images::Image::new(w1 as _, h1 as _, src_mask.pixel_type()); + let src = Image::from_vec_u8(w0 as _, h0 as _, v.to_vec(), PixelType::U8)?; + let mut dst = Image::new(w1 as _, h1 as _, src.pixel_type()); let (mut resizer, mut options) = Self::build_resizer_filter(filter)?; if crop_src { let (_, w, h) = Self::scale_wh(w1 as _, h1 as _, w0 as _, h0 as _); options = options.crop(0., 0., w.into(), h.into()); }; - resizer.resize(&src_mask, &mut dst_mask, &options)?; - Ok(dst_mask.into_vec()) + resizer.resize(&src, &mut dst, &options)?; + Ok(dst.into_vec()) } pub fn build_resizer_filter(ty: &str) -> Result<(Resizer, ResizeOptions)> { @@ -205,7 +184,7 @@ impl Ops<'_> { "Mitchell" => FilterType::Mitchell, "Gaussian" => FilterType::Gaussian, "Lanczos3" => FilterType::Lanczos3, - _ => anyhow::bail!("Unsupported resize filter type: {ty}"), + _ => anyhow::bail!("Unsupported resizer's filter type: {ty}"), }; Ok(( Resizer::new(), @@ -215,22 +194,22 @@ impl Ops<'_> { pub fn resize( xs: &[DynamicImage], - height: u32, - width: u32, + th: u32, + tw: u32, filter: &str, ) -> Result> { - let mut ys = Array::ones((xs.len(), height as usize, width as usize, 3)).into_dyn(); + let mut ys = Array::ones((xs.len(), th as usize, tw as usize, 3)).into_dyn(); let (mut resizer, options) = Self::build_resizer_filter(filter)?; for (idx, x) in xs.iter().enumerate() { - let buffer = if x.dimensions() == (width, height) { + let buffer = if x.dimensions() == (tw, th) { x.to_rgb8().into_raw() } else { - let mut dst_image = Image::new(width, height, PixelType::U8x3); - resizer.resize(x, &mut dst_image, &options)?; - dst_image.into_vec() + let mut dst = Image::new(tw, th, PixelType::U8x3); + resizer.resize(x, &mut dst, &options)?; + dst.into_vec() }; - let y_ = Array::from_shape_vec((height as usize, width as usize, 3), buffer)? - .mapv(|x| x as f32); + let y_ = + Array::from_shape_vec((th as usize, tw as usize, 3), buffer)?.mapv(|x| x as f32); ys.slice_mut(s![idx, .., .., ..]).assign(&y_); } Ok(ys) @@ -238,55 +217,55 @@ impl Ops<'_> { pub fn letterbox( xs: &[DynamicImage], - height: u32, - width: u32, + th: u32, + tw: u32, filter: &str, bg: u8, resize_by: &str, center: bool, ) -> Result> { - let mut ys = Array::ones((xs.len(), height as usize, width as usize, 3)).into_dyn(); + let mut ys = Array::ones((xs.len(), th as usize, tw as usize, 3)).into_dyn(); let (mut resizer, options) = Self::build_resizer_filter(filter)?; for (idx, x) in xs.iter().enumerate() { let (w0, h0) = x.dimensions(); - let buffer = if w0 == width && h0 == height { + let buffer = if w0 == tw && h0 == th { x.to_rgb8().into_raw() } else { let (w, h) = match resize_by { "auto" => { - let r = (width as f32 / w0 as f32).min(height as f32 / h0 as f32); + let r = (tw as f32 / w0 as f32).min(th as f32 / h0 as f32); ( (w0 as f32 * r).round() as u32, (h0 as f32 * r).round() as u32, ) } - "height" => (height * w0 / h0, height), - "width" => (width, width * h0 / w0), - _ => anyhow::bail!("Option: width, height, auto"), + "height" => (th * w0 / h0, th), + "width" => (tw, tw * h0 / w0), + _ => anyhow::bail!("Options for `letterbox`: width, height, auto"), }; - let mut dst_image = Image::from_vec_u8( - width, - height, - vec![bg; 3 * height as usize * width as usize], + let mut dst = Image::from_vec_u8( + tw, + th, + vec![bg; 3 * th as usize * tw as usize], PixelType::U8x3, )?; let (l, t) = if center { - if w == width { - (0, (height - h) / 2) + if w == tw { + (0, (th - h) / 2) } else { - ((width - w) / 2, 0) + ((tw - w) / 2, 0) } } else { (0, 0) }; - let mut cropped_dst_image = CroppedImageMut::new(&mut dst_image, l, t, w, h)?; - resizer.resize(x, &mut cropped_dst_image, &options)?; - dst_image.into_vec() + let mut dst_cropped = CroppedImageMut::new(&mut dst, l, t, w, h)?; + resizer.resize(x, &mut dst_cropped, &options)?; + dst.into_vec() }; - let y_ = Array::from_shape_vec((height as usize, width as usize, 3), buffer)? - .mapv(|x| x as f32); + let y_ = + Array::from_shape_vec((th as usize, tw as usize, 3), buffer)?.mapv(|x| x as f32); ys.slice_mut(s![idx, .., .., ..]).assign(&y_); } Ok(ys) diff --git a/src/core/options.rs b/src/core/options.rs index cb6e866..dc4b00a 100644 --- a/src/core/options.rs +++ b/src/core/options.rs @@ -4,7 +4,7 @@ use anyhow::Result; use crate::{ auto_load, - models::{YOLOPreds, YOLOTask, YOLOVersion}, + models::{SamKind, YOLOPreds, YOLOTask, YOLOVersion}, Device, MinOptMax, }; @@ -39,7 +39,30 @@ pub struct Options { pub i33: Option, pub i34: Option, pub i35: Option, - + pub i40: Option, + pub i41: Option, + pub i42: Option, + pub i43: Option, + pub i44: Option, + pub i45: Option, + pub i50: Option, + pub i51: Option, + pub i52: Option, + pub i53: Option, + pub i54: Option, + pub i55: Option, + pub i60: Option, + pub i61: Option, + pub i62: Option, + pub i63: Option, + pub i64_: Option, + pub i65: Option, + pub i70: Option, + pub i71: Option, + pub i72: Option, + pub i73: Option, + pub i74: Option, + pub i75: Option, // trt related pub trt_engine_cache_enable: bool, pub trt_int8_enable: bool, @@ -63,6 +86,9 @@ pub struct Options { pub yolo_task: Option, pub yolo_version: Option, pub yolo_preds: Option, + pub find_contours: bool, + pub sam_kind: Option, + pub use_low_res_mask: Option, } impl Default for Options { @@ -96,6 +122,30 @@ impl Default for Options { i33: None, i34: None, i35: None, + i40: None, + i41: None, + i42: None, + i43: None, + i44: None, + i45: None, + i50: None, + i51: None, + i52: None, + i53: None, + i54: None, + i55: None, + i60: None, + i61: None, + i62: None, + i63: None, + i64_: None, + i65: None, + i70: None, + i71: None, + i72: None, + i73: None, + i74: None, + i75: None, trt_engine_cache_enable: true, trt_int8_enable: false, trt_fp16_enable: false, @@ -116,6 +166,9 @@ impl Default for Options { yolo_task: None, yolo_version: None, yolo_preds: None, + find_contours: false, + sam_kind: None, + use_low_res_mask: None, } } } @@ -171,6 +224,21 @@ impl Options { self } + pub fn with_find_contours(mut self, x: bool) -> Self { + self.find_contours = x; + self + } + + pub fn with_sam_kind(mut self, x: SamKind) -> Self { + self.sam_kind = Some(x); + self + } + + pub fn use_low_res_mask(mut self, x: bool) -> Self { + self.use_low_res_mask = Some(x); + self + } + pub fn with_names(mut self, names: &[&str]) -> Self { self.names = Some(names.iter().map(|x| x.to_string()).collect::>()); self @@ -360,4 +428,124 @@ impl Options { self.i35 = Some(x); self } + + pub fn with_i40(mut self, x: MinOptMax) -> Self { + self.i40 = Some(x); + self + } + + pub fn with_i41(mut self, x: MinOptMax) -> Self { + self.i41 = Some(x); + self + } + + pub fn with_i42(mut self, x: MinOptMax) -> Self { + self.i42 = Some(x); + self + } + + pub fn with_i43(mut self, x: MinOptMax) -> Self { + self.i43 = Some(x); + self + } + + pub fn with_i44(mut self, x: MinOptMax) -> Self { + self.i44 = Some(x); + self + } + + pub fn with_i45(mut self, x: MinOptMax) -> Self { + self.i45 = Some(x); + self + } + + pub fn with_i50(mut self, x: MinOptMax) -> Self { + self.i50 = Some(x); + self + } + + pub fn with_i51(mut self, x: MinOptMax) -> Self { + self.i51 = Some(x); + self + } + + pub fn with_i52(mut self, x: MinOptMax) -> Self { + self.i52 = Some(x); + self + } + + pub fn with_i53(mut self, x: MinOptMax) -> Self { + self.i53 = Some(x); + self + } + + pub fn with_i54(mut self, x: MinOptMax) -> Self { + self.i54 = Some(x); + self + } + + pub fn with_i55(mut self, x: MinOptMax) -> Self { + self.i55 = Some(x); + self + } + + pub fn with_i60(mut self, x: MinOptMax) -> Self { + self.i60 = Some(x); + self + } + + pub fn with_i61(mut self, x: MinOptMax) -> Self { + self.i61 = Some(x); + self + } + + pub fn with_i62(mut self, x: MinOptMax) -> Self { + self.i62 = Some(x); + self + } + + pub fn with_i63(mut self, x: MinOptMax) -> Self { + self.i63 = Some(x); + self + } + + pub fn with_i64(mut self, x: MinOptMax) -> Self { + self.i64_ = Some(x); + self + } + + pub fn with_i65(mut self, x: MinOptMax) -> Self { + self.i65 = Some(x); + self + } + + pub fn with_i70(mut self, x: MinOptMax) -> Self { + self.i70 = Some(x); + self + } + + pub fn with_i71(mut self, x: MinOptMax) -> Self { + self.i71 = Some(x); + self + } + + pub fn with_i72(mut self, x: MinOptMax) -> Self { + self.i72 = Some(x); + self + } + + pub fn with_i73(mut self, x: MinOptMax) -> Self { + self.i73 = Some(x); + self + } + + pub fn with_i74(mut self, x: MinOptMax) -> Self { + self.i74 = Some(x); + self + } + + pub fn with_i75(mut self, x: MinOptMax) -> Self { + self.i75 = Some(x); + self + } } diff --git a/src/core/x.rs b/src/core/x.rs index 72da861..e6b39ba 100644 --- a/src/core/x.rs +++ b/src/core/x.rs @@ -14,6 +14,12 @@ impl From> for X { } } +impl From> for X { + fn from(x: Vec) -> Self { + Self(Array::from_vec(x).into_dyn().into_owned()) + } +} + impl std::ops::Deref for X { type Target = Array; @@ -28,7 +34,23 @@ impl X { } pub fn apply(ops: &[Ops]) -> Result { - Ops::apply(ops) + let mut y = Self::default(); + for op in ops { + y = match op { + Ops::Resize(xs, h, w, filter) => Self::resize(xs, *h, *w, filter)?, + Ops::Letterbox(xs, h, w, filter, bg, resize_by, center) => { + Self::letterbox(xs, *h, *w, filter, *bg, resize_by, *center)? + } + Ops::Normalize(min_, max_) => y.normalize(*min_, *max_)?, + Ops::Standardize(mean, std, d) => y.standardize(mean, std, *d)?, + Ops::Permute(shape) => y.permute(shape)?, + Ops::InsertAxis(d) => y.insert_axis(*d)?, + Ops::Nhwc2nchw => y.nhwc2nchw()?, + Ops::Nchw2nhwc => y.nchw2nhwc()?, + _ => todo!(), + } + } + Ok(y) } pub fn permute(mut self, shape: &[usize]) -> Result { diff --git a/src/lib.rs b/src/lib.rs index 3b8ea2c..77dd69c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,91 +1,150 @@ -//! A Rust library integrated with ONNXRuntime, providing a collection of Computer Vison and Vision-Language models. +//! A Rust library integrated with ONNXRuntime, providing a collection of **Computer Vision** and **Vision-Language** models. //! -//! [`OrtEngine`] provides ONNX model loading, metadata parsing, dry_run, inference and other functions, supporting EPs such as CUDA, TensorRT, CoreML, etc. You can use it as the ONNXRuntime engine for building models. +//! # Supported Models //! +//! - [YOLOv5](https://github.com/ultralytics/yolov5): Object Detection, Instance Segmentation, Classification +//! - [YOLOv6](https://github.com/meituan/YOLOv6): Object Detection +//! - [YOLOv7](https://github.com/WongKinYiu/yolov7): Object Detection +//! - [YOLOv8](https://github.com/ultralytics/ultralytics): Object Detection, Instance Segmentation, Classification, Oriented Object Detection, Keypoint Detection +//! - [YOLOv9](https://github.com/WongKinYiu/yolov9): Object Detection +//! - [YOLOv10](https://github.com/THU-MIG/yolov10): Object Detection +//! - [RT-DETR](https://arxiv.org/abs/2304.08069): Object Detection +//! - [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM): Instance Segmentation +//! - [SAM](https://github.com/facebookresearch/segment-anything): Segmentation Anything +//! - [MobileSAM](https://github.com/ChaoningZhang/MobileSAM): Segmentation Anything +//! - [EdgeSAM](https://github.com/chongzhou96/EdgeSAM): Segmentation Anything +//! - [SAM-HQ](https://github.com/SysCV/sam-hq): Segmentation Anything +//! - [YOLO-World](https://github.com/AILab-CVC/YOLO-World): Object Detection +//! - [DINOv2](https://github.com/facebookresearch/dinov2): Vision-Self-Supervised +//! - [CLIP](https://github.com/openai/CLIP): Vision-Language +//! - [BLIP](https://github.com/salesforce/BLIP): Vision-Language +//! - [DB](https://arxiv.org/abs/1911.08947): Text Detection +//! - [SVTR](https://arxiv.org/abs/2205.00159): Text Recognition +//! - [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo): Keypoint Detection +//! - [YOLOPv2](https://arxiv.org/abs/2208.11434): Panoptic Driving Perception +//! - [Depth-Anything (v1, v2)](https://github.com/LiheYoung/Depth-Anything): Monocular Depth Estimation +//! - [MODNet](https://github.com/ZHKKKe/MODNet): Image Matting //! -//! -//! - -//! # Supported models -//! | Model | Task / Type | -//! | :---------------------------------------------------------------: | :-------------------------: | -//! | [YOLOv5](https://github.com/ultralytics/yolov5) | Object Detection
Instance Segmentation
Classification | -//! | [YOLOv6](https://github.com/meituan/YOLOv6) | Object Detection | -//! | [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | -//! | [YOLOv8](https://github.com/ultralytics/ultralytics) | Object Detection
Instance Segmentation
Classification
Oriented Object Detection
Keypoint Detection | -//! | [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | -//! | [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | -//! | [RT-DETR](https://arxiv.org/abs/2304.08069) | Object Detection | -//! | [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | -//! | [YOLO-World](https://github.com/AILab-CVC/YOLO-World) | Object Detection | -//! | [DINOv2](https://github.com/facebookresearch/dinov2) | Vision-Self-Supervised | -//! | [CLIP](https://github.com/openai/CLIP) | Vision-Language | -//! | [BLIP](https://github.com/salesforce/BLIP) | Vision-Language | -//! | [DB](https://arxiv.org/abs/1911.08947) | Text Detection | -//! | [SVTR](https://arxiv.org/abs/2205.00159) | Text Recognition | -//! | [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | -//! | [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | -//! | [Depth-Anything
(v1, v2)](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | -//! | [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | - //! # Examples -//! [All Examples Here](https://github.com/jamjamjon/usls/tree/main/examples) - -//! # Use provided models for inference - -//! #### 1. Using provided [`models`] with [`Option`] - -//! ```Rust, no_run +//! +//! [All Demos Here](https://github.com/jamjamjon/usls/tree/main/examples) +//! +//! # Using Provided Models for Inference +//! +//! #### 1. Build Model +//! Using provided [`models`] with [`Options`] +//! +//! ```rust, no_run //! use usls::{coco, models::YOLO, Annotator, DataLoader, Options, Vision}; //! //! let options = Options::default() //! .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR -//! .with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb +//! .with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb //! .with_model("xxxx.onnx")?; -//! .with_trt(0) -//! .with_fp16(true) -//! .with_i00((1, 1, 4).into()) -//! .with_i02((224, 640, 800).into()) -//! .with_i03((224, 640, 800).into()) -//! .with_confs(&[0.4, 0.15]) // class_0: 0.4, others: 0.15 -//! .with_profile(false); //! let mut model = YOLO::new(options)?; //! ``` - -//! #### 2. Load images using [`DataLoader`] or [`image::io::Reader`] //! -//! ```Rust, no_run -//! // Load one image +//! - Use `CUDA`, `TensorRT`, or `CoreML` +//! +//! ```rust, no_run +//! let options = Options::default() +//! .with_cuda(0) // using CUDA by default +//! // .with_trt(0) +//! // .with_coreml(0) +//! // .with_cpu(); +//! ``` +//! +//! - Dynamic Input Shapes +//! +//! ```rust, no_run +//! let options = Options::default() +//! .with_i00((1, 2, 4).into()) // dynamic batch +//! .with_i02((416, 640, 800).into()) // dynamic height +//! .with_i03((416, 640, 800).into()); // dynamic width +//! ``` +//! +//! - Set Confidence Thresholds for Each Category +//! +//! ```rust, no_run +//! let options = Options::default() +//! .with_confs(&[0.4, 0.15]); // class_0: 0.4, others: 0.15 +//! ``` +//! +//! - Set Class Names +//! +//! ```rust, no_run +//! let options = Options::default() +//! .with_names(&coco::NAMES_80); +//! ``` +//! +//! More options can be found in the [`Options`] documentation. +//! +//! #### 2. Load Images +//! +//! Ensure that the input image is RGB type. +//! +//! - Using [`image::ImageReader`] or [`DataLoader`] to Load One Image +//! +//! ```rust, no_run //! let x = vec![DataLoader::try_read("./assets/bus.jpg")?]; +//! // or +//! let x = image::ImageReader::open("myimage.png")?.decode()?; +//! ``` //! -//! // Load images with batch_size = 4 +//! - Using [`DataLoader`] to Load a Batch of Images +//! +//! ```rust, no_run //! let dl = DataLoader::default() //! .with_batch(4) //! .load("./assets")?; -//! // Load one image with `image::io::Reader` -//! let x = image::io::Reader::open("myimage.png")?.decode()? //! ``` //! -//! #### 3. Build annotator using [`Annotator`] +//! #### 3. (Optional) Annotate Results with [`Annotator`] //! -//! ```Rust, no_run +//! ```rust, no_run +//! let annotator = Annotator::default(); +//! ``` +//! +//! - Set Saveout Name +//! +//! ```rust, no_run //! let annotator = Annotator::default() -//! .with_bboxes_thickness(4) //! .with_saveout("YOLOs"); //! ``` +//! +//! - Set Bboxes Line Width //! -//! #### 4. Run and annotate +//! ```rust, no_run +//! let annotator = Annotator::default() +//! .with_bboxes_thickness(4); +//! ``` +//! +//! - Disable Mask Plotting +//! +//! ```rust, no_run +//! let annotator = Annotator::default() +//! .without_masks(true); +//! ``` +//! +//! More options can be found in the [`Annotator`] documentation. //! -//! ```Rust, no_run +//! +//! #### 4. Run and Annotate +//! +//! ```rust, no_run //! for (xs, _paths) in dl { //! let ys = model.run(&xs)?; //! annotator.annotate(&xs, &ys); //! } //! ``` //! -//! #### 5. Parse inference results from [`Vec`] -//! For example, uou can get detection bboxes with `y.bboxes()`: -//! ```Rust, no_run +//! #### 5. Get Results +//! +//! The inference outputs of provided models will be saved to a [`Vec`]. +//! +//! - For Example, Get Detection Bboxes with `y.bboxes()` +//! +//! ```rust, no_run //! let ys = model.run(&xs)?; //! for y in ys { //! // bboxes @@ -99,18 +158,17 @@ //! bbox.ymax(), //! bbox.confidence(), //! bbox.id(), -//! ) +//! ); //! } //! } //! } -//! ``` +//! ``` //! +//! # Also, You Can Implement Your Own Model with [`OrtEngine`] and [`Options`] //! -//! # Build your own model with [`OrtEngine`] -//! -//! Refer to [Demo: Depth-Anything](https://github.com/jamjamjon/usls/blob/main/src/models/depth_anything.rs) -//! +//! [`OrtEngine`] provides ONNX model loading, metadata parsing, dry_run, inference, and other functions, supporting EPs such as CUDA, TensorRT, CoreML, etc. You can use it as the ONNXRuntime engine for building models. //! +//! Refer to [Demo: Depth-Anything](https://github.com/jamjamjon/usls/blob/main/src/models/depth_anything.rs) for more details. mod core; pub mod models; diff --git a/src/models/blip.rs b/src/models/blip.rs index 18a93b7..d70280d 100644 --- a/src/models/blip.rs +++ b/src/models/blip.rs @@ -27,7 +27,7 @@ impl Blip { visual.height().to_owned(), visual.width().to_owned(), ); - let tokenizer = Tokenizer::from_file(&options_textual.tokenizer.unwrap()).unwrap(); + let tokenizer = Tokenizer::from_file(options_textual.tokenizer.unwrap()).unwrap(); let tokenizer = TokenizerStream::new(tokenizer); visual.dry_run()?; textual.dry_run()?; diff --git a/src/models/clip.rs b/src/models/clip.rs index 2fe0523..73802e1 100644 --- a/src/models/clip.rs +++ b/src/models/clip.rs @@ -28,7 +28,7 @@ impl Clip { visual.inputs_minoptmax()[0][2].to_owned(), visual.inputs_minoptmax()[0][3].to_owned(), ); - let mut tokenizer = Tokenizer::from_file(&options_textual.tokenizer.unwrap()).unwrap(); + let mut tokenizer = Tokenizer::from_file(options_textual.tokenizer.unwrap()).unwrap(); tokenizer.with_padding(Some(PaddingParams { strategy: PaddingStrategy::Fixed(context_length), direction: PaddingDirection::Right, diff --git a/src/models/db.rs b/src/models/db.rs index 2e14b2d..68351fa 100644 --- a/src/models/db.rs +++ b/src/models/db.rs @@ -119,31 +119,31 @@ impl DB { continue; } - let mask = Polygon::default().with_points_imageproc(&contour.points); - let delta = mask.area() * ratio.round() as f64 * self.unclip_ratio as f64 - / mask.perimeter(); + let polygon = Polygon::default().with_points_imageproc(&contour.points); + let delta = polygon.area() * ratio.round() as f64 * self.unclip_ratio as f64 + / polygon.perimeter(); // TODO: optimize - let mask = mask + let polygon = polygon .unclip(delta, image_width as f64, image_height as f64) .resample(50) // .simplify(6e-4) .convex_hull(); - if let Some(bbox) = mask.bbox() { + if let Some(bbox) = polygon.bbox() { if bbox.height() < self.min_height || bbox.width() < self.min_width { continue; } - let confidence = mask.area() as f32 / bbox.area(); + let confidence = polygon.area() as f32 / bbox.area(); if confidence < self.confs[0] { continue; } y_bbox.push(bbox.with_confidence(confidence).with_id(0)); - if let Some(mbr) = mask.mbr() { + if let Some(mbr) = polygon.mbr() { y_mbrs.push(mbr.with_confidence(confidence).with_id(0)); } - y_polygons.push(mask.with_id(0)); + y_polygons.push(polygon.with_id(0)); } else { continue; } diff --git a/src/models/mod.rs b/src/models/mod.rs index df8c9d9..87f7822 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -8,6 +8,7 @@ mod dinov2; mod modnet; mod rtdetr; mod rtmo; +mod sam; mod svtr; mod yolo; mod yolo_; @@ -21,10 +22,8 @@ pub use dinov2::Dinov2; pub use modnet::MODNet; pub use rtdetr::RTDETR; pub use rtmo::RTMO; +pub use sam::{SamKind, SamPrompt, SAM}; pub use svtr::SVTR; pub use yolo::YOLO; pub use yolo_::*; -// { -// AnchorsPosition, BoxType, ClssType, KptsType, YOLOFormat, YOLOPreds, YOLOTask, YOLOVersion, -// }; pub use yolop::YOLOPv2; diff --git a/src/models/sam.rs b/src/models/sam.rs new file mode 100644 index 0000000..d7fe83f --- /dev/null +++ b/src/models/sam.rs @@ -0,0 +1,291 @@ +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Array, Axis}; +use rand::prelude::*; + +use crate::{DynConf, Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, X, Y}; + +#[derive(Debug, Clone, clap::ValueEnum)] +pub enum SamKind { + Sam, + MobileSam, + SamHq, + EdgeSam, +} + +#[derive(Debug, Default, Clone)] +pub struct SamPrompt { + points: Vec, + labels: Vec, +} + +impl SamPrompt { + pub fn everything() -> Self { + todo!() + } + + pub fn with_postive_point(mut self, x: f32, y: f32) -> Self { + self.points.extend_from_slice(&[x, y]); + self.labels.push(1.); + self + } + + pub fn with_negative_point(mut self, x: f32, y: f32) -> Self { + self.points.extend_from_slice(&[x, y]); + self.labels.push(0.); + self + } + + pub fn with_bbox(mut self, x: f32, y: f32, x2: f32, y2: f32) -> Self { + self.points.extend_from_slice(&[x, y, x2, y2]); + self.labels.extend_from_slice(&[2., 3.]); + self + } + + pub fn point_coords(&self, r: f32) -> Result { + let point_coords = Array::from_shape_vec((1, self.num_points(), 2), self.points.clone())? + .into_dyn() + .into_owned(); + Ok(X::from(point_coords * r)) + } + + pub fn point_labels(&self) -> Result { + let point_labels = Array::from_shape_vec((1, self.num_points()), self.labels.clone())? + .into_dyn() + .into_owned(); + Ok(X::from(point_labels)) + } + + pub fn num_points(&self) -> usize { + self.points.len() / 2 + } +} + +#[derive(Debug)] +pub struct SAM { + encoder: OrtEngine, + decoder: OrtEngine, + height: MinOptMax, + width: MinOptMax, + batch: MinOptMax, + pub conf: DynConf, + find_contours: bool, + kind: SamKind, + use_low_res_mask: bool, +} + +impl SAM { + pub fn new(options_encoder: Options, options_decoder: Options) -> Result { + let mut encoder = OrtEngine::new(&options_encoder)?; + let mut decoder = OrtEngine::new(&options_decoder)?; + let (batch, height, width) = ( + encoder.inputs_minoptmax()[0][0].to_owned(), + encoder.inputs_minoptmax()[0][2].to_owned(), + encoder.inputs_minoptmax()[0][3].to_owned(), + ); + let conf = DynConf::new(&options_decoder.confs, 1); + + let kind = match options_decoder.sam_kind { + Some(x) => x, + None => anyhow::bail!("Error: no clear `SamKind` specified."), + }; + let find_contours = options_decoder.find_contours; + let use_low_res_mask = match kind { + SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => { + options_decoder.use_low_res_mask.unwrap_or(false) + } + SamKind::EdgeSam => true, + }; + + encoder.dry_run()?; + decoder.dry_run()?; + + Ok(Self { + encoder, + decoder, + batch, + height, + width, + conf, + kind, + find_contours, + use_low_res_mask, + }) + } + + pub fn run(&mut self, xs: &[DynamicImage], prompts: &[SamPrompt]) -> Result> { + let ys = self.encode(xs)?; + self.decode(ys, xs, prompts) + } + + pub fn encode(&mut self, xs: &[DynamicImage]) -> Result> { + let xs_ = X::apply(&[ + Ops::Letterbox( + xs, + self.height() as u32, + self.width() as u32, + "Bilinear", + 0, + "auto", + false, + ), + Ops::Standardize(&[123.675, 116.28, 103.53], &[58.395, 57.12, 57.375], 3), + Ops::Nhwc2nchw, + ])?; + + self.encoder.run(vec![xs_]) + } + + pub fn decode( + &mut self, + xs: Vec, + xs0: &[DynamicImage], + prompts: &[SamPrompt], + ) -> Result> { + let mut ys: Vec = Vec::new(); + + for (idx, image_embedding) in xs[0].axis_iter(Axis(0)).enumerate() { + let image_width = xs0[idx].width() as f32; + let image_height = xs0[idx].height() as f32; + let ratio = + (self.width() as f32 / image_width).min(self.height() as f32 / image_height); + let args = match self.kind { + SamKind::Sam | SamKind::MobileSam => { + vec![ + X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?, // image_embedding + prompts[idx].point_coords(ratio)?, // point_coords + prompts[idx].point_labels()?, // point_labels + X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input, + X::zeros(&[1]), // has_mask_input + X::from(vec![image_height, image_width]), // orig_im_size + ] + } + SamKind::SamHq => { + vec![ + X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?, // image_embedding + X::from(xs[1].slice(s![idx, .., .., ..]).into_dyn().into_owned()) + .insert_axis(0)? + .insert_axis(0)?, // intern_embedding + prompts[idx].point_coords(ratio)?, // point_coords + prompts[idx].point_labels()?, // point_labels + X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input + X::zeros(&[1]), // has_mask_input + X::from(vec![image_height, image_width]), // orig_im_size + ] + } + SamKind::EdgeSam => { + vec![ + X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?, + prompts[idx].point_coords(ratio)?, + prompts[idx].point_labels()?, + ] + } + }; + + let ys_ = self.decoder.run(args)?; + + let mut y_masks: Vec = Vec::new(); + let mut y_polygons: Vec = Vec::new(); + + // masks & confs + let (masks, confs) = match self.kind { + SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => { + if !self.use_low_res_mask { + (&ys_[0], &ys_[1]) + } else { + (&ys_[2], &ys_[1]) + } + } + SamKind::EdgeSam => match (ys_[0].ndim(), ys_[1].ndim()) { + (2, 4) => (&ys_[1], &ys_[0]), + (4, 2) => (&ys_[0], &ys_[1]), + _ => anyhow::bail!("Can not parse the outputs of decoder."), + }, + }; + + for (mask, iou) in masks.axis_iter(Axis(0)).zip(confs.axis_iter(Axis(0))) { + let (i, conf) = match iou + .to_owned() + .into_raw_vec() + .into_iter() + .enumerate() + .max_by(|a, b| a.1.total_cmp(&b.1)) + { + Some((i, c)) => (i, c), + None => continue, + }; + + if conf < self.conf[0] { + continue; + } + let mask = mask.slice(s![i, .., ..]); + let (h, w) = mask.dim(); + let luma = if self.use_low_res_mask { + Ops::resize_lumaf32_vec( + &mask.to_owned().into_raw_vec(), + w as _, + h as _, + image_width as _, + image_height as _, + true, + "Bilinear", + )? + } else { + mask.mapv(|x| if x > 0. { 255u8 } else { 0u8 }) + .into_raw_vec() + }; + + let luma: image::ImageBuffer, Vec<_>> = + match image::ImageBuffer::from_raw(image_width as _, image_height as _, luma) { + None => continue, + Some(x) => x, + }; + + // contours + let mut rng = thread_rng(); + let id = rng.gen_range(0..20); + if self.find_contours { + let contours: Vec> = + imageproc::contours::find_contours_with_threshold(&luma, 0); + for c in contours.iter() { + let polygon = Polygon::default().with_points_imageproc(&c.points); + y_polygons.push(polygon.with_confidence(iou[0]).with_id(id)); + } + } + y_masks.push(Mask::default().with_mask(luma).with_id(id)); + } + + let mut y = Y::default(); + if !y_masks.is_empty() { + y = y.with_masks(&y_masks); + } + if !y_polygons.is_empty() { + y = y.with_polygons(&y_polygons); + } + + ys.push(y); + } + + Ok(ys) + } + + pub fn width_low_res(&self) -> usize { + self.width() as usize / 4 + } + + pub fn height_low_res(&self) -> usize { + self.height() as usize / 4 + } + + pub fn batch(&self) -> isize { + self.batch.opt + } + + pub fn width(&self) -> isize { + self.width.opt + } + + pub fn height(&self) -> isize { + self.height.opt + } +} diff --git a/src/models/yolo.rs b/src/models/yolo.rs index d44a7c0..56d5554 100644 --- a/src/models/yolo.rs +++ b/src/models/yolo.rs @@ -24,6 +24,7 @@ pub struct YOLO { names_kpt: Option>, task: YOLOTask, layout: YOLOPreds, + find_contours: bool, version: Option, } @@ -153,6 +154,7 @@ impl Vision for YOLO { names_kpt, layout, version, + find_contours: options.find_contours, }) } @@ -417,7 +419,6 @@ impl Vision for YOLO { .into_par_iter() .filter_map(|bbox| { let coefs = coefs.slice(s![bbox.id_born(), ..]).to_vec(); - let proto = protos.as_ref()?.slice(s![idx, .., .., ..]); let (nm, mh, mw) = proto.dim(); @@ -461,10 +462,9 @@ impl Vision for YOLO { } // Find contours - let contours: Vec> = - imageproc::contours::find_contours_with_threshold(&mask, 0); - - Some(( + let polygons = if self.find_contours { + let contours: Vec> = + imageproc::contours::find_contours_with_threshold(&mask, 0); contours .into_par_iter() .map(|x| { @@ -473,7 +473,13 @@ impl Vision for YOLO { .with_points_imageproc(&x.points) .with_name(bbox.name().cloned()) }) - .max_by(|x, y| x.area().total_cmp(&y.area()))?, + .max_by(|x, y| x.area().total_cmp(&y.area()))? + } else { + Polygon::default() + }; + + Some(( + polygons, Mask::default() .with_mask(mask) .with_id(bbox.id()) @@ -482,7 +488,12 @@ impl Vision for YOLO { }) .collect::<(Vec<_>, Vec<_>)>(); - y = y.with_polygons(&y_polygons).with_masks(&y_masks); + if !y_polygons.is_empty() { + y = y.with_polygons(&y_polygons); + } + if !y_masks.is_empty() { + y = y.with_masks(&y_masks); + } } } diff --git a/src/ys/polygon.rs b/src/ys/polygon.rs index 2f8fef4..f4df61d 100644 --- a/src/ys/polygon.rs +++ b/src/ys/polygon.rs @@ -64,6 +64,11 @@ impl Polygon { self } + pub fn with_confidence(mut self, x: f32) -> Self { + self.confidence = x; + self + } + pub fn id(&self) -> isize { self.id }