Compare commits

...

10 Commits

Author SHA1 Message Date
jamjamjon 4e932c4910 Bump the version to 0.0.20 2024-12-03 19:37:34 +08:00
Collide 2785b090c6
upgrade ort to v2.0.0-rc.9 (#52) 2024-12-03 19:16:23 +08:00
Jamjamjon 57db14ce5d
Update README.md 2024-10-10 00:30:52 +08:00
Jamjamjon 447889028e
Add Apple ml-depth-pro model 2024-10-10 00:26:26 +08:00
Jamjamjon 1d596383de
Add support for restricting detection classes (#45)
* Add support for restricting detection classes in `Options`
2024-10-05 17:49:08 +08:00
Jamjamjon 0102c15687
Minor fixes 2024-10-01 09:37:46 +08:00
Jamjamjon 64dc804a13
Update README.md 2024-09-30 22:48:07 +08:00
Jamjamjon 0609dd1f1d
Add YOLOv11
* Add YOLOv11
2024-09-30 22:43:34 +08:00
Jamjamjon 2cb9e57fc4
Update README.md 2024-09-28 10:49:06 +08:00
Jamjamjon f2c4593672
Update README.md 2024-09-28 10:10:05 +08:00
11 changed files with 612 additions and 322 deletions

View File

@ -1,6 +1,6 @@
[package] [package]
name = "usls" name = "usls"
version = "0.0.16" version = "0.0.20"
edition = "2021" edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls" repository = "https://github.com/jamjamjon/usls"
@ -12,7 +12,7 @@ exclude = ["assets/*", "examples/*", "scripts/*", "runs/*"]
[dependencies] [dependencies]
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
ndarray = { version = "0.16.1", features = ["rayon"] } ndarray = { version = "0.16.1", features = ["rayon"] }
ort = { version = "2.0.0-rc.5", default-features = false} ort = { version = "2.0.0-rc.9", default-features = false }
anyhow = { version = "1.0.75" } anyhow = { version = "1.0.75" }
regex = { version = "1.5.4" } regex = { version = "1.5.4" }
rand = { version = "0.8.5" } rand = { version = "0.8.5" }
@ -30,7 +30,7 @@ imageproc = { version = "0.24" }
ab_glyph = "0.2.23" ab_glyph = "0.2.23"
geo = "0.28.0" geo = "0.28.0"
prost = "0.12.4" prost = "0.12.4"
fast_image_resize = { version = "4.2.1", features = ["image"]} fast_image_resize = { version = "4.2.1", features = ["image"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
tempfile = "3.12.0" tempfile = "3.12.0"
@ -50,7 +50,6 @@ default = [
"ort/cuda", "ort/cuda",
"ort/tensorrt", "ort/tensorrt",
"ort/coreml", "ort/coreml",
"ort/operator-libraries"
] ]
auto = ["ort/download-binaries"] auto = ["ort/download-binaries"]

145
README.md
View File

@ -3,7 +3,7 @@
</p> </p>
<p align="center"> <p align="center">
| <a href="https://docs.rs/usls"><strong>Documentation</strong></a> | <a href="https://docs.rs/usls"><strong>Documentation</strong></a>
<br> <br>
<br> <br>
<a href='https://github.com/microsoft/onnxruntime/releases'> <a href='https://github.com/microsoft/onnxruntime/releases'>
@ -34,9 +34,9 @@
**`usls`** is a Rust library integrated with **ONNXRuntime** that provides a collection of state-of-the-art models for **Computer Vision** and **Vision-Language** tasks, including: **`usls`** is a Rust library integrated with **ONNXRuntime** that provides a collection of state-of-the-art models for **Computer Vision** and **Vision-Language** tasks, including:
- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10) - **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [YOLOv11](https://github.com/ultralytics/ultralytics)
- **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) - **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569) - **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569), [DepthPro](https://github.com/apple/ml-depth-pro)
- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242) - **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242)
<details> <details>
@ -51,7 +51,8 @@
| [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | | [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [YOLOv8](https://github.com/ultralytics/ultralytics) | Object Detection<br>Instance Segmentation<br>Classification<br>Oriented Object Detection<br>Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | | [YOLOv8](https://github.com/ultralytics/ultralytics) | Object Detection<br>Instance Segmentation<br>Classification<br>Oriented Object Detection<br>Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | | [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | | [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [YOLOv11](https://github.com/ultralytics/ultralytics) | Object Detection<br>Instance Segmentation<br>Classification<br>Oriented Object Detection<br>Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [RTDETR](https://arxiv.org/abs/2304.08069) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | | [RTDETR](https://arxiv.org/abs/2304.08069) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ | | [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | | | | [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | | |
@ -67,11 +68,12 @@
| [SVTR](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | ✅ | ✅ | ✅ | ✅ | | [SVTR](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | ✅ | ✅ | ✅ | ✅ |
| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | ✅ | ✅ | ❌ | ❌ | | [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | ✅ | ✅ | ❌ | ❌ |
| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | ✅ | ✅ | ✅ | ✅ | | [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | ✅ | ✅ | ✅ | ✅ |
| [Depth-Anything](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ | | [Depth-Anything v1 & v2](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ |
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ | | [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ |
| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | | | [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | |
| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Body Part Segmentation | [demo](examples/sapiens) | ✅ | ✅ | | | | [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Body Part Segmentation | [demo](examples/sapiens) | ✅ | ✅ | | |
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | | | | [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | | |
| [DepthPro](https://github.com/apple/ml-depth-pro) | Monocular Depth Estimation | [demo](examples/depth-pro) | ✅ | ✅ | | |
@ -80,7 +82,8 @@
## ⛳️ ONNXRuntime Linking ## ⛳️ ONNXRuntime Linking
You have two options to link the ONNXRuntime library <details>
<summary>You have two options to link the ONNXRuntime library</summary>
- ### Option 1: Manual Linking - ### Option 1: Manual Linking
@ -99,6 +102,7 @@ You have two options to link the ONNXRuntime library
cargo run -r --example yolo --features auto cargo run -r --example yolo --features auto
``` ```
</details>
## 🎈 Demo ## 🎈 Demo
@ -123,70 +127,95 @@ cargo run -r --example yolo # blip, clip, yolop, svtr, db, ...
- Build model with the provided `models` and `Options` - Build model with the provided `models` and `Options`
- Load images, video and stream with `DataLoader` - Load images, video and stream with `DataLoader`
- Do inference - Do inference
- Annotate inference results with `Annotator`
- Retrieve inference results from `Vec<Y>` - Retrieve inference results from `Vec<Y>`
- Annotate inference results with `Annotator`
- Display images and write them to video with `Viewer`
```rust <br/>
use usls::{models::YOLO, Annotator, DataLoader, Nms, Options, Vision, YOLOTask, YOLOVersion}; <details>
<summary>example code</summary>
fn main() -> anyhow::Result<()> { ```rust
// Build model with Options use usls::{models::YOLO, Annotator, DataLoader, Nms, Options, Vision, YOLOTask, YOLOVersion};
let options = Options::new()
.with_trt(0)
.with_model("yolo/v8-m-dyn.onnx")?
.with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
.with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb
.with_i00((1, 2, 4).into())
.with_i02((0, 640, 640).into())
.with_i03((0, 640, 640).into())
.with_confs(&[0.2]);
let mut model = YOLO::new(options)?;
// Build DataLoader to load image(s), video, stream fn main() -> anyhow::Result<()> {
let dl = DataLoader::new( // Build model with Options
// "./assets/bus.jpg", // local image let options = Options::new()
// "images/bus.jpg", // remote image .with_trt(0)
// "../images-folder", // local images (from folder) .with_model("yolo/v8-m-dyn.onnx")?
// "../demo.mp4", // local video .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
// "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // online video .with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb
"rtsp://admin:kkasd1234@192.168.2.217:554/h264/ch1/", // stream .with_ixx(0, 0, (1, 2, 4).into())
)? .with_ixx(0, 2, (0, 640, 640).into())
.with_batch(2) // iterate with batch_size = 2 .with_ixx(0, 3, (0, 640, 640).into())
.build()?; .with_confs(&[0.2]);
let mut model = YOLO::new(options)?;
// Build annotator // Build DataLoader to load image(s), video, stream
let annotator = Annotator::new() let dl = DataLoader::new(
.with_bboxes_thickness(4) // "./assets/bus.jpg", // local image
.with_saveout("YOLO-DataLoader"); // "images/bus.jpg", // remote image
// "../images-folder", // local images (from folder)
// "../demo.mp4", // local video
// "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // online video
"rtsp://admin:kkasd1234@192.168.2.217:554/h264/ch1/", // stream
)?
.with_batch(2) // iterate with batch_size = 2
.build()?;
// Run and annotate results // Build annotator
for (xs, _) in dl { let annotator = Annotator::new()
let ys = model.forward(&xs, false)?; .with_bboxes_thickness(4)
annotator.annotate(&xs, &ys); .with_saveout("YOLO-DataLoader");
// Retrieve inference results // Build viewer
for y in ys { let mut viewer = Viewer::new().with_delay(10).with_scale(1.).resizable(true);
// bboxes
if let Some(bboxes) = y.bboxes() { // Run and annotate results
for bbox in bboxes { for (xs, _) in dl {
println!( let ys = model.forward(&xs, false)?;
"Bbox: {}, {}, {}, {}, {}, {}", // annotator.annotate(&xs, &ys);
bbox.xmin(), let images_plotted = annotator.plot(&xs, &ys, false)?;
bbox.ymin(),
bbox.xmax(), // show image
bbox.ymax(), viewer.imshow(&images_plotted)?;
bbox.confidence(),
bbox.id(), // check out window and key event
); if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) {
} break;
}
// write video
viewer.write_batch(&images_plotted)?;
// Retrieve inference results
for y in ys {
// bboxes
if let Some(bboxes) = y.bboxes() {
for bbox in bboxes {
println!(
"Bbox: {}, {}, {}, {}, {}, {}",
bbox.xmin(),
bbox.ymin(),
bbox.xmax(),
bbox.ymax(),
bbox.confidence(),
bbox.id(),
);
} }
} }
} }
Ok(())
} }
```
// finish video write
viewer.finish_write()?;
Ok(())
}
```
</details>
</br>
## 📌 License ## 📌 License
This project is licensed under [LICENSE](LICENSE). This project is licensed under [LICENSE](LICENSE).

View File

@ -0,0 +1,26 @@
use usls::{models::DepthPro, Annotator, DataLoader, Options};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// options
let options = Options::default()
.with_model("depth-pro/q4f16.onnx")? // bnb4, f16
.with_ixx(0, 0, 1.into()) // batch. Note: now only support batch_size = 1
.with_ixx(0, 1, 3.into()) // channel
.with_ixx(0, 2, 1536.into()) // height
.with_ixx(0, 3, 1536.into()); // width
let mut model = DepthPro::new(options)?;
// load
let x = [DataLoader::try_read("images/street.jpg")?];
// run
let y = model.run(&x)?;
// annotate
let annotator = Annotator::default()
.with_colormap("Turbo")
.with_saveout("Depth-Pro");
annotator.annotate(&x, &y);
Ok(())
}

View File

@ -24,45 +24,41 @@
## Quick Start ## Quick Start
```Shell ```Shell
# customized
cargo run -r --example yolo -- --task detect --ver v8 --nc 6 --model xxx.onnx # YOLOv8
# Classify # Classify
cargo run -r --example yolo -- --task classify --ver v5 # YOLOv5 cargo run -r --example yolo -- --task classify --ver v5 --scale s --width 224 --height 224 --nc 1000 # YOLOv5
cargo run -r --example yolo -- --task classify --ver v8 # YOLOv8 cargo run -r --example yolo -- --task classify --ver v8 --scale n --width 224 --height 224 --nc 1000 # YOLOv8
cargo run -r --example yolo -- --task classify --ver v11 --scale n --width 224 --height 224 --nc 1000 # YOLOv11
# Detect # Detect
cargo run -r --example yolo -- --task detect --ver v5 # YOLOv5 cargo run -r --example yolo -- --task detect --ver v5 --scale n # YOLOv5
cargo run -r --example yolo -- --task detect --ver v6 # YOLOv6 cargo run -r --example yolo -- --task detect --ver v6 --scale n # YOLOv6
cargo run -r --example yolo -- --task detect --ver v7 # YOLOv7 cargo run -r --example yolo -- --task detect --ver v7 --scale t # YOLOv7
cargo run -r --example yolo -- --task detect --ver v8 # YOLOv8 cargo run -r --example yolo -- --task detect --ver v8 --scale n # YOLOv8
cargo run -r --example yolo -- --task detect --ver v9 # YOLOv9 cargo run -r --example yolo -- --task detect --ver v9 --scale t # YOLOv9
cargo run -r --example yolo -- --task detect --ver v10 # YOLOv10 cargo run -r --example yolo -- --task detect --ver v10 --scale n # YOLOv10
cargo run -r --example yolo -- --task detect --ver rtdetr # YOLOv8-RTDETR cargo run -r --example yolo -- --task detect --ver v11 --scale n # YOLOv11
cargo run -r --example yolo -- --task detect --ver v8 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world cargo run -r --example yolo -- --task detect --ver rtdetr --scale l # RTDETR
cargo run -r --example yolo -- --task detect --ver v8 --model yolo/v8-s-world-v2-shoes.onnx # YOLOv8-world
# Pose # Pose
cargo run -r --example yolo -- --task pose --ver v8 # YOLOv8-Pose cargo run -r --example yolo -- --task pose --ver v8 --scale n # YOLOv8-Pose
cargo run -r --example yolo -- --task pose --ver v11 --scale n # YOLOv11-Pose
# Segment # Segment
cargo run -r --example yolo -- --task segment --ver v5 # YOLOv5-Segment cargo run -r --example yolo -- --task segment --ver v5 --scale n # YOLOv5-Segment
cargo run -r --example yolo -- --task segment --ver v8 # YOLOv8-Segment cargo run -r --example yolo -- --task segment --ver v8 --scale n # YOLOv8-Segment
cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM cargo run -r --example yolo -- --task segment --ver v11 --scale n # YOLOv8-Segment
cargo run -r --example yolo -- --task segment --ver v8 --model yolo/FastSAM-s-dyn-f16.onnx # FastSAM
# Obb # Obb
cargo run -r --example yolo -- --task obb --ver v8 # YOLOv8-Obb cargo run -r --example yolo -- --ver v8 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv8-Obb
cargo run -r --example yolo -- --ver v11 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv11-Obb
``` ```
<details close> **`cargo run -r --example yolo -- --help` for more options**
<summary>other options</summary>
`--source` to specify the input images
`--model` to specify the ONNX model
`--width --height` to specify the input resolution
`--nc` to specify the number of model's classes
`--plot` to annotate with inference results
`--profile` to profile
`--cuda --trt --coreml --device_id` to select device
`--half` to use float16 when using TensorRT EP
</details>
## YOLOs configs with `Options` ## YOLOs configs with `Options`
@ -96,6 +92,8 @@ let options = Options::default()
..Default::default() ..Default::default()
} }
) )
// .with_nc(80)
// .with_names(&COCO_CLASS_NAMES_80)
.with_model("xxxx.onnx")?; .with_model("xxxx.onnx")?;
``` ```
</details> </details>
@ -140,7 +138,7 @@ let options = Options::default()
</details> </details>
<details close> <details close>
<summary>YOLOv8</summary> <summary>YOLOv8, YOLOv11</summary>
```Shell ```Shell
pip install -U ultralytics pip install -U ultralytics

View File

@ -2,188 +2,169 @@ use anyhow::Result;
use clap::Parser; use clap::Parser;
use usls::{ use usls::{
models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion, COCO_KEYPOINTS_17, models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOScale, YOLOTask,
COCO_SKELETONS_16, YOLOVersion, COCO_SKELETONS_16,
}; };
#[derive(Parser, Clone)] #[derive(Parser, Clone)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
pub struct Args { pub struct Args {
/// Path to the model
#[arg(long)] #[arg(long)]
pub model: Option<String>, pub model: Option<String>,
/// Input source path
#[arg(long, default_value_t = String::from("./assets/bus.jpg"))] #[arg(long, default_value_t = String::from("./assets/bus.jpg"))]
pub source: String, pub source: String,
/// YOLO Task
#[arg(long, value_enum, default_value_t = YOLOTask::Detect)] #[arg(long, value_enum, default_value_t = YOLOTask::Detect)]
pub task: YOLOTask, pub task: YOLOTask,
/// YOLO Version
#[arg(long, value_enum, default_value_t = YOLOVersion::V8)] #[arg(long, value_enum, default_value_t = YOLOVersion::V8)]
pub ver: YOLOVersion, pub ver: YOLOVersion,
/// YOLO Scale
#[arg(long, value_enum, default_value_t = YOLOScale::N)]
pub scale: YOLOScale,
/// Batch size
#[arg(long, default_value_t = 1)] #[arg(long, default_value_t = 1)]
pub batch_size: usize, pub batch_size: usize,
/// Minimum input width
#[arg(long, default_value_t = 224)] #[arg(long, default_value_t = 224)]
pub width_min: isize, pub width_min: isize,
/// Input width
#[arg(long, default_value_t = 640)] #[arg(long, default_value_t = 640)]
pub width: isize, pub width: isize,
#[arg(long, default_value_t = 800)] /// Maximum input width
#[arg(long, default_value_t = 1024)]
pub width_max: isize, pub width_max: isize,
/// Minimum input height
#[arg(long, default_value_t = 224)] #[arg(long, default_value_t = 224)]
pub height_min: isize, pub height_min: isize,
/// Input height
#[arg(long, default_value_t = 640)] #[arg(long, default_value_t = 640)]
pub height: isize, pub height: isize,
#[arg(long, default_value_t = 800)] /// Maximum input height
#[arg(long, default_value_t = 1024)]
pub height_max: isize, pub height_max: isize,
/// Number of classes
#[arg(long, default_value_t = 80)] #[arg(long, default_value_t = 80)]
pub nc: usize, pub nc: usize,
/// Class confidence
#[arg(long)]
pub confs: Vec<f32>,
/// Enable TensorRT support
#[arg(long)] #[arg(long)]
pub trt: bool, pub trt: bool,
/// Enable CUDA support
#[arg(long)] #[arg(long)]
pub cuda: bool, pub cuda: bool,
#[arg(long)] /// Enable CoreML support
pub half: bool,
#[arg(long)] #[arg(long)]
pub coreml: bool, pub coreml: bool,
/// Use TensorRT half precision
#[arg(long)]
pub half: bool,
/// Device ID to use
#[arg(long, default_value_t = 0)] #[arg(long, default_value_t = 0)]
pub device_id: usize, pub device_id: usize,
/// Enable performance profiling
#[arg(long)] #[arg(long)]
pub profile: bool, pub profile: bool,
#[arg(long)] /// Disable contour drawing
pub no_plot: bool,
#[arg(long)] #[arg(long)]
pub no_contours: bool, pub no_contours: bool,
/// Show result
#[arg(long)]
pub view: bool,
/// Do not save output
#[arg(long)]
pub nosave: bool,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
// build options // model path
let options = Options::default(); let path = match &args.model {
None => format!(
// version & task "yolo/{}-{}-{}.onnx",
let (options, saveout) = match args.ver { args.ver.name(),
YOLOVersion::V5 => match args.task { args.scale.name(),
YOLOTask::Classify => ( args.task.name()
options.with_model(&args.model.unwrap_or("yolo/v5-n-cls-dyn.onnx".to_string()))?, ),
"YOLOv5-Classify", Some(x) => x.to_string(),
),
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/v5-n-dyn.onnx".to_string()))?,
"YOLOv5-Detect",
),
YOLOTask::Segment => (
options.with_model(&args.model.unwrap_or("yolo/v5-n-seg-dyn.onnx".to_string()))?,
"YOLOv5-Segment",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V6 => match args.task {
YOLOTask::Detect => (
options
.with_model(&args.model.unwrap_or("yolo/v6-n-dyn.onnx".to_string()))?
.with_nc(args.nc),
"YOLOv6-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V7 => match args.task {
YOLOTask::Detect => (
options
.with_model(&args.model.unwrap_or("yolo/v7-tiny-dyn.onnx".to_string()))?
.with_nc(args.nc),
"YOLOv7-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V8 => match args.task {
YOLOTask::Classify => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-cls-dyn.onnx".to_string()))?,
"YOLOv8-Classify",
),
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-dyn.onnx".to_string()))?,
"YOLOv8-Detect",
),
YOLOTask::Segment => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-seg-dyn.onnx".to_string()))?,
"YOLOv8-Segment",
),
YOLOTask::Pose => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-pose-dyn.onnx".to_string()))?,
"YOLOv8-Pose",
),
YOLOTask::Obb => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-obb-dyn.onnx".to_string()))?,
"YOLOv8-Obb",
),
},
YOLOVersion::V9 => match args.task {
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/v9-c-dyn-f16.onnx".to_string()))?,
"YOLOv9-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V10 => match args.task {
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/v10-n.onnx".to_string()))?,
"YOLOv10-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::RTDETR => match args.task {
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/rtdetr-l-f16.onnx".to_string()))?,
"RTDETR",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
}; };
let options = options // saveout
.with_yolo_version(args.ver) let saveout = match &args.model {
.with_yolo_task(args.task); None => format!(
"{}-{}-{}",
args.ver.name(),
args.scale.name(),
args.task.name()
),
Some(x) => {
let p = std::path::PathBuf::from(&x);
p.file_stem().unwrap().to_str().unwrap().to_string()
}
};
// device // device
let options = if args.cuda { let device = if args.cuda {
options.with_cuda(args.device_id) Device::Cuda(args.device_id)
} else if args.trt { } else if args.trt {
let options = options.with_trt(args.device_id); Device::Trt(args.device_id)
if args.half {
options.with_trt_fp16(true)
} else {
options
}
} else if args.coreml { } else if args.coreml {
options.with_coreml(args.device_id) Device::CoreML(args.device_id)
} else { } else {
options.with_cpu() Device::Cpu(args.device_id)
}; };
let options = options
// build options
let options = Options::new()
.with_model(&path)?
.with_yolo_version(args.ver)
.with_yolo_task(args.task)
.with_device(device)
.with_trt_fp16(args.half)
.with_ixx(0, 0, (1, args.batch_size as _, 4).into()) .with_ixx(0, 0, (1, args.batch_size as _, 4).into())
.with_ixx(0, 2, (args.height_min, args.height, args.height_max).into()) .with_ixx(0, 2, (args.height_min, args.height, args.height_max).into())
.with_ixx(0, 3, (args.width_min, args.width, args.width_max).into()) .with_ixx(0, 3, (args.width_min, args.width, args.width_max).into())
.with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15 .with_confs(if args.confs.is_empty() {
&[0.2, 0.15]
} else {
&args.confs
})
.with_nc(args.nc)
// .with_names(&COCO_CLASS_NAMES_80) // .with_names(&COCO_CLASS_NAMES_80)
.with_names2(&COCO_KEYPOINTS_17) // .with_names2(&COCO_KEYPOINTS_17)
.with_find_contours(!args.no_contours) // find contours or not .with_find_contours(!args.no_contours) // find contours or not
.exclude_classes(&[0])
// .retain_classes(&[0, 5])
.with_profile(args.profile); .with_profile(args.profile);
// build model
let mut model = YOLO::new(options)?; let mut model = YOLO::new(options)?;
// build dataloader // build dataloader
@ -194,16 +175,54 @@ fn main() -> Result<()> {
// build annotator // build annotator
let annotator = Annotator::default() let annotator = Annotator::default()
.with_skeletons(&COCO_SKELETONS_16) .with_skeletons(&COCO_SKELETONS_16)
.with_bboxes_thickness(4)
.without_masks(true) // No masks plotting when doing segment task. .without_masks(true) // No masks plotting when doing segment task.
.with_saveout(saveout); .with_bboxes_thickness(3)
.with_keypoints_name(false) // Enable keypoints names
.with_saveout_subs(&["YOLO"])
.with_saveout(&saveout);
// build viewer
let mut viewer = if args.view {
Some(Viewer::new().with_delay(5).with_scale(1.).resizable(true))
} else {
None
};
// run & annotate // run & annotate
for (xs, _paths) in dl { for (xs, _paths) in dl {
// let ys = model.run(&xs)?; // way one // let ys = model.run(&xs)?; // way one
let ys = model.forward(&xs, args.profile)?; // way two let ys = model.forward(&xs, args.profile)?; // way two
if !args.no_plot { let images_plotted = annotator.plot(&xs, &ys, !args.nosave)?;
annotator.annotate(&xs, &ys);
// show image
match &mut viewer {
Some(viewer) => viewer.imshow(&images_plotted)?,
None => continue,
}
// check out window and key event
match &mut viewer {
Some(viewer) => {
if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) {
break;
}
}
None => continue,
}
// write video
if !args.nosave {
match &mut viewer {
Some(viewer) => viewer.write_batch(&images_plotted)?,
None => continue,
}
}
}
// finish video write
if !args.nosave {
if let Some(viewer) = &mut viewer {
viewer.finish_write()?;
} }
} }

View File

@ -48,6 +48,8 @@ pub struct Options {
pub sam_kind: Option<SamKind>, pub sam_kind: Option<SamKind>,
pub use_low_res_mask: Option<bool>, pub use_low_res_mask: Option<bool>,
pub sapiens_task: Option<SapiensTask>, pub sapiens_task: Option<SapiensTask>,
pub classes_excluded: Vec<isize>,
pub classes_retained: Vec<isize>,
} }
impl Default for Options { impl Default for Options {
@ -88,6 +90,8 @@ impl Default for Options {
use_low_res_mask: None, use_low_res_mask: None,
sapiens_task: None, sapiens_task: None,
task: Task::Untitled, task: Task::Untitled,
classes_excluded: vec![],
classes_retained: vec![],
} }
} }
} }
@ -276,4 +280,16 @@ impl Options {
self.iiixs.push(Iiix::from((i, ii, x))); self.iiixs.push(Iiix::from((i, ii, x)));
self self
} }
pub fn exclude_classes(mut self, xs: &[isize]) -> Self {
self.classes_retained.clear();
self.classes_excluded.extend_from_slice(xs);
self
}
pub fn retain_classes(mut self, xs: &[isize]) -> Self {
self.classes_excluded.clear();
self.classes_retained.extend_from_slice(xs);
self
}
} }

View File

@ -2,7 +2,9 @@ use anyhow::Result;
use half::f16; use half::f16;
use ndarray::{Array, IxDyn}; use ndarray::{Array, IxDyn};
use ort::{ use ort::{
ExecutionProvider, Session, SessionBuilder, TensorElementType, TensorRTExecutionProvider, execution_providers::{ExecutionProvider, TensorRTExecutionProvider},
session::{builder::SessionBuilder, Session},
tensor::TensorElementType,
}; };
use prost::Message; use prost::Message;
use std::collections::HashSet; use std::collections::HashSet;
@ -88,14 +90,14 @@ impl OrtEngine {
// build // build
ort::init().commit()?; ort::init().commit()?;
let builder = Session::builder()?; let mut builder = Session::builder()?;
let mut device = config.device.to_owned(); let mut device = config.device.to_owned();
match device { match device {
Device::Trt(device_id) => { Device::Trt(device_id) => {
Self::build_trt( Self::build_trt(
&inputs_attrs.names, &inputs_attrs.names,
&inputs_minoptmax, &inputs_minoptmax,
&builder, &mut builder,
device_id, device_id,
config.trt_int8_enable, config.trt_int8_enable,
config.trt_fp16_enable, config.trt_fp16_enable,
@ -103,23 +105,23 @@ impl OrtEngine {
)?; )?;
} }
Device::Cuda(device_id) => { Device::Cuda(device_id) => {
Self::build_cuda(&builder, device_id).unwrap_or_else(|err| { Self::build_cuda(&mut builder, device_id).unwrap_or_else(|err| {
tracing::warn!("{err}, Using cpu"); tracing::warn!("{err}, Using cpu");
device = Device::Cpu(0); device = Device::Cpu(0);
}) })
} }
Device::CoreML(_) => Self::build_coreml(&builder).unwrap_or_else(|err| { Device::CoreML(_) => Self::build_coreml(&mut builder).unwrap_or_else(|err| {
tracing::warn!("{err}, Using cpu"); tracing::warn!("{err}, Using cpu");
device = Device::Cpu(0); device = Device::Cpu(0);
}), }),
Device::Cpu(_) => { Device::Cpu(_) => {
Self::build_cpu(&builder)?; Self::build_cpu(&mut builder)?;
} }
_ => todo!(), _ => todo!(),
} }
let session = builder let session = builder
.with_optimization_level(ort::GraphOptimizationLevel::Level3)? .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
.commit_from_file(&config.onnx_path)?; .commit_from_file(&config.onnx_path)?;
// summary // summary
@ -149,7 +151,7 @@ impl OrtEngine {
fn build_trt( fn build_trt(
names: &[String], names: &[String],
inputs_minoptmax: &[Vec<MinOptMax>], inputs_minoptmax: &[Vec<MinOptMax>],
builder: &SessionBuilder, builder: &mut SessionBuilder,
device_id: usize, device_id: usize,
int8_enable: bool, int8_enable: bool,
fp16_enable: bool, fp16_enable: bool,
@ -205,8 +207,9 @@ impl OrtEngine {
} }
} }
fn build_cuda(builder: &SessionBuilder, device_id: usize) -> Result<()> { fn build_cuda(builder: &mut SessionBuilder, device_id: usize) -> Result<()> {
let ep = ort::CUDAExecutionProvider::default().with_device_id(device_id as i32); let ep = ort::execution_providers::CUDAExecutionProvider::default()
.with_device_id(device_id as i32);
if ep.is_available()? && ep.register(builder).is_ok() { if ep.is_available()? && ep.register(builder).is_ok() {
Ok(()) Ok(())
} else { } else {
@ -214,8 +217,8 @@ impl OrtEngine {
} }
} }
fn build_coreml(builder: &SessionBuilder) -> Result<()> { fn build_coreml(builder: &mut SessionBuilder) -> Result<()> {
let ep = ort::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only(); let ep = ort::execution_providers::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
if ep.is_available()? && ep.register(builder).is_ok() { if ep.is_available()? && ep.register(builder).is_ok() {
Ok(()) Ok(())
} else { } else {
@ -223,8 +226,8 @@ impl OrtEngine {
} }
} }
fn build_cpu(builder: &SessionBuilder) -> Result<()> { fn build_cpu(builder: &mut SessionBuilder) -> Result<()> {
let ep = ort::CPUExecutionProvider::default(); let ep = ort::execution_providers::CPUExecutionProvider::default();
if ep.is_available()? && ep.register(builder).is_ok() { if ep.is_available()? && ep.register(builder).is_ok() {
Ok(()) Ok(())
} else { } else {
@ -292,28 +295,28 @@ impl OrtEngine {
let t_pre = std::time::Instant::now(); let t_pre = std::time::Instant::now();
for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) { for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) {
let x_ = match &idtype { let x_ = match &idtype {
TensorElementType::Float32 => ort::Value::from_array(x.view())?.into_dyn(), TensorElementType::Float32 => ort::value::Value::from_array(x.view())?.into_dyn(),
TensorElementType::Float16 => { TensorElementType::Float16 => {
ort::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn() ort::value::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
} }
TensorElementType::Int32 => { TensorElementType::Int32 => {
ort::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn() ort::value::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
} }
TensorElementType::Int64 => { TensorElementType::Int64 => {
ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn() ort::value::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
} }
TensorElementType::Uint8 => { TensorElementType::Uint8 => {
ort::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn() ort::value::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
} }
TensorElementType::Int8 => { TensorElementType::Int8 => {
ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn() ort::value::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
} }
TensorElementType::Bool => { TensorElementType::Bool => {
ort::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn() ort::value::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
} }
_ => todo!(), _ => todo!(),
}; };
xs_.push(Into::<ort::SessionInputValue<'_>>::into(x_)); xs_.push(Into::<ort::session::SessionInputValue<'_>>::into(x_));
} }
let t_pre = t_pre.elapsed(); let t_pre = t_pre.elapsed();
self.ts.add_or_push(0, t_pre); self.ts.add_or_push(0, t_pre);
@ -451,45 +454,45 @@ impl OrtEngine {
} }
#[allow(dead_code)] #[allow(dead_code)]
fn nbytes_from_onnx_dtype(x: &ort::TensorElementType) -> usize { fn nbytes_from_onnx_dtype(x: &ort::tensor::TensorElementType) -> usize {
match x { match x {
ort::TensorElementType::Float64 ort::tensor::TensorElementType::Float64
| ort::TensorElementType::Uint64 | ort::tensor::TensorElementType::Uint64
| ort::TensorElementType::Int64 => 8, // i64, f64, u64 | ort::tensor::TensorElementType::Int64 => 8, // i64, f64, u64
ort::TensorElementType::Float32 ort::tensor::TensorElementType::Float32
| ort::TensorElementType::Uint32 | ort::tensor::TensorElementType::Uint32
| ort::TensorElementType::Int32 | ort::tensor::TensorElementType::Int32
| ort::TensorElementType::String => 4, // f32, i32, u32, string(1~4) | ort::tensor::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
ort::TensorElementType::Float16 ort::tensor::TensorElementType::Float16
| ort::TensorElementType::Bfloat16 | ort::tensor::TensorElementType::Bfloat16
| ort::TensorElementType::Int16 | ort::tensor::TensorElementType::Int16
| ort::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16 | ort::tensor::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
ort::TensorElementType::Uint8 ort::tensor::TensorElementType::Uint8
| ort::TensorElementType::Int8 | ort::tensor::TensorElementType::Int8
| ort::TensorElementType::Bool => 1, // u8, i8, bool | ort::tensor::TensorElementType::Bool => 1, // u8, i8, bool
} }
} }
#[allow(dead_code)] #[allow(dead_code)]
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::TensorElementType> { fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::tensor::TensorElementType> {
match value { match value {
0 => None, 0 => None,
1 => Some(ort::TensorElementType::Float32), 1 => Some(ort::tensor::TensorElementType::Float32),
2 => Some(ort::TensorElementType::Uint8), 2 => Some(ort::tensor::TensorElementType::Uint8),
3 => Some(ort::TensorElementType::Int8), 3 => Some(ort::tensor::TensorElementType::Int8),
4 => Some(ort::TensorElementType::Uint16), 4 => Some(ort::tensor::TensorElementType::Uint16),
5 => Some(ort::TensorElementType::Int16), 5 => Some(ort::tensor::TensorElementType::Int16),
6 => Some(ort::TensorElementType::Int32), 6 => Some(ort::tensor::TensorElementType::Int32),
7 => Some(ort::TensorElementType::Int64), 7 => Some(ort::tensor::TensorElementType::Int64),
8 => Some(ort::TensorElementType::String), 8 => Some(ort::tensor::TensorElementType::String),
9 => Some(ort::TensorElementType::Bool), 9 => Some(ort::tensor::TensorElementType::Bool),
10 => Some(ort::TensorElementType::Float16), 10 => Some(ort::tensor::TensorElementType::Float16),
11 => Some(ort::TensorElementType::Float64), 11 => Some(ort::tensor::TensorElementType::Float64),
12 => Some(ort::TensorElementType::Uint32), 12 => Some(ort::tensor::TensorElementType::Uint32),
13 => Some(ort::TensorElementType::Uint64), 13 => Some(ort::tensor::TensorElementType::Uint64),
14 => None, // COMPLEX64 14 => None, // COMPLEX64
15 => None, // COMPLEX128 15 => None, // COMPLEX128
16 => Some(ort::TensorElementType::Bfloat16), 16 => Some(ort::tensor::TensorElementType::Bfloat16),
_ => None, _ => None,
} }
} }
@ -499,7 +502,7 @@ impl OrtEngine {
value_info: &[onnx::ValueInfoProto], value_info: &[onnx::ValueInfoProto],
) -> Result<OrtTensorAttr> { ) -> Result<OrtTensorAttr> {
let mut dimss: Vec<Vec<usize>> = Vec::new(); let mut dimss: Vec<Vec<usize>> = Vec::new();
let mut dtypes: Vec<ort::TensorElementType> = Vec::new(); let mut dtypes: Vec<ort::tensor::TensorElementType> = Vec::new();
let mut names: Vec<String> = Vec::new(); let mut names: Vec<String> = Vec::new();
for v in value_info.iter() { for v in value_info.iter() {
if initializer_names.contains(v.name.as_str()) { if initializer_names.contains(v.name.as_str()) {
@ -569,7 +572,7 @@ impl OrtEngine {
&self.outputs_attrs.names &self.outputs_attrs.names
} }
pub fn odtypes(&self) -> &Vec<ort::TensorElementType> { pub fn odtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
&self.outputs_attrs.dtypes &self.outputs_attrs.dtypes
} }
@ -585,7 +588,7 @@ impl OrtEngine {
&self.inputs_attrs.names &self.inputs_attrs.names
} }
pub fn idtypes(&self) -> &Vec<ort::TensorElementType> { pub fn idtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
&self.inputs_attrs.dtypes &self.inputs_attrs.dtypes
} }

86
src/models/depth_pro.rs Normal file
View File

@ -0,0 +1,86 @@
use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y};
use anyhow::Result;
use image::DynamicImage;
use ndarray::Axis;
#[derive(Debug)]
pub struct DepthPro {
engine: OrtEngine,
height: MinOptMax,
width: MinOptMax,
batch: MinOptMax,
}
impl DepthPro {
pub fn new(options: Options) -> Result<Self> {
let mut engine = OrtEngine::new(&options)?;
let (batch, height, width) = (
engine.batch().clone(),
engine.height().clone(),
engine.width().clone(),
);
engine.dry_run()?;
Ok(Self {
engine,
height,
width,
batch,
})
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
let xs_ = X::apply(&[
Ops::Resize(
xs,
self.height.opt() as u32,
self.width.opt() as u32,
"Bilinear",
),
Ops::Normalize(0., 255.),
Ops::Standardize(&[0.5, 0.5, 0.5], &[0.5, 0.5, 0.5], 3),
Ops::Nhwc2nchw,
])?;
let ys = self.engine.run(Xs::from(xs_))?;
self.postprocess(ys, xs)
}
pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
let (predicted_depth, _focallength_px) = (&xs["predicted_depth"], &xs["focallength_px"]);
let predicted_depth = predicted_depth.mapv(|x| 1. / x);
let mut ys: Vec<Y> = Vec::new();
for (idx, luma) in predicted_depth.axis_iter(Axis(0)).enumerate() {
let (w1, h1) = (xs0[idx].width(), xs0[idx].height());
let v = luma.into_owned().into_raw_vec_and_offset().0;
let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap();
let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap();
let v = v
.iter()
.map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8)
.collect::<Vec<_>>();
let luma = Ops::resize_luma8_u8(
&v,
self.width.opt() as _,
self.height.opt() as _,
w1 as _,
h1 as _,
false,
"Bilinear",
)?;
let luma: image::ImageBuffer<image::Luma<_>, Vec<_>> =
match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) {
None => continue,
Some(x) => x,
};
ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)]));
}
Ok(ys)
}
pub fn batch(&self) -> isize {
self.batch.opt() as _
}
}

View File

@ -4,6 +4,7 @@ mod blip;
mod clip; mod clip;
mod db; mod db;
mod depth_anything; mod depth_anything;
mod depth_pro;
mod dinov2; mod dinov2;
mod florence2; mod florence2;
mod grounding_dino; mod grounding_dino;
@ -20,6 +21,7 @@ pub use blip::Blip;
pub use clip::Clip; pub use clip::Clip;
pub use db::DB; pub use db::DB;
pub use depth_anything::DepthAnything; pub use depth_anything::DepthAnything;
pub use depth_pro::DepthPro;
pub use dinov2::Dinov2; pub use dinov2::Dinov2;
pub use florence2::Florence2; pub use florence2::Florence2;
pub use grounding_dino::GroundingDINO; pub use grounding_dino::GroundingDINO;

View File

@ -20,12 +20,14 @@ pub struct YOLO {
confs: DynConf, confs: DynConf,
kconfs: DynConf, kconfs: DynConf,
iou: f32, iou: f32,
names: Option<Vec<String>>, names: Vec<String>,
names_kpt: Option<Vec<String>>, names_kpt: Vec<String>,
task: YOLOTask, task: YOLOTask,
layout: YOLOPreds, layout: YOLOPreds,
find_contours: bool, find_contours: bool,
version: Option<YOLOVersion>, version: Option<YOLOVersion>,
classes_excluded: Vec<isize>,
classes_retained: Vec<isize>,
} }
impl Vision for YOLO { impl Vision for YOLO {
@ -64,27 +66,26 @@ impl Vision for YOLO {
Some(task) => match task { Some(task) => match task {
YOLOTask::Classify => match ver { YOLOTask::Classify => match ver {
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_clss().apply_softmax(true)), YOLOVersion::V5 => (Some(ver), YOLOPreds::n_clss().apply_softmax(true)),
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_clss()), YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_clss()),
x => anyhow::bail!("YOLOTask::Classify is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") x => anyhow::bail!("YOLOTask::Classify is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
} }
YOLOTask::Detect => match ver { YOLOTask::Detect => match ver {
YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver),YOLOPreds::n_a_cxcywh_confclss()), YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss()),
YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()), YOLOVersion::V8 | YOLOVersion::V9 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_a()),
YOLOVersion::V9 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()), YOLOVersion::V10 => (Some(ver), YOLOPreds::n_a_xyxy_confcls().apply_nms(false)),
YOLOVersion::V10 => (Some(ver),YOLOPreds::n_a_xyxy_confcls().apply_nms(false)), YOLOVersion::RTDETR => (Some(ver), YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)),
YOLOVersion::RTDETR => (Some(ver),YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)),
} }
YOLOTask::Pose => match ver { YOLOTask::Pose => match ver {
YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_xycs_a()), YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_xycs_a()),
x => anyhow::bail!("YOLOTask::Pose is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") x => anyhow::bail!("YOLOTask::Pose is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
} }
YOLOTask::Segment => match ver { YOLOTask::Segment => match ver {
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss_coefs()), YOLOVersion::V5 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss_coefs()),
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()), YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()),
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
} }
YOLOTask::Obb => match ver { YOLOTask::Obb => match ver {
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()), YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()),
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
} }
} }
@ -97,49 +98,75 @@ impl Vision for YOLO {
let task = task.unwrap_or(layout.task()); let task = task.unwrap_or(layout.task());
// The number of classes & Class names // Class names: user-defined.or(parsed)
let mut names = options.names.or(Self::fetch_names(&engine)); let names_parsed = Self::fetch_names(&engine);
let nc = match options.nc { let names = match names_parsed {
Some(nc) => { Some(names_parsed) => match options.names {
match &names { Some(names) => {
None => names = Some((0..nc).map(|x| x.to_string()).collect::<Vec<String>>()), if names.len() == names_parsed.len() {
Some(names) => { Some(names)
assert_eq!( } else {
nc, anyhow::bail!(
"The lengths of parsed class names: {} and user-defined class names: {} do not match.",
names_parsed.len(),
names.len(), names.len(),
"The length of `nc` and `class names` is not equal."
); );
} }
} }
nc None => Some(names_parsed),
} },
None => match &names { None => options.names,
Some(names) => names.len(), };
None => panic!(
"Can not parse model without `nc` and `class names`. Try to make it explicit with `options.with_nc(80)`" // nc: names.len().or(options.nc)
let nc = match &names {
Some(names) => names.len(),
None => match options.nc {
Some(nc) => nc,
None => anyhow::bail!(
"Unable to obtain the number of classes. Please specify them explicitly using `options.with_nc(usize)` or `options.with_names(&[&str])`."
), ),
}
};
// Class names
let names = match names {
None => Self::n2s(nc),
Some(names) => names,
};
// Keypoint names & nk
let (nk, names_kpt) = match Self::fetch_kpts(&engine) {
None => (0, vec![]),
Some(nk) => match options.names2 {
Some(names) => {
if names.len() == nk {
(nk, names)
} else {
anyhow::bail!(
"The lengths of user-defined keypoint names: {} and nk: {} do not match.",
names.len(),
nk,
);
}
}
None => (nk, Self::n2s(nk)),
}, },
}; };
// Keypoints names // Confs & Iou
let names_kpt = options.names2;
// The number of keypoints
let nk = engine
.try_fetch("kpt_shape")
.map(|kpt_string| {
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
let caps = re.captures(&kpt_string).unwrap();
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
})
.unwrap_or(0_usize);
let confs = DynConf::new(&options.confs, nc); let confs = DynConf::new(&options.confs, nc);
let kconfs = DynConf::new(&options.kconfs, nk); let kconfs = DynConf::new(&options.kconfs, nk);
let iou = options.iou.unwrap_or(0.45); let iou = options.iou.unwrap_or(0.45);
// Classes excluded and retained
let classes_excluded = options.classes_excluded;
let classes_retained = options.classes_retained;
// Summary // Summary
tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version); tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version);
// dry run
engine.dry_run()?; engine.dry_run()?;
Ok(Self { Ok(Self {
@ -158,6 +185,8 @@ impl Vision for YOLO {
layout, layout,
version, version,
find_contours: options.find_contours, find_contours: options.find_contours,
classes_excluded,
classes_retained,
}) })
} }
@ -219,10 +248,8 @@ impl Vision for YOLO {
slice_clss.into_owned() slice_clss.into_owned()
}; };
let mut probs = Prob::default().with_probs(&x.into_raw_vec_and_offset().0); let mut probs = Prob::default().with_probs(&x.into_raw_vec_and_offset().0);
if let Some(names) = &self.names { probs = probs
probs = .with_names(&self.names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
probs.with_names(&names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
}
return Some(y.with_probs(&probs)); return Some(y.with_probs(&probs));
} }
@ -257,7 +284,19 @@ impl Vision for YOLO {
} }
}; };
// filtering // filtering by class id
if !self.classes_excluded.is_empty()
&& self.classes_excluded.contains(&(class_id as isize))
{
return None;
}
if !self.classes_retained.is_empty()
&& !self.classes_retained.contains(&(class_id as isize))
{
return None;
}
// filtering by conf
if confidence < self.confs[class_id] { if confidence < self.confs[class_id] {
return None; return None;
} }
@ -325,9 +364,7 @@ impl Vision for YOLO {
) )
.with_confidence(confidence) .with_confidence(confidence)
.with_id(class_id as isize); .with_id(class_id as isize);
if let Some(names) = &self.names { mbr = mbr.with_name(&self.names[class_id]);
mbr = mbr.with_name(&names[class_id]);
}
(None, Some(mbr)) (None, Some(mbr))
} }
@ -337,9 +374,7 @@ impl Vision for YOLO {
.with_confidence(confidence) .with_confidence(confidence)
.with_id(class_id as isize) .with_id(class_id as isize)
.with_id_born(i as isize); .with_id_born(i as isize);
if let Some(names) = &self.names { bbox = bbox.with_name(&self.names[class_id]);
bbox = bbox.with_name(&names[class_id]);
}
(Some(bbox), None) (Some(bbox), None)
} }
@ -394,9 +429,7 @@ impl Vision for YOLO {
ky.max(0.0f32).min(image_height), ky.max(0.0f32).min(image_height),
); );
if let Some(names) = &self.names_kpt { kpt = kpt.with_name(&self.names_kpt[i]);
kpt = kpt.with_name(&names[i]);
}
kpt kpt
} }
}) })
@ -505,16 +538,16 @@ impl Vision for YOLO {
} }
impl YOLO { impl YOLO {
pub fn batch(&self) -> isize { pub fn batch(&self) -> usize {
self.batch.opt() as _ self.batch.opt()
} }
pub fn width(&self) -> isize { pub fn width(&self) -> usize {
self.width.opt() as _ self.width.opt()
} }
pub fn height(&self) -> isize { pub fn height(&self) -> usize {
self.height.opt() as _ self.height.opt()
} }
pub fn version(&self) -> Option<&YOLOVersion> { pub fn version(&self) -> Option<&YOLOVersion> {
@ -541,4 +574,16 @@ impl YOLO {
names_ names_
}) })
} }
fn fetch_kpts(engine: &OrtEngine) -> Option<usize> {
engine.try_fetch("kpt_shape").map(|s| {
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
let caps = re.captures(&s).unwrap();
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
})
}
fn n2s(n: usize) -> Vec<String> {
(0..n).map(|x| format!("# {}", x)).collect::<Vec<String>>()
}
} }

View File

@ -9,6 +9,28 @@ pub enum YOLOTask {
Obb, Obb,
} }
impl YOLOTask {
pub fn name(&self) -> String {
match self {
Self::Classify => "cls".to_string(),
Self::Detect => "det".to_string(),
Self::Pose => "pose".to_string(),
Self::Segment => "seg".to_string(),
Self::Obb => "obb".to_string(),
}
}
pub fn name_detailed(&self) -> String {
match self {
Self::Classify => "image-classification".to_string(),
Self::Detect => "object-detection".to_string(),
Self::Pose => "pose-estimation".to_string(),
Self::Segment => "instance-segment".to_string(),
Self::Obb => "oriented-object-detection".to_string(),
}
}
}
#[derive(Debug, Copy, Clone, clap::ValueEnum)] #[derive(Debug, Copy, Clone, clap::ValueEnum)]
pub enum YOLOVersion { pub enum YOLOVersion {
V5, V5,
@ -17,9 +39,54 @@ pub enum YOLOVersion {
V8, V8,
V9, V9,
V10, V10,
V11,
RTDETR, RTDETR,
} }
impl YOLOVersion {
pub fn name(&self) -> String {
match self {
Self::V5 => "v5".to_string(),
Self::V6 => "v6".to_string(),
Self::V7 => "v7".to_string(),
Self::V8 => "v8".to_string(),
Self::V9 => "v9".to_string(),
Self::V10 => "v10".to_string(),
Self::V11 => "v11".to_string(),
Self::RTDETR => "rtdetr".to_string(),
}
}
}
#[derive(Debug, Copy, Clone, clap::ValueEnum)]
pub enum YOLOScale {
N,
T,
B,
S,
M,
L,
C,
E,
X,
}
impl YOLOScale {
pub fn name(&self) -> String {
match self {
Self::N => "n".to_string(),
Self::T => "t".to_string(),
Self::S => "s".to_string(),
Self::B => "b".to_string(),
Self::M => "m".to_string(),
Self::L => "l".to_string(),
Self::C => "c".to_string(),
Self::E => "e".to_string(),
Self::X => "x".to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum BoxType { pub enum BoxType {
/// 1 /// 1