Compare commits
10 Commits
6ace97f09f
...
4e932c4910
| Author | SHA1 | Date |
|---|---|---|
|
|
4e932c4910 | |
|
|
2785b090c6 | |
|
|
57db14ce5d | |
|
|
447889028e | |
|
|
1d596383de | |
|
|
0102c15687 | |
|
|
64dc804a13 | |
|
|
0609dd1f1d | |
|
|
2cb9e57fc4 | |
|
|
f2c4593672 |
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "usls"
|
name = "usls"
|
||||||
version = "0.0.16"
|
version = "0.0.20"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
|
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
|
||||||
repository = "https://github.com/jamjamjon/usls"
|
repository = "https://github.com/jamjamjon/usls"
|
||||||
|
|
@ -12,7 +12,7 @@ exclude = ["assets/*", "examples/*", "scripts/*", "runs/*"]
|
||||||
[dependencies]
|
[dependencies]
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
ndarray = { version = "0.16.1", features = ["rayon"] }
|
ndarray = { version = "0.16.1", features = ["rayon"] }
|
||||||
ort = { version = "2.0.0-rc.5", default-features = false}
|
ort = { version = "2.0.0-rc.9", default-features = false }
|
||||||
anyhow = { version = "1.0.75" }
|
anyhow = { version = "1.0.75" }
|
||||||
regex = { version = "1.5.4" }
|
regex = { version = "1.5.4" }
|
||||||
rand = { version = "0.8.5" }
|
rand = { version = "0.8.5" }
|
||||||
|
|
@ -30,7 +30,7 @@ imageproc = { version = "0.24" }
|
||||||
ab_glyph = "0.2.23"
|
ab_glyph = "0.2.23"
|
||||||
geo = "0.28.0"
|
geo = "0.28.0"
|
||||||
prost = "0.12.4"
|
prost = "0.12.4"
|
||||||
fast_image_resize = { version = "4.2.1", features = ["image"]}
|
fast_image_resize = { version = "4.2.1", features = ["image"] }
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
tempfile = "3.12.0"
|
tempfile = "3.12.0"
|
||||||
|
|
@ -50,7 +50,6 @@ default = [
|
||||||
"ort/cuda",
|
"ort/cuda",
|
||||||
"ort/tensorrt",
|
"ort/tensorrt",
|
||||||
"ort/coreml",
|
"ort/coreml",
|
||||||
"ort/operator-libraries"
|
|
||||||
]
|
]
|
||||||
auto = ["ort/download-binaries"]
|
auto = ["ort/download-binaries"]
|
||||||
|
|
||||||
|
|
|
||||||
145
README.md
145
README.md
|
|
@ -3,7 +3,7 @@
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
| <a href="https://docs.rs/usls"><strong>Documentation</strong></a> |
|
<a href="https://docs.rs/usls"><strong>Documentation</strong></a>
|
||||||
<br>
|
<br>
|
||||||
<br>
|
<br>
|
||||||
<a href='https://github.com/microsoft/onnxruntime/releases'>
|
<a href='https://github.com/microsoft/onnxruntime/releases'>
|
||||||
|
|
@ -34,9 +34,9 @@
|
||||||
|
|
||||||
**`usls`** is a Rust library integrated with **ONNXRuntime** that provides a collection of state-of-the-art models for **Computer Vision** and **Vision-Language** tasks, including:
|
**`usls`** is a Rust library integrated with **ONNXRuntime** that provides a collection of state-of-the-art models for **Computer Vision** and **Vision-Language** tasks, including:
|
||||||
|
|
||||||
- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10)
|
- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [YOLOv11](https://github.com/ultralytics/ultralytics)
|
||||||
- **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
|
- **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
|
||||||
- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569)
|
- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569), [DepthPro](https://github.com/apple/ml-depth-pro)
|
||||||
- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242)
|
- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242)
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|
@ -51,7 +51,8 @@
|
||||||
| [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
| [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [YOLOv8](https://github.com/ultralytics/ultralytics) | Object Detection<br>Instance Segmentation<br>Classification<br>Oriented Object Detection<br>Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
| [YOLOv8](https://github.com/ultralytics/ultralytics) | Object Detection<br>Instance Segmentation<br>Classification<br>Oriented Object Detection<br>Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
| [YOLOv11](https://github.com/ultralytics/ultralytics) | Object Detection<br>Instance Segmentation<br>Classification<br>Oriented Object Detection<br>Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [RTDETR](https://arxiv.org/abs/2304.08069) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
| [RTDETR](https://arxiv.org/abs/2304.08069) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | | |
|
| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | | |
|
||||||
|
|
@ -67,11 +68,12 @@
|
||||||
| [SVTR](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | ✅ | ✅ | ✅ | ✅ |
|
| [SVTR](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | ✅ | ✅ | ❌ | ❌ |
|
| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | ✅ | ✅ | ❌ | ❌ |
|
||||||
| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | ✅ | ✅ | ✅ | ✅ |
|
| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [Depth-Anything](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ |
|
| [Depth-Anything v1 & v2](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ |
|
||||||
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ |
|
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | |
|
| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | |
|
||||||
| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Body Part Segmentation | [demo](examples/sapiens) | ✅ | ✅ | | |
|
| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Body Part Segmentation | [demo](examples/sapiens) | ✅ | ✅ | | |
|
||||||
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | | |
|
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | | |
|
||||||
|
| [DepthPro](https://github.com/apple/ml-depth-pro) | Monocular Depth Estimation | [demo](examples/depth-pro) | ✅ | ✅ | | |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -80,7 +82,8 @@
|
||||||
|
|
||||||
## ⛳️ ONNXRuntime Linking
|
## ⛳️ ONNXRuntime Linking
|
||||||
|
|
||||||
You have two options to link the ONNXRuntime library
|
<details>
|
||||||
|
<summary>You have two options to link the ONNXRuntime library</summary>
|
||||||
|
|
||||||
- ### Option 1: Manual Linking
|
- ### Option 1: Manual Linking
|
||||||
|
|
||||||
|
|
@ -99,6 +102,7 @@ You have two options to link the ONNXRuntime library
|
||||||
cargo run -r --example yolo --features auto
|
cargo run -r --example yolo --features auto
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
## 🎈 Demo
|
## 🎈 Demo
|
||||||
|
|
||||||
|
|
@ -123,70 +127,95 @@ cargo run -r --example yolo # blip, clip, yolop, svtr, db, ...
|
||||||
- Build model with the provided `models` and `Options`
|
- Build model with the provided `models` and `Options`
|
||||||
- Load images, video and stream with `DataLoader`
|
- Load images, video and stream with `DataLoader`
|
||||||
- Do inference
|
- Do inference
|
||||||
- Annotate inference results with `Annotator`
|
|
||||||
- Retrieve inference results from `Vec<Y>`
|
- Retrieve inference results from `Vec<Y>`
|
||||||
|
- Annotate inference results with `Annotator`
|
||||||
|
- Display images and write them to video with `Viewer`
|
||||||
|
|
||||||
```rust
|
<br/>
|
||||||
use usls::{models::YOLO, Annotator, DataLoader, Nms, Options, Vision, YOLOTask, YOLOVersion};
|
<details>
|
||||||
|
<summary>example code</summary>
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
```rust
|
||||||
// Build model with Options
|
use usls::{models::YOLO, Annotator, DataLoader, Nms, Options, Vision, YOLOTask, YOLOVersion};
|
||||||
let options = Options::new()
|
|
||||||
.with_trt(0)
|
|
||||||
.with_model("yolo/v8-m-dyn.onnx")?
|
|
||||||
.with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
|
|
||||||
.with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb
|
|
||||||
.with_i00((1, 2, 4).into())
|
|
||||||
.with_i02((0, 640, 640).into())
|
|
||||||
.with_i03((0, 640, 640).into())
|
|
||||||
.with_confs(&[0.2]);
|
|
||||||
let mut model = YOLO::new(options)?;
|
|
||||||
|
|
||||||
// Build DataLoader to load image(s), video, stream
|
fn main() -> anyhow::Result<()> {
|
||||||
let dl = DataLoader::new(
|
// Build model with Options
|
||||||
// "./assets/bus.jpg", // local image
|
let options = Options::new()
|
||||||
// "images/bus.jpg", // remote image
|
.with_trt(0)
|
||||||
// "../images-folder", // local images (from folder)
|
.with_model("yolo/v8-m-dyn.onnx")?
|
||||||
// "../demo.mp4", // local video
|
.with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
|
||||||
// "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // online video
|
.with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb
|
||||||
"rtsp://admin:kkasd1234@192.168.2.217:554/h264/ch1/", // stream
|
.with_ixx(0, 0, (1, 2, 4).into())
|
||||||
)?
|
.with_ixx(0, 2, (0, 640, 640).into())
|
||||||
.with_batch(2) // iterate with batch_size = 2
|
.with_ixx(0, 3, (0, 640, 640).into())
|
||||||
.build()?;
|
.with_confs(&[0.2]);
|
||||||
|
let mut model = YOLO::new(options)?;
|
||||||
|
|
||||||
// Build annotator
|
// Build DataLoader to load image(s), video, stream
|
||||||
let annotator = Annotator::new()
|
let dl = DataLoader::new(
|
||||||
.with_bboxes_thickness(4)
|
// "./assets/bus.jpg", // local image
|
||||||
.with_saveout("YOLO-DataLoader");
|
// "images/bus.jpg", // remote image
|
||||||
|
// "../images-folder", // local images (from folder)
|
||||||
|
// "../demo.mp4", // local video
|
||||||
|
// "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // online video
|
||||||
|
"rtsp://admin:kkasd1234@192.168.2.217:554/h264/ch1/", // stream
|
||||||
|
)?
|
||||||
|
.with_batch(2) // iterate with batch_size = 2
|
||||||
|
.build()?;
|
||||||
|
|
||||||
// Run and annotate results
|
// Build annotator
|
||||||
for (xs, _) in dl {
|
let annotator = Annotator::new()
|
||||||
let ys = model.forward(&xs, false)?;
|
.with_bboxes_thickness(4)
|
||||||
annotator.annotate(&xs, &ys);
|
.with_saveout("YOLO-DataLoader");
|
||||||
|
|
||||||
// Retrieve inference results
|
// Build viewer
|
||||||
for y in ys {
|
let mut viewer = Viewer::new().with_delay(10).with_scale(1.).resizable(true);
|
||||||
// bboxes
|
|
||||||
if let Some(bboxes) = y.bboxes() {
|
// Run and annotate results
|
||||||
for bbox in bboxes {
|
for (xs, _) in dl {
|
||||||
println!(
|
let ys = model.forward(&xs, false)?;
|
||||||
"Bbox: {}, {}, {}, {}, {}, {}",
|
// annotator.annotate(&xs, &ys);
|
||||||
bbox.xmin(),
|
let images_plotted = annotator.plot(&xs, &ys, false)?;
|
||||||
bbox.ymin(),
|
|
||||||
bbox.xmax(),
|
// show image
|
||||||
bbox.ymax(),
|
viewer.imshow(&images_plotted)?;
|
||||||
bbox.confidence(),
|
|
||||||
bbox.id(),
|
// check out window and key event
|
||||||
);
|
if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) {
|
||||||
}
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// write video
|
||||||
|
viewer.write_batch(&images_plotted)?;
|
||||||
|
|
||||||
|
// Retrieve inference results
|
||||||
|
for y in ys {
|
||||||
|
// bboxes
|
||||||
|
if let Some(bboxes) = y.bboxes() {
|
||||||
|
for bbox in bboxes {
|
||||||
|
println!(
|
||||||
|
"Bbox: {}, {}, {}, {}, {}, {}",
|
||||||
|
bbox.xmin(),
|
||||||
|
bbox.ymin(),
|
||||||
|
bbox.xmax(),
|
||||||
|
bbox.ymax(),
|
||||||
|
bbox.confidence(),
|
||||||
|
bbox.id(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
```
|
|
||||||
|
|
||||||
|
// finish video write
|
||||||
|
viewer.finish_write()?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
</br>
|
||||||
|
|
||||||
## 📌 License
|
## 📌 License
|
||||||
This project is licensed under [LICENSE](LICENSE).
|
This project is licensed under [LICENSE](LICENSE).
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
use usls::{models::DepthPro, Annotator, DataLoader, Options};
|
||||||
|
|
||||||
|
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
// options
|
||||||
|
let options = Options::default()
|
||||||
|
.with_model("depth-pro/q4f16.onnx")? // bnb4, f16
|
||||||
|
.with_ixx(0, 0, 1.into()) // batch. Note: now only support batch_size = 1
|
||||||
|
.with_ixx(0, 1, 3.into()) // channel
|
||||||
|
.with_ixx(0, 2, 1536.into()) // height
|
||||||
|
.with_ixx(0, 3, 1536.into()); // width
|
||||||
|
let mut model = DepthPro::new(options)?;
|
||||||
|
|
||||||
|
// load
|
||||||
|
let x = [DataLoader::try_read("images/street.jpg")?];
|
||||||
|
|
||||||
|
// run
|
||||||
|
let y = model.run(&x)?;
|
||||||
|
|
||||||
|
// annotate
|
||||||
|
let annotator = Annotator::default()
|
||||||
|
.with_colormap("Turbo")
|
||||||
|
.with_saveout("Depth-Pro");
|
||||||
|
annotator.annotate(&x, &y);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
@ -24,45 +24,41 @@
|
||||||
## Quick Start
|
## Quick Start
|
||||||
```Shell
|
```Shell
|
||||||
|
|
||||||
|
# customized
|
||||||
|
cargo run -r --example yolo -- --task detect --ver v8 --nc 6 --model xxx.onnx # YOLOv8
|
||||||
|
|
||||||
# Classify
|
# Classify
|
||||||
cargo run -r --example yolo -- --task classify --ver v5 # YOLOv5
|
cargo run -r --example yolo -- --task classify --ver v5 --scale s --width 224 --height 224 --nc 1000 # YOLOv5
|
||||||
cargo run -r --example yolo -- --task classify --ver v8 # YOLOv8
|
cargo run -r --example yolo -- --task classify --ver v8 --scale n --width 224 --height 224 --nc 1000 # YOLOv8
|
||||||
|
cargo run -r --example yolo -- --task classify --ver v11 --scale n --width 224 --height 224 --nc 1000 # YOLOv11
|
||||||
|
|
||||||
# Detect
|
# Detect
|
||||||
cargo run -r --example yolo -- --task detect --ver v5 # YOLOv5
|
cargo run -r --example yolo -- --task detect --ver v5 --scale n # YOLOv5
|
||||||
cargo run -r --example yolo -- --task detect --ver v6 # YOLOv6
|
cargo run -r --example yolo -- --task detect --ver v6 --scale n # YOLOv6
|
||||||
cargo run -r --example yolo -- --task detect --ver v7 # YOLOv7
|
cargo run -r --example yolo -- --task detect --ver v7 --scale t # YOLOv7
|
||||||
cargo run -r --example yolo -- --task detect --ver v8 # YOLOv8
|
cargo run -r --example yolo -- --task detect --ver v8 --scale n # YOLOv8
|
||||||
cargo run -r --example yolo -- --task detect --ver v9 # YOLOv9
|
cargo run -r --example yolo -- --task detect --ver v9 --scale t # YOLOv9
|
||||||
cargo run -r --example yolo -- --task detect --ver v10 # YOLOv10
|
cargo run -r --example yolo -- --task detect --ver v10 --scale n # YOLOv10
|
||||||
cargo run -r --example yolo -- --task detect --ver rtdetr # YOLOv8-RTDETR
|
cargo run -r --example yolo -- --task detect --ver v11 --scale n # YOLOv11
|
||||||
cargo run -r --example yolo -- --task detect --ver v8 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world
|
cargo run -r --example yolo -- --task detect --ver rtdetr --scale l # RTDETR
|
||||||
|
cargo run -r --example yolo -- --task detect --ver v8 --model yolo/v8-s-world-v2-shoes.onnx # YOLOv8-world
|
||||||
|
|
||||||
# Pose
|
# Pose
|
||||||
cargo run -r --example yolo -- --task pose --ver v8 # YOLOv8-Pose
|
cargo run -r --example yolo -- --task pose --ver v8 --scale n # YOLOv8-Pose
|
||||||
|
cargo run -r --example yolo -- --task pose --ver v11 --scale n # YOLOv11-Pose
|
||||||
|
|
||||||
# Segment
|
# Segment
|
||||||
cargo run -r --example yolo -- --task segment --ver v5 # YOLOv5-Segment
|
cargo run -r --example yolo -- --task segment --ver v5 --scale n # YOLOv5-Segment
|
||||||
cargo run -r --example yolo -- --task segment --ver v8 # YOLOv8-Segment
|
cargo run -r --example yolo -- --task segment --ver v8 --scale n # YOLOv8-Segment
|
||||||
cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM
|
cargo run -r --example yolo -- --task segment --ver v11 --scale n # YOLOv8-Segment
|
||||||
|
cargo run -r --example yolo -- --task segment --ver v8 --model yolo/FastSAM-s-dyn-f16.onnx # FastSAM
|
||||||
|
|
||||||
# Obb
|
# Obb
|
||||||
cargo run -r --example yolo -- --task obb --ver v8 # YOLOv8-Obb
|
cargo run -r --example yolo -- --ver v8 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv8-Obb
|
||||||
|
cargo run -r --example yolo -- --ver v11 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv11-Obb
|
||||||
```
|
```
|
||||||
|
|
||||||
<details close>
|
**`cargo run -r --example yolo -- --help` for more options**
|
||||||
<summary>other options</summary>
|
|
||||||
|
|
||||||
`--source` to specify the input images
|
|
||||||
`--model` to specify the ONNX model
|
|
||||||
`--width --height` to specify the input resolution
|
|
||||||
`--nc` to specify the number of model's classes
|
|
||||||
`--plot` to annotate with inference results
|
|
||||||
`--profile` to profile
|
|
||||||
`--cuda --trt --coreml --device_id` to select device
|
|
||||||
`--half` to use float16 when using TensorRT EP
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
|
|
||||||
## YOLOs configs with `Options`
|
## YOLOs configs with `Options`
|
||||||
|
|
@ -96,6 +92,8 @@ let options = Options::default()
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
// .with_nc(80)
|
||||||
|
// .with_names(&COCO_CLASS_NAMES_80)
|
||||||
.with_model("xxxx.onnx")?;
|
.with_model("xxxx.onnx")?;
|
||||||
```
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
@ -140,7 +138,7 @@ let options = Options::default()
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details close>
|
<details close>
|
||||||
<summary>YOLOv8</summary>
|
<summary>YOLOv8, YOLOv11</summary>
|
||||||
|
|
||||||
```Shell
|
```Shell
|
||||||
pip install -U ultralytics
|
pip install -U ultralytics
|
||||||
|
|
|
||||||
|
|
@ -2,188 +2,169 @@ use anyhow::Result;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use usls::{
|
use usls::{
|
||||||
models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion, COCO_KEYPOINTS_17,
|
models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOScale, YOLOTask,
|
||||||
COCO_SKELETONS_16,
|
YOLOVersion, COCO_SKELETONS_16,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Parser, Clone)]
|
#[derive(Parser, Clone)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
pub struct Args {
|
pub struct Args {
|
||||||
|
/// Path to the model
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub model: Option<String>,
|
pub model: Option<String>,
|
||||||
|
|
||||||
|
/// Input source path
|
||||||
#[arg(long, default_value_t = String::from("./assets/bus.jpg"))]
|
#[arg(long, default_value_t = String::from("./assets/bus.jpg"))]
|
||||||
pub source: String,
|
pub source: String,
|
||||||
|
|
||||||
|
/// YOLO Task
|
||||||
#[arg(long, value_enum, default_value_t = YOLOTask::Detect)]
|
#[arg(long, value_enum, default_value_t = YOLOTask::Detect)]
|
||||||
pub task: YOLOTask,
|
pub task: YOLOTask,
|
||||||
|
|
||||||
|
/// YOLO Version
|
||||||
#[arg(long, value_enum, default_value_t = YOLOVersion::V8)]
|
#[arg(long, value_enum, default_value_t = YOLOVersion::V8)]
|
||||||
pub ver: YOLOVersion,
|
pub ver: YOLOVersion,
|
||||||
|
|
||||||
|
/// YOLO Scale
|
||||||
|
#[arg(long, value_enum, default_value_t = YOLOScale::N)]
|
||||||
|
pub scale: YOLOScale,
|
||||||
|
|
||||||
|
/// Batch size
|
||||||
#[arg(long, default_value_t = 1)]
|
#[arg(long, default_value_t = 1)]
|
||||||
pub batch_size: usize,
|
pub batch_size: usize,
|
||||||
|
|
||||||
|
/// Minimum input width
|
||||||
#[arg(long, default_value_t = 224)]
|
#[arg(long, default_value_t = 224)]
|
||||||
pub width_min: isize,
|
pub width_min: isize,
|
||||||
|
|
||||||
|
/// Input width
|
||||||
#[arg(long, default_value_t = 640)]
|
#[arg(long, default_value_t = 640)]
|
||||||
pub width: isize,
|
pub width: isize,
|
||||||
|
|
||||||
#[arg(long, default_value_t = 800)]
|
/// Maximum input width
|
||||||
|
#[arg(long, default_value_t = 1024)]
|
||||||
pub width_max: isize,
|
pub width_max: isize,
|
||||||
|
|
||||||
|
/// Minimum input height
|
||||||
#[arg(long, default_value_t = 224)]
|
#[arg(long, default_value_t = 224)]
|
||||||
pub height_min: isize,
|
pub height_min: isize,
|
||||||
|
|
||||||
|
/// Input height
|
||||||
#[arg(long, default_value_t = 640)]
|
#[arg(long, default_value_t = 640)]
|
||||||
pub height: isize,
|
pub height: isize,
|
||||||
|
|
||||||
#[arg(long, default_value_t = 800)]
|
/// Maximum input height
|
||||||
|
#[arg(long, default_value_t = 1024)]
|
||||||
pub height_max: isize,
|
pub height_max: isize,
|
||||||
|
|
||||||
|
/// Number of classes
|
||||||
#[arg(long, default_value_t = 80)]
|
#[arg(long, default_value_t = 80)]
|
||||||
pub nc: usize,
|
pub nc: usize,
|
||||||
|
|
||||||
|
/// Class confidence
|
||||||
|
#[arg(long)]
|
||||||
|
pub confs: Vec<f32>,
|
||||||
|
|
||||||
|
/// Enable TensorRT support
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub trt: bool,
|
pub trt: bool,
|
||||||
|
|
||||||
|
/// Enable CUDA support
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub cuda: bool,
|
pub cuda: bool,
|
||||||
|
|
||||||
#[arg(long)]
|
/// Enable CoreML support
|
||||||
pub half: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub coreml: bool,
|
pub coreml: bool,
|
||||||
|
|
||||||
|
/// Use TensorRT half precision
|
||||||
|
#[arg(long)]
|
||||||
|
pub half: bool,
|
||||||
|
|
||||||
|
/// Device ID to use
|
||||||
#[arg(long, default_value_t = 0)]
|
#[arg(long, default_value_t = 0)]
|
||||||
pub device_id: usize,
|
pub device_id: usize,
|
||||||
|
|
||||||
|
/// Enable performance profiling
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub profile: bool,
|
pub profile: bool,
|
||||||
|
|
||||||
#[arg(long)]
|
/// Disable contour drawing
|
||||||
pub no_plot: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
pub no_contours: bool,
|
pub no_contours: bool,
|
||||||
|
|
||||||
|
/// Show result
|
||||||
|
#[arg(long)]
|
||||||
|
pub view: bool,
|
||||||
|
|
||||||
|
/// Do not save output
|
||||||
|
#[arg(long)]
|
||||||
|
pub nosave: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
// build options
|
// model path
|
||||||
let options = Options::default();
|
let path = match &args.model {
|
||||||
|
None => format!(
|
||||||
// version & task
|
"yolo/{}-{}-{}.onnx",
|
||||||
let (options, saveout) = match args.ver {
|
args.ver.name(),
|
||||||
YOLOVersion::V5 => match args.task {
|
args.scale.name(),
|
||||||
YOLOTask::Classify => (
|
args.task.name()
|
||||||
options.with_model(&args.model.unwrap_or("yolo/v5-n-cls-dyn.onnx".to_string()))?,
|
),
|
||||||
"YOLOv5-Classify",
|
Some(x) => x.to_string(),
|
||||||
),
|
|
||||||
YOLOTask::Detect => (
|
|
||||||
options.with_model(&args.model.unwrap_or("yolo/v5-n-dyn.onnx".to_string()))?,
|
|
||||||
"YOLOv5-Detect",
|
|
||||||
),
|
|
||||||
YOLOTask::Segment => (
|
|
||||||
options.with_model(&args.model.unwrap_or("yolo/v5-n-seg-dyn.onnx".to_string()))?,
|
|
||||||
"YOLOv5-Segment",
|
|
||||||
),
|
|
||||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
|
||||||
},
|
|
||||||
YOLOVersion::V6 => match args.task {
|
|
||||||
YOLOTask::Detect => (
|
|
||||||
options
|
|
||||||
.with_model(&args.model.unwrap_or("yolo/v6-n-dyn.onnx".to_string()))?
|
|
||||||
.with_nc(args.nc),
|
|
||||||
"YOLOv6-Detect",
|
|
||||||
),
|
|
||||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
|
||||||
},
|
|
||||||
YOLOVersion::V7 => match args.task {
|
|
||||||
YOLOTask::Detect => (
|
|
||||||
options
|
|
||||||
.with_model(&args.model.unwrap_or("yolo/v7-tiny-dyn.onnx".to_string()))?
|
|
||||||
.with_nc(args.nc),
|
|
||||||
"YOLOv7-Detect",
|
|
||||||
),
|
|
||||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
|
||||||
},
|
|
||||||
YOLOVersion::V8 => match args.task {
|
|
||||||
YOLOTask::Classify => (
|
|
||||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-cls-dyn.onnx".to_string()))?,
|
|
||||||
"YOLOv8-Classify",
|
|
||||||
),
|
|
||||||
YOLOTask::Detect => (
|
|
||||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-dyn.onnx".to_string()))?,
|
|
||||||
"YOLOv8-Detect",
|
|
||||||
),
|
|
||||||
YOLOTask::Segment => (
|
|
||||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-seg-dyn.onnx".to_string()))?,
|
|
||||||
"YOLOv8-Segment",
|
|
||||||
),
|
|
||||||
YOLOTask::Pose => (
|
|
||||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-pose-dyn.onnx".to_string()))?,
|
|
||||||
"YOLOv8-Pose",
|
|
||||||
),
|
|
||||||
YOLOTask::Obb => (
|
|
||||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-obb-dyn.onnx".to_string()))?,
|
|
||||||
"YOLOv8-Obb",
|
|
||||||
),
|
|
||||||
},
|
|
||||||
YOLOVersion::V9 => match args.task {
|
|
||||||
YOLOTask::Detect => (
|
|
||||||
options.with_model(&args.model.unwrap_or("yolo/v9-c-dyn-f16.onnx".to_string()))?,
|
|
||||||
"YOLOv9-Detect",
|
|
||||||
),
|
|
||||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
|
||||||
},
|
|
||||||
YOLOVersion::V10 => match args.task {
|
|
||||||
YOLOTask::Detect => (
|
|
||||||
options.with_model(&args.model.unwrap_or("yolo/v10-n.onnx".to_string()))?,
|
|
||||||
"YOLOv10-Detect",
|
|
||||||
),
|
|
||||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
|
||||||
},
|
|
||||||
YOLOVersion::RTDETR => match args.task {
|
|
||||||
YOLOTask::Detect => (
|
|
||||||
options.with_model(&args.model.unwrap_or("yolo/rtdetr-l-f16.onnx".to_string()))?,
|
|
||||||
"RTDETR",
|
|
||||||
),
|
|
||||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let options = options
|
// saveout
|
||||||
.with_yolo_version(args.ver)
|
let saveout = match &args.model {
|
||||||
.with_yolo_task(args.task);
|
None => format!(
|
||||||
|
"{}-{}-{}",
|
||||||
|
args.ver.name(),
|
||||||
|
args.scale.name(),
|
||||||
|
args.task.name()
|
||||||
|
),
|
||||||
|
Some(x) => {
|
||||||
|
let p = std::path::PathBuf::from(&x);
|
||||||
|
p.file_stem().unwrap().to_str().unwrap().to_string()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// device
|
// device
|
||||||
let options = if args.cuda {
|
let device = if args.cuda {
|
||||||
options.with_cuda(args.device_id)
|
Device::Cuda(args.device_id)
|
||||||
} else if args.trt {
|
} else if args.trt {
|
||||||
let options = options.with_trt(args.device_id);
|
Device::Trt(args.device_id)
|
||||||
if args.half {
|
|
||||||
options.with_trt_fp16(true)
|
|
||||||
} else {
|
|
||||||
options
|
|
||||||
}
|
|
||||||
} else if args.coreml {
|
} else if args.coreml {
|
||||||
options.with_coreml(args.device_id)
|
Device::CoreML(args.device_id)
|
||||||
} else {
|
} else {
|
||||||
options.with_cpu()
|
Device::Cpu(args.device_id)
|
||||||
};
|
};
|
||||||
let options = options
|
|
||||||
|
// build options
|
||||||
|
let options = Options::new()
|
||||||
|
.with_model(&path)?
|
||||||
|
.with_yolo_version(args.ver)
|
||||||
|
.with_yolo_task(args.task)
|
||||||
|
.with_device(device)
|
||||||
|
.with_trt_fp16(args.half)
|
||||||
.with_ixx(0, 0, (1, args.batch_size as _, 4).into())
|
.with_ixx(0, 0, (1, args.batch_size as _, 4).into())
|
||||||
.with_ixx(0, 2, (args.height_min, args.height, args.height_max).into())
|
.with_ixx(0, 2, (args.height_min, args.height, args.height_max).into())
|
||||||
.with_ixx(0, 3, (args.width_min, args.width, args.width_max).into())
|
.with_ixx(0, 3, (args.width_min, args.width, args.width_max).into())
|
||||||
.with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15
|
.with_confs(if args.confs.is_empty() {
|
||||||
|
&[0.2, 0.15]
|
||||||
|
} else {
|
||||||
|
&args.confs
|
||||||
|
})
|
||||||
|
.with_nc(args.nc)
|
||||||
// .with_names(&COCO_CLASS_NAMES_80)
|
// .with_names(&COCO_CLASS_NAMES_80)
|
||||||
.with_names2(&COCO_KEYPOINTS_17)
|
// .with_names2(&COCO_KEYPOINTS_17)
|
||||||
.with_find_contours(!args.no_contours) // find contours or not
|
.with_find_contours(!args.no_contours) // find contours or not
|
||||||
|
.exclude_classes(&[0])
|
||||||
|
// .retain_classes(&[0, 5])
|
||||||
.with_profile(args.profile);
|
.with_profile(args.profile);
|
||||||
|
|
||||||
|
// build model
|
||||||
let mut model = YOLO::new(options)?;
|
let mut model = YOLO::new(options)?;
|
||||||
|
|
||||||
// build dataloader
|
// build dataloader
|
||||||
|
|
@ -194,16 +175,54 @@ fn main() -> Result<()> {
|
||||||
// build annotator
|
// build annotator
|
||||||
let annotator = Annotator::default()
|
let annotator = Annotator::default()
|
||||||
.with_skeletons(&COCO_SKELETONS_16)
|
.with_skeletons(&COCO_SKELETONS_16)
|
||||||
.with_bboxes_thickness(4)
|
|
||||||
.without_masks(true) // No masks plotting when doing segment task.
|
.without_masks(true) // No masks plotting when doing segment task.
|
||||||
.with_saveout(saveout);
|
.with_bboxes_thickness(3)
|
||||||
|
.with_keypoints_name(false) // Enable keypoints names
|
||||||
|
.with_saveout_subs(&["YOLO"])
|
||||||
|
.with_saveout(&saveout);
|
||||||
|
|
||||||
|
// build viewer
|
||||||
|
let mut viewer = if args.view {
|
||||||
|
Some(Viewer::new().with_delay(5).with_scale(1.).resizable(true))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
// run & annotate
|
// run & annotate
|
||||||
for (xs, _paths) in dl {
|
for (xs, _paths) in dl {
|
||||||
// let ys = model.run(&xs)?; // way one
|
// let ys = model.run(&xs)?; // way one
|
||||||
let ys = model.forward(&xs, args.profile)?; // way two
|
let ys = model.forward(&xs, args.profile)?; // way two
|
||||||
if !args.no_plot {
|
let images_plotted = annotator.plot(&xs, &ys, !args.nosave)?;
|
||||||
annotator.annotate(&xs, &ys);
|
|
||||||
|
// show image
|
||||||
|
match &mut viewer {
|
||||||
|
Some(viewer) => viewer.imshow(&images_plotted)?,
|
||||||
|
None => continue,
|
||||||
|
}
|
||||||
|
|
||||||
|
// check out window and key event
|
||||||
|
match &mut viewer {
|
||||||
|
Some(viewer) => {
|
||||||
|
if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => continue,
|
||||||
|
}
|
||||||
|
|
||||||
|
// write video
|
||||||
|
if !args.nosave {
|
||||||
|
match &mut viewer {
|
||||||
|
Some(viewer) => viewer.write_batch(&images_plotted)?,
|
||||||
|
None => continue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// finish video write
|
||||||
|
if !args.nosave {
|
||||||
|
if let Some(viewer) = &mut viewer {
|
||||||
|
viewer.finish_write()?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,8 @@ pub struct Options {
|
||||||
pub sam_kind: Option<SamKind>,
|
pub sam_kind: Option<SamKind>,
|
||||||
pub use_low_res_mask: Option<bool>,
|
pub use_low_res_mask: Option<bool>,
|
||||||
pub sapiens_task: Option<SapiensTask>,
|
pub sapiens_task: Option<SapiensTask>,
|
||||||
|
pub classes_excluded: Vec<isize>,
|
||||||
|
pub classes_retained: Vec<isize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Options {
|
impl Default for Options {
|
||||||
|
|
@ -88,6 +90,8 @@ impl Default for Options {
|
||||||
use_low_res_mask: None,
|
use_low_res_mask: None,
|
||||||
sapiens_task: None,
|
sapiens_task: None,
|
||||||
task: Task::Untitled,
|
task: Task::Untitled,
|
||||||
|
classes_excluded: vec![],
|
||||||
|
classes_retained: vec![],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -276,4 +280,16 @@ impl Options {
|
||||||
self.iiixs.push(Iiix::from((i, ii, x)));
|
self.iiixs.push(Iiix::from((i, ii, x)));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn exclude_classes(mut self, xs: &[isize]) -> Self {
|
||||||
|
self.classes_retained.clear();
|
||||||
|
self.classes_excluded.extend_from_slice(xs);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn retain_classes(mut self, xs: &[isize]) -> Self {
|
||||||
|
self.classes_excluded.clear();
|
||||||
|
self.classes_retained.extend_from_slice(xs);
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,9 @@ use anyhow::Result;
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use ndarray::{Array, IxDyn};
|
use ndarray::{Array, IxDyn};
|
||||||
use ort::{
|
use ort::{
|
||||||
ExecutionProvider, Session, SessionBuilder, TensorElementType, TensorRTExecutionProvider,
|
execution_providers::{ExecutionProvider, TensorRTExecutionProvider},
|
||||||
|
session::{builder::SessionBuilder, Session},
|
||||||
|
tensor::TensorElementType,
|
||||||
};
|
};
|
||||||
use prost::Message;
|
use prost::Message;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
@ -88,14 +90,14 @@ impl OrtEngine {
|
||||||
|
|
||||||
// build
|
// build
|
||||||
ort::init().commit()?;
|
ort::init().commit()?;
|
||||||
let builder = Session::builder()?;
|
let mut builder = Session::builder()?;
|
||||||
let mut device = config.device.to_owned();
|
let mut device = config.device.to_owned();
|
||||||
match device {
|
match device {
|
||||||
Device::Trt(device_id) => {
|
Device::Trt(device_id) => {
|
||||||
Self::build_trt(
|
Self::build_trt(
|
||||||
&inputs_attrs.names,
|
&inputs_attrs.names,
|
||||||
&inputs_minoptmax,
|
&inputs_minoptmax,
|
||||||
&builder,
|
&mut builder,
|
||||||
device_id,
|
device_id,
|
||||||
config.trt_int8_enable,
|
config.trt_int8_enable,
|
||||||
config.trt_fp16_enable,
|
config.trt_fp16_enable,
|
||||||
|
|
@ -103,23 +105,23 @@ impl OrtEngine {
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
Device::Cuda(device_id) => {
|
Device::Cuda(device_id) => {
|
||||||
Self::build_cuda(&builder, device_id).unwrap_or_else(|err| {
|
Self::build_cuda(&mut builder, device_id).unwrap_or_else(|err| {
|
||||||
tracing::warn!("{err}, Using cpu");
|
tracing::warn!("{err}, Using cpu");
|
||||||
device = Device::Cpu(0);
|
device = Device::Cpu(0);
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Device::CoreML(_) => Self::build_coreml(&builder).unwrap_or_else(|err| {
|
Device::CoreML(_) => Self::build_coreml(&mut builder).unwrap_or_else(|err| {
|
||||||
tracing::warn!("{err}, Using cpu");
|
tracing::warn!("{err}, Using cpu");
|
||||||
device = Device::Cpu(0);
|
device = Device::Cpu(0);
|
||||||
}),
|
}),
|
||||||
Device::Cpu(_) => {
|
Device::Cpu(_) => {
|
||||||
Self::build_cpu(&builder)?;
|
Self::build_cpu(&mut builder)?;
|
||||||
}
|
}
|
||||||
_ => todo!(),
|
_ => todo!(),
|
||||||
}
|
}
|
||||||
|
|
||||||
let session = builder
|
let session = builder
|
||||||
.with_optimization_level(ort::GraphOptimizationLevel::Level3)?
|
.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
|
||||||
.commit_from_file(&config.onnx_path)?;
|
.commit_from_file(&config.onnx_path)?;
|
||||||
|
|
||||||
// summary
|
// summary
|
||||||
|
|
@ -149,7 +151,7 @@ impl OrtEngine {
|
||||||
fn build_trt(
|
fn build_trt(
|
||||||
names: &[String],
|
names: &[String],
|
||||||
inputs_minoptmax: &[Vec<MinOptMax>],
|
inputs_minoptmax: &[Vec<MinOptMax>],
|
||||||
builder: &SessionBuilder,
|
builder: &mut SessionBuilder,
|
||||||
device_id: usize,
|
device_id: usize,
|
||||||
int8_enable: bool,
|
int8_enable: bool,
|
||||||
fp16_enable: bool,
|
fp16_enable: bool,
|
||||||
|
|
@ -205,8 +207,9 @@ impl OrtEngine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_cuda(builder: &SessionBuilder, device_id: usize) -> Result<()> {
|
fn build_cuda(builder: &mut SessionBuilder, device_id: usize) -> Result<()> {
|
||||||
let ep = ort::CUDAExecutionProvider::default().with_device_id(device_id as i32);
|
let ep = ort::execution_providers::CUDAExecutionProvider::default()
|
||||||
|
.with_device_id(device_id as i32);
|
||||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -214,8 +217,8 @@ impl OrtEngine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_coreml(builder: &SessionBuilder) -> Result<()> {
|
fn build_coreml(builder: &mut SessionBuilder) -> Result<()> {
|
||||||
let ep = ort::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
|
let ep = ort::execution_providers::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
|
||||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -223,8 +226,8 @@ impl OrtEngine {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_cpu(builder: &SessionBuilder) -> Result<()> {
|
fn build_cpu(builder: &mut SessionBuilder) -> Result<()> {
|
||||||
let ep = ort::CPUExecutionProvider::default();
|
let ep = ort::execution_providers::CPUExecutionProvider::default();
|
||||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -292,28 +295,28 @@ impl OrtEngine {
|
||||||
let t_pre = std::time::Instant::now();
|
let t_pre = std::time::Instant::now();
|
||||||
for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) {
|
for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) {
|
||||||
let x_ = match &idtype {
|
let x_ = match &idtype {
|
||||||
TensorElementType::Float32 => ort::Value::from_array(x.view())?.into_dyn(),
|
TensorElementType::Float32 => ort::value::Value::from_array(x.view())?.into_dyn(),
|
||||||
TensorElementType::Float16 => {
|
TensorElementType::Float16 => {
|
||||||
ort::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
|
ort::value::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
|
||||||
}
|
}
|
||||||
TensorElementType::Int32 => {
|
TensorElementType::Int32 => {
|
||||||
ort::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
|
ort::value::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
|
||||||
}
|
}
|
||||||
TensorElementType::Int64 => {
|
TensorElementType::Int64 => {
|
||||||
ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
|
ort::value::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
|
||||||
}
|
}
|
||||||
TensorElementType::Uint8 => {
|
TensorElementType::Uint8 => {
|
||||||
ort::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
|
ort::value::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
|
||||||
}
|
}
|
||||||
TensorElementType::Int8 => {
|
TensorElementType::Int8 => {
|
||||||
ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
|
ort::value::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
|
||||||
}
|
}
|
||||||
TensorElementType::Bool => {
|
TensorElementType::Bool => {
|
||||||
ort::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
|
ort::value::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
|
||||||
}
|
}
|
||||||
_ => todo!(),
|
_ => todo!(),
|
||||||
};
|
};
|
||||||
xs_.push(Into::<ort::SessionInputValue<'_>>::into(x_));
|
xs_.push(Into::<ort::session::SessionInputValue<'_>>::into(x_));
|
||||||
}
|
}
|
||||||
let t_pre = t_pre.elapsed();
|
let t_pre = t_pre.elapsed();
|
||||||
self.ts.add_or_push(0, t_pre);
|
self.ts.add_or_push(0, t_pre);
|
||||||
|
|
@ -451,45 +454,45 @@ impl OrtEngine {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn nbytes_from_onnx_dtype(x: &ort::TensorElementType) -> usize {
|
fn nbytes_from_onnx_dtype(x: &ort::tensor::TensorElementType) -> usize {
|
||||||
match x {
|
match x {
|
||||||
ort::TensorElementType::Float64
|
ort::tensor::TensorElementType::Float64
|
||||||
| ort::TensorElementType::Uint64
|
| ort::tensor::TensorElementType::Uint64
|
||||||
| ort::TensorElementType::Int64 => 8, // i64, f64, u64
|
| ort::tensor::TensorElementType::Int64 => 8, // i64, f64, u64
|
||||||
ort::TensorElementType::Float32
|
ort::tensor::TensorElementType::Float32
|
||||||
| ort::TensorElementType::Uint32
|
| ort::tensor::TensorElementType::Uint32
|
||||||
| ort::TensorElementType::Int32
|
| ort::tensor::TensorElementType::Int32
|
||||||
| ort::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
|
| ort::tensor::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
|
||||||
ort::TensorElementType::Float16
|
ort::tensor::TensorElementType::Float16
|
||||||
| ort::TensorElementType::Bfloat16
|
| ort::tensor::TensorElementType::Bfloat16
|
||||||
| ort::TensorElementType::Int16
|
| ort::tensor::TensorElementType::Int16
|
||||||
| ort::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
|
| ort::tensor::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
|
||||||
ort::TensorElementType::Uint8
|
ort::tensor::TensorElementType::Uint8
|
||||||
| ort::TensorElementType::Int8
|
| ort::tensor::TensorElementType::Int8
|
||||||
| ort::TensorElementType::Bool => 1, // u8, i8, bool
|
| ort::tensor::TensorElementType::Bool => 1, // u8, i8, bool
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::TensorElementType> {
|
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::tensor::TensorElementType> {
|
||||||
match value {
|
match value {
|
||||||
0 => None,
|
0 => None,
|
||||||
1 => Some(ort::TensorElementType::Float32),
|
1 => Some(ort::tensor::TensorElementType::Float32),
|
||||||
2 => Some(ort::TensorElementType::Uint8),
|
2 => Some(ort::tensor::TensorElementType::Uint8),
|
||||||
3 => Some(ort::TensorElementType::Int8),
|
3 => Some(ort::tensor::TensorElementType::Int8),
|
||||||
4 => Some(ort::TensorElementType::Uint16),
|
4 => Some(ort::tensor::TensorElementType::Uint16),
|
||||||
5 => Some(ort::TensorElementType::Int16),
|
5 => Some(ort::tensor::TensorElementType::Int16),
|
||||||
6 => Some(ort::TensorElementType::Int32),
|
6 => Some(ort::tensor::TensorElementType::Int32),
|
||||||
7 => Some(ort::TensorElementType::Int64),
|
7 => Some(ort::tensor::TensorElementType::Int64),
|
||||||
8 => Some(ort::TensorElementType::String),
|
8 => Some(ort::tensor::TensorElementType::String),
|
||||||
9 => Some(ort::TensorElementType::Bool),
|
9 => Some(ort::tensor::TensorElementType::Bool),
|
||||||
10 => Some(ort::TensorElementType::Float16),
|
10 => Some(ort::tensor::TensorElementType::Float16),
|
||||||
11 => Some(ort::TensorElementType::Float64),
|
11 => Some(ort::tensor::TensorElementType::Float64),
|
||||||
12 => Some(ort::TensorElementType::Uint32),
|
12 => Some(ort::tensor::TensorElementType::Uint32),
|
||||||
13 => Some(ort::TensorElementType::Uint64),
|
13 => Some(ort::tensor::TensorElementType::Uint64),
|
||||||
14 => None, // COMPLEX64
|
14 => None, // COMPLEX64
|
||||||
15 => None, // COMPLEX128
|
15 => None, // COMPLEX128
|
||||||
16 => Some(ort::TensorElementType::Bfloat16),
|
16 => Some(ort::tensor::TensorElementType::Bfloat16),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -499,7 +502,7 @@ impl OrtEngine {
|
||||||
value_info: &[onnx::ValueInfoProto],
|
value_info: &[onnx::ValueInfoProto],
|
||||||
) -> Result<OrtTensorAttr> {
|
) -> Result<OrtTensorAttr> {
|
||||||
let mut dimss: Vec<Vec<usize>> = Vec::new();
|
let mut dimss: Vec<Vec<usize>> = Vec::new();
|
||||||
let mut dtypes: Vec<ort::TensorElementType> = Vec::new();
|
let mut dtypes: Vec<ort::tensor::TensorElementType> = Vec::new();
|
||||||
let mut names: Vec<String> = Vec::new();
|
let mut names: Vec<String> = Vec::new();
|
||||||
for v in value_info.iter() {
|
for v in value_info.iter() {
|
||||||
if initializer_names.contains(v.name.as_str()) {
|
if initializer_names.contains(v.name.as_str()) {
|
||||||
|
|
@ -569,7 +572,7 @@ impl OrtEngine {
|
||||||
&self.outputs_attrs.names
|
&self.outputs_attrs.names
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn odtypes(&self) -> &Vec<ort::TensorElementType> {
|
pub fn odtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
|
||||||
&self.outputs_attrs.dtypes
|
&self.outputs_attrs.dtypes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -585,7 +588,7 @@ impl OrtEngine {
|
||||||
&self.inputs_attrs.names
|
&self.inputs_attrs.names
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn idtypes(&self) -> &Vec<ort::TensorElementType> {
|
pub fn idtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
|
||||||
&self.inputs_attrs.dtypes
|
&self.inputs_attrs.dtypes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,86 @@
|
||||||
|
use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y};
|
||||||
|
use anyhow::Result;
|
||||||
|
use image::DynamicImage;
|
||||||
|
use ndarray::Axis;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct DepthPro {
|
||||||
|
engine: OrtEngine,
|
||||||
|
height: MinOptMax,
|
||||||
|
width: MinOptMax,
|
||||||
|
batch: MinOptMax,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DepthPro {
|
||||||
|
pub fn new(options: Options) -> Result<Self> {
|
||||||
|
let mut engine = OrtEngine::new(&options)?;
|
||||||
|
let (batch, height, width) = (
|
||||||
|
engine.batch().clone(),
|
||||||
|
engine.height().clone(),
|
||||||
|
engine.width().clone(),
|
||||||
|
);
|
||||||
|
engine.dry_run()?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
engine,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
batch,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||||
|
let xs_ = X::apply(&[
|
||||||
|
Ops::Resize(
|
||||||
|
xs,
|
||||||
|
self.height.opt() as u32,
|
||||||
|
self.width.opt() as u32,
|
||||||
|
"Bilinear",
|
||||||
|
),
|
||||||
|
Ops::Normalize(0., 255.),
|
||||||
|
Ops::Standardize(&[0.5, 0.5, 0.5], &[0.5, 0.5, 0.5], 3),
|
||||||
|
Ops::Nhwc2nchw,
|
||||||
|
])?;
|
||||||
|
let ys = self.engine.run(Xs::from(xs_))?;
|
||||||
|
|
||||||
|
self.postprocess(ys, xs)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||||
|
let (predicted_depth, _focallength_px) = (&xs["predicted_depth"], &xs["focallength_px"]);
|
||||||
|
let predicted_depth = predicted_depth.mapv(|x| 1. / x);
|
||||||
|
|
||||||
|
let mut ys: Vec<Y> = Vec::new();
|
||||||
|
for (idx, luma) in predicted_depth.axis_iter(Axis(0)).enumerate() {
|
||||||
|
let (w1, h1) = (xs0[idx].width(), xs0[idx].height());
|
||||||
|
let v = luma.into_owned().into_raw_vec_and_offset().0;
|
||||||
|
let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap();
|
||||||
|
let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap();
|
||||||
|
let v = v
|
||||||
|
.iter()
|
||||||
|
.map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let luma = Ops::resize_luma8_u8(
|
||||||
|
&v,
|
||||||
|
self.width.opt() as _,
|
||||||
|
self.height.opt() as _,
|
||||||
|
w1 as _,
|
||||||
|
h1 as _,
|
||||||
|
false,
|
||||||
|
"Bilinear",
|
||||||
|
)?;
|
||||||
|
let luma: image::ImageBuffer<image::Luma<_>, Vec<_>> =
|
||||||
|
match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) {
|
||||||
|
None => continue,
|
||||||
|
Some(x) => x,
|
||||||
|
};
|
||||||
|
ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)]));
|
||||||
|
}
|
||||||
|
Ok(ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn batch(&self) -> isize {
|
||||||
|
self.batch.opt() as _
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -4,6 +4,7 @@ mod blip;
|
||||||
mod clip;
|
mod clip;
|
||||||
mod db;
|
mod db;
|
||||||
mod depth_anything;
|
mod depth_anything;
|
||||||
|
mod depth_pro;
|
||||||
mod dinov2;
|
mod dinov2;
|
||||||
mod florence2;
|
mod florence2;
|
||||||
mod grounding_dino;
|
mod grounding_dino;
|
||||||
|
|
@ -20,6 +21,7 @@ pub use blip::Blip;
|
||||||
pub use clip::Clip;
|
pub use clip::Clip;
|
||||||
pub use db::DB;
|
pub use db::DB;
|
||||||
pub use depth_anything::DepthAnything;
|
pub use depth_anything::DepthAnything;
|
||||||
|
pub use depth_pro::DepthPro;
|
||||||
pub use dinov2::Dinov2;
|
pub use dinov2::Dinov2;
|
||||||
pub use florence2::Florence2;
|
pub use florence2::Florence2;
|
||||||
pub use grounding_dino::GroundingDINO;
|
pub use grounding_dino::GroundingDINO;
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,14 @@ pub struct YOLO {
|
||||||
confs: DynConf,
|
confs: DynConf,
|
||||||
kconfs: DynConf,
|
kconfs: DynConf,
|
||||||
iou: f32,
|
iou: f32,
|
||||||
names: Option<Vec<String>>,
|
names: Vec<String>,
|
||||||
names_kpt: Option<Vec<String>>,
|
names_kpt: Vec<String>,
|
||||||
task: YOLOTask,
|
task: YOLOTask,
|
||||||
layout: YOLOPreds,
|
layout: YOLOPreds,
|
||||||
find_contours: bool,
|
find_contours: bool,
|
||||||
version: Option<YOLOVersion>,
|
version: Option<YOLOVersion>,
|
||||||
|
classes_excluded: Vec<isize>,
|
||||||
|
classes_retained: Vec<isize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Vision for YOLO {
|
impl Vision for YOLO {
|
||||||
|
|
@ -64,27 +66,26 @@ impl Vision for YOLO {
|
||||||
Some(task) => match task {
|
Some(task) => match task {
|
||||||
YOLOTask::Classify => match ver {
|
YOLOTask::Classify => match ver {
|
||||||
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_clss().apply_softmax(true)),
|
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_clss().apply_softmax(true)),
|
||||||
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_clss()),
|
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_clss()),
|
||||||
x => anyhow::bail!("YOLOTask::Classify is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
x => anyhow::bail!("YOLOTask::Classify is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
||||||
}
|
}
|
||||||
YOLOTask::Detect => match ver {
|
YOLOTask::Detect => match ver {
|
||||||
YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver),YOLOPreds::n_a_cxcywh_confclss()),
|
YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss()),
|
||||||
YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()),
|
YOLOVersion::V8 | YOLOVersion::V9 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_a()),
|
||||||
YOLOVersion::V9 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()),
|
YOLOVersion::V10 => (Some(ver), YOLOPreds::n_a_xyxy_confcls().apply_nms(false)),
|
||||||
YOLOVersion::V10 => (Some(ver),YOLOPreds::n_a_xyxy_confcls().apply_nms(false)),
|
YOLOVersion::RTDETR => (Some(ver), YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)),
|
||||||
YOLOVersion::RTDETR => (Some(ver),YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)),
|
|
||||||
}
|
}
|
||||||
YOLOTask::Pose => match ver {
|
YOLOTask::Pose => match ver {
|
||||||
YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_xycs_a()),
|
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_xycs_a()),
|
||||||
x => anyhow::bail!("YOLOTask::Pose is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
x => anyhow::bail!("YOLOTask::Pose is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
||||||
}
|
}
|
||||||
YOLOTask::Segment => match ver {
|
YOLOTask::Segment => match ver {
|
||||||
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss_coefs()),
|
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss_coefs()),
|
||||||
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()),
|
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()),
|
||||||
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
||||||
}
|
}
|
||||||
YOLOTask::Obb => match ver {
|
YOLOTask::Obb => match ver {
|
||||||
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()),
|
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()),
|
||||||
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -97,49 +98,75 @@ impl Vision for YOLO {
|
||||||
|
|
||||||
let task = task.unwrap_or(layout.task());
|
let task = task.unwrap_or(layout.task());
|
||||||
|
|
||||||
// The number of classes & Class names
|
// Class names: user-defined.or(parsed)
|
||||||
let mut names = options.names.or(Self::fetch_names(&engine));
|
let names_parsed = Self::fetch_names(&engine);
|
||||||
let nc = match options.nc {
|
let names = match names_parsed {
|
||||||
Some(nc) => {
|
Some(names_parsed) => match options.names {
|
||||||
match &names {
|
Some(names) => {
|
||||||
None => names = Some((0..nc).map(|x| x.to_string()).collect::<Vec<String>>()),
|
if names.len() == names_parsed.len() {
|
||||||
Some(names) => {
|
Some(names)
|
||||||
assert_eq!(
|
} else {
|
||||||
nc,
|
anyhow::bail!(
|
||||||
|
"The lengths of parsed class names: {} and user-defined class names: {} do not match.",
|
||||||
|
names_parsed.len(),
|
||||||
names.len(),
|
names.len(),
|
||||||
"The length of `nc` and `class names` is not equal."
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
nc
|
None => Some(names_parsed),
|
||||||
}
|
},
|
||||||
None => match &names {
|
None => options.names,
|
||||||
Some(names) => names.len(),
|
};
|
||||||
None => panic!(
|
|
||||||
"Can not parse model without `nc` and `class names`. Try to make it explicit with `options.with_nc(80)`"
|
// nc: names.len().or(options.nc)
|
||||||
|
let nc = match &names {
|
||||||
|
Some(names) => names.len(),
|
||||||
|
None => match options.nc {
|
||||||
|
Some(nc) => nc,
|
||||||
|
None => anyhow::bail!(
|
||||||
|
"Unable to obtain the number of classes. Please specify them explicitly using `options.with_nc(usize)` or `options.with_names(&[&str])`."
|
||||||
),
|
),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Class names
|
||||||
|
let names = match names {
|
||||||
|
None => Self::n2s(nc),
|
||||||
|
Some(names) => names,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Keypoint names & nk
|
||||||
|
let (nk, names_kpt) = match Self::fetch_kpts(&engine) {
|
||||||
|
None => (0, vec![]),
|
||||||
|
Some(nk) => match options.names2 {
|
||||||
|
Some(names) => {
|
||||||
|
if names.len() == nk {
|
||||||
|
(nk, names)
|
||||||
|
} else {
|
||||||
|
anyhow::bail!(
|
||||||
|
"The lengths of user-defined keypoint names: {} and nk: {} do not match.",
|
||||||
|
names.len(),
|
||||||
|
nk,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => (nk, Self::n2s(nk)),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// Keypoints names
|
// Confs & Iou
|
||||||
let names_kpt = options.names2;
|
|
||||||
|
|
||||||
// The number of keypoints
|
|
||||||
let nk = engine
|
|
||||||
.try_fetch("kpt_shape")
|
|
||||||
.map(|kpt_string| {
|
|
||||||
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
|
|
||||||
let caps = re.captures(&kpt_string).unwrap();
|
|
||||||
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
|
|
||||||
})
|
|
||||||
.unwrap_or(0_usize);
|
|
||||||
let confs = DynConf::new(&options.confs, nc);
|
let confs = DynConf::new(&options.confs, nc);
|
||||||
let kconfs = DynConf::new(&options.kconfs, nk);
|
let kconfs = DynConf::new(&options.kconfs, nk);
|
||||||
let iou = options.iou.unwrap_or(0.45);
|
let iou = options.iou.unwrap_or(0.45);
|
||||||
|
|
||||||
|
// Classes excluded and retained
|
||||||
|
let classes_excluded = options.classes_excluded;
|
||||||
|
let classes_retained = options.classes_retained;
|
||||||
|
|
||||||
// Summary
|
// Summary
|
||||||
tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version);
|
tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version);
|
||||||
|
|
||||||
|
// dry run
|
||||||
engine.dry_run()?;
|
engine.dry_run()?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
|
@ -158,6 +185,8 @@ impl Vision for YOLO {
|
||||||
layout,
|
layout,
|
||||||
version,
|
version,
|
||||||
find_contours: options.find_contours,
|
find_contours: options.find_contours,
|
||||||
|
classes_excluded,
|
||||||
|
classes_retained,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -219,10 +248,8 @@ impl Vision for YOLO {
|
||||||
slice_clss.into_owned()
|
slice_clss.into_owned()
|
||||||
};
|
};
|
||||||
let mut probs = Prob::default().with_probs(&x.into_raw_vec_and_offset().0);
|
let mut probs = Prob::default().with_probs(&x.into_raw_vec_and_offset().0);
|
||||||
if let Some(names) = &self.names {
|
probs = probs
|
||||||
probs =
|
.with_names(&self.names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
|
||||||
probs.with_names(&names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
|
|
||||||
}
|
|
||||||
|
|
||||||
return Some(y.with_probs(&probs));
|
return Some(y.with_probs(&probs));
|
||||||
}
|
}
|
||||||
|
|
@ -257,7 +284,19 @@ impl Vision for YOLO {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// filtering
|
// filtering by class id
|
||||||
|
if !self.classes_excluded.is_empty()
|
||||||
|
&& self.classes_excluded.contains(&(class_id as isize))
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
if !self.classes_retained.is_empty()
|
||||||
|
&& !self.classes_retained.contains(&(class_id as isize))
|
||||||
|
{
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// filtering by conf
|
||||||
if confidence < self.confs[class_id] {
|
if confidence < self.confs[class_id] {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
@ -325,9 +364,7 @@ impl Vision for YOLO {
|
||||||
)
|
)
|
||||||
.with_confidence(confidence)
|
.with_confidence(confidence)
|
||||||
.with_id(class_id as isize);
|
.with_id(class_id as isize);
|
||||||
if let Some(names) = &self.names {
|
mbr = mbr.with_name(&self.names[class_id]);
|
||||||
mbr = mbr.with_name(&names[class_id]);
|
|
||||||
}
|
|
||||||
|
|
||||||
(None, Some(mbr))
|
(None, Some(mbr))
|
||||||
}
|
}
|
||||||
|
|
@ -337,9 +374,7 @@ impl Vision for YOLO {
|
||||||
.with_confidence(confidence)
|
.with_confidence(confidence)
|
||||||
.with_id(class_id as isize)
|
.with_id(class_id as isize)
|
||||||
.with_id_born(i as isize);
|
.with_id_born(i as isize);
|
||||||
if let Some(names) = &self.names {
|
bbox = bbox.with_name(&self.names[class_id]);
|
||||||
bbox = bbox.with_name(&names[class_id]);
|
|
||||||
}
|
|
||||||
|
|
||||||
(Some(bbox), None)
|
(Some(bbox), None)
|
||||||
}
|
}
|
||||||
|
|
@ -394,9 +429,7 @@ impl Vision for YOLO {
|
||||||
ky.max(0.0f32).min(image_height),
|
ky.max(0.0f32).min(image_height),
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Some(names) = &self.names_kpt {
|
kpt = kpt.with_name(&self.names_kpt[i]);
|
||||||
kpt = kpt.with_name(&names[i]);
|
|
||||||
}
|
|
||||||
kpt
|
kpt
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -505,16 +538,16 @@ impl Vision for YOLO {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl YOLO {
|
impl YOLO {
|
||||||
pub fn batch(&self) -> isize {
|
pub fn batch(&self) -> usize {
|
||||||
self.batch.opt() as _
|
self.batch.opt()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn width(&self) -> isize {
|
pub fn width(&self) -> usize {
|
||||||
self.width.opt() as _
|
self.width.opt()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn height(&self) -> isize {
|
pub fn height(&self) -> usize {
|
||||||
self.height.opt() as _
|
self.height.opt()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn version(&self) -> Option<&YOLOVersion> {
|
pub fn version(&self) -> Option<&YOLOVersion> {
|
||||||
|
|
@ -541,4 +574,16 @@ impl YOLO {
|
||||||
names_
|
names_
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn fetch_kpts(engine: &OrtEngine) -> Option<usize> {
|
||||||
|
engine.try_fetch("kpt_shape").map(|s| {
|
||||||
|
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
|
||||||
|
let caps = re.captures(&s).unwrap();
|
||||||
|
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn n2s(n: usize) -> Vec<String> {
|
||||||
|
(0..n).map(|x| format!("# {}", x)).collect::<Vec<String>>()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,28 @@ pub enum YOLOTask {
|
||||||
Obb,
|
Obb,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl YOLOTask {
|
||||||
|
pub fn name(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Self::Classify => "cls".to_string(),
|
||||||
|
Self::Detect => "det".to_string(),
|
||||||
|
Self::Pose => "pose".to_string(),
|
||||||
|
Self::Segment => "seg".to_string(),
|
||||||
|
Self::Obb => "obb".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn name_detailed(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Self::Classify => "image-classification".to_string(),
|
||||||
|
Self::Detect => "object-detection".to_string(),
|
||||||
|
Self::Pose => "pose-estimation".to_string(),
|
||||||
|
Self::Segment => "instance-segment".to_string(),
|
||||||
|
Self::Obb => "oriented-object-detection".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone, clap::ValueEnum)]
|
#[derive(Debug, Copy, Clone, clap::ValueEnum)]
|
||||||
pub enum YOLOVersion {
|
pub enum YOLOVersion {
|
||||||
V5,
|
V5,
|
||||||
|
|
@ -17,9 +39,54 @@ pub enum YOLOVersion {
|
||||||
V8,
|
V8,
|
||||||
V9,
|
V9,
|
||||||
V10,
|
V10,
|
||||||
|
V11,
|
||||||
RTDETR,
|
RTDETR,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl YOLOVersion {
|
||||||
|
pub fn name(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Self::V5 => "v5".to_string(),
|
||||||
|
Self::V6 => "v6".to_string(),
|
||||||
|
Self::V7 => "v7".to_string(),
|
||||||
|
Self::V8 => "v8".to_string(),
|
||||||
|
Self::V9 => "v9".to_string(),
|
||||||
|
Self::V10 => "v10".to_string(),
|
||||||
|
Self::V11 => "v11".to_string(),
|
||||||
|
Self::RTDETR => "rtdetr".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone, clap::ValueEnum)]
|
||||||
|
pub enum YOLOScale {
|
||||||
|
N,
|
||||||
|
T,
|
||||||
|
B,
|
||||||
|
S,
|
||||||
|
M,
|
||||||
|
L,
|
||||||
|
C,
|
||||||
|
E,
|
||||||
|
X,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl YOLOScale {
|
||||||
|
pub fn name(&self) -> String {
|
||||||
|
match self {
|
||||||
|
Self::N => "n".to_string(),
|
||||||
|
Self::T => "t".to_string(),
|
||||||
|
Self::S => "s".to_string(),
|
||||||
|
Self::B => "b".to_string(),
|
||||||
|
Self::M => "m".to_string(),
|
||||||
|
Self::L => "l".to_string(),
|
||||||
|
Self::C => "c".to_string(),
|
||||||
|
Self::E => "e".to_string(),
|
||||||
|
Self::X => "x".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub enum BoxType {
|
pub enum BoxType {
|
||||||
/// 1
|
/// 1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue