Compare commits
10 Commits
6ace97f09f
...
4e932c4910
| Author | SHA1 | Date |
|---|---|---|
|
|
4e932c4910 | |
|
|
2785b090c6 | |
|
|
57db14ce5d | |
|
|
447889028e | |
|
|
1d596383de | |
|
|
0102c15687 | |
|
|
64dc804a13 | |
|
|
0609dd1f1d | |
|
|
2cb9e57fc4 | |
|
|
f2c4593672 |
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "usls"
|
||||
version = "0.0.16"
|
||||
version = "0.0.20"
|
||||
edition = "2021"
|
||||
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
|
||||
repository = "https://github.com/jamjamjon/usls"
|
||||
|
|
@ -12,7 +12,7 @@ exclude = ["assets/*", "examples/*", "scripts/*", "runs/*"]
|
|||
[dependencies]
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
ndarray = { version = "0.16.1", features = ["rayon"] }
|
||||
ort = { version = "2.0.0-rc.5", default-features = false}
|
||||
ort = { version = "2.0.0-rc.9", default-features = false }
|
||||
anyhow = { version = "1.0.75" }
|
||||
regex = { version = "1.5.4" }
|
||||
rand = { version = "0.8.5" }
|
||||
|
|
@ -30,7 +30,7 @@ imageproc = { version = "0.24" }
|
|||
ab_glyph = "0.2.23"
|
||||
geo = "0.28.0"
|
||||
prost = "0.12.4"
|
||||
fast_image_resize = { version = "4.2.1", features = ["image"]}
|
||||
fast_image_resize = { version = "4.2.1", features = ["image"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tempfile = "3.12.0"
|
||||
|
|
@ -50,7 +50,6 @@ default = [
|
|||
"ort/cuda",
|
||||
"ort/tensorrt",
|
||||
"ort/coreml",
|
||||
"ort/operator-libraries"
|
||||
]
|
||||
auto = ["ort/download-binaries"]
|
||||
|
||||
|
|
|
|||
157
README.md
157
README.md
|
|
@ -3,7 +3,7 @@
|
|||
</p>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://docs.rs/usls"><strong>Documentation</strong></a> |
|
||||
<a href="https://docs.rs/usls"><strong>Documentation</strong></a>
|
||||
<br>
|
||||
<br>
|
||||
<a href='https://github.com/microsoft/onnxruntime/releases'>
|
||||
|
|
@ -34,9 +34,9 @@
|
|||
|
||||
**`usls`** is a Rust library integrated with **ONNXRuntime** that provides a collection of state-of-the-art models for **Computer Vision** and **Vision-Language** tasks, including:
|
||||
|
||||
- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10)
|
||||
- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [YOLOv11](https://github.com/ultralytics/ultralytics)
|
||||
- **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
|
||||
- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569)
|
||||
- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569), [DepthPro](https://github.com/apple/ml-depth-pro)
|
||||
- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242)
|
||||
|
||||
<details>
|
||||
|
|
@ -51,7 +51,8 @@
|
|||
| [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [YOLOv8](https://github.com/ultralytics/ultralytics) | Object Detection<br>Instance Segmentation<br>Classification<br>Oriented Object Detection<br>Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [YOLOv11](https://github.com/ultralytics/ultralytics) | Object Detection<br>Instance Segmentation<br>Classification<br>Oriented Object Detection<br>Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [RTDETR](https://arxiv.org/abs/2304.08069) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | | |
|
||||
|
|
@ -67,11 +68,12 @@
|
|||
| [SVTR](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | ✅ | ✅ | ❌ | ❌ |
|
||||
| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [Depth-Anything](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ |
|
||||
| [Depth-Anything v1 & v2](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ |
|
||||
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | |
|
||||
| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Body Part Segmentation | [demo](examples/sapiens) | ✅ | ✅ | | |
|
||||
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | | |
|
||||
| [DepthPro](https://github.com/apple/ml-depth-pro) | Monocular Depth Estimation | [demo](examples/depth-pro) | ✅ | ✅ | | |
|
||||
|
||||
|
||||
|
||||
|
|
@ -80,7 +82,8 @@
|
|||
|
||||
## ⛳️ ONNXRuntime Linking
|
||||
|
||||
You have two options to link the ONNXRuntime library
|
||||
<details>
|
||||
<summary>You have two options to link the ONNXRuntime library</summary>
|
||||
|
||||
- ### Option 1: Manual Linking
|
||||
|
||||
|
|
@ -99,6 +102,7 @@ You have two options to link the ONNXRuntime library
|
|||
cargo run -r --example yolo --features auto
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 🎈 Demo
|
||||
|
||||
|
|
@ -118,75 +122,100 @@ cargo run -r --example yolo # blip, clip, yolop, svtr, db, ...
|
|||
[dependencies]
|
||||
usls = { git = "https://github.com/jamjamjon/usls", rev = "commit-sha" }
|
||||
```
|
||||
|
||||
|
||||
- #### Follow the pipeline
|
||||
- Build model with the provided `models` and `Options`
|
||||
- Load images, video and stream with `DataLoader`
|
||||
- Do inference
|
||||
- Annotate inference results with `Annotator`
|
||||
- Retrieve inference results from `Vec<Y>`
|
||||
|
||||
```rust
|
||||
use usls::{models::YOLO, Annotator, DataLoader, Nms, Options, Vision, YOLOTask, YOLOVersion};
|
||||
- Annotate inference results with `Annotator`
|
||||
- Display images and write them to video with `Viewer`
|
||||
|
||||
<br/>
|
||||
<details>
|
||||
<summary>example code</summary>
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Build model with Options
|
||||
let options = Options::new()
|
||||
.with_trt(0)
|
||||
.with_model("yolo/v8-m-dyn.onnx")?
|
||||
.with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
|
||||
.with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb
|
||||
.with_i00((1, 2, 4).into())
|
||||
.with_i02((0, 640, 640).into())
|
||||
.with_i03((0, 640, 640).into())
|
||||
.with_confs(&[0.2]);
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// Build DataLoader to load image(s), video, stream
|
||||
let dl = DataLoader::new(
|
||||
// "./assets/bus.jpg", // local image
|
||||
// "images/bus.jpg", // remote image
|
||||
// "../images-folder", // local images (from folder)
|
||||
// "../demo.mp4", // local video
|
||||
// "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // online video
|
||||
"rtsp://admin:kkasd1234@192.168.2.217:554/h264/ch1/", // stream
|
||||
)?
|
||||
.with_batch(2) // iterate with batch_size = 2
|
||||
.build()?;
|
||||
|
||||
// Build annotator
|
||||
let annotator = Annotator::new()
|
||||
.with_bboxes_thickness(4)
|
||||
.with_saveout("YOLO-DataLoader");
|
||||
|
||||
// Run and annotate results
|
||||
for (xs, _) in dl {
|
||||
let ys = model.forward(&xs, false)?;
|
||||
annotator.annotate(&xs, &ys);
|
||||
|
||||
// Retrieve inference results
|
||||
for y in ys {
|
||||
// bboxes
|
||||
if let Some(bboxes) = y.bboxes() {
|
||||
for bbox in bboxes {
|
||||
println!(
|
||||
"Bbox: {}, {}, {}, {}, {}, {}",
|
||||
bbox.xmin(),
|
||||
bbox.ymin(),
|
||||
bbox.xmax(),
|
||||
bbox.ymax(),
|
||||
bbox.confidence(),
|
||||
bbox.id(),
|
||||
);
|
||||
}
|
||||
```rust
|
||||
use usls::{models::YOLO, Annotator, DataLoader, Nms, Options, Vision, YOLOTask, YOLOVersion};
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Build model with Options
|
||||
let options = Options::new()
|
||||
.with_trt(0)
|
||||
.with_model("yolo/v8-m-dyn.onnx")?
|
||||
.with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
|
||||
.with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb
|
||||
.with_ixx(0, 0, (1, 2, 4).into())
|
||||
.with_ixx(0, 2, (0, 640, 640).into())
|
||||
.with_ixx(0, 3, (0, 640, 640).into())
|
||||
.with_confs(&[0.2]);
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// Build DataLoader to load image(s), video, stream
|
||||
let dl = DataLoader::new(
|
||||
// "./assets/bus.jpg", // local image
|
||||
// "images/bus.jpg", // remote image
|
||||
// "../images-folder", // local images (from folder)
|
||||
// "../demo.mp4", // local video
|
||||
// "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // online video
|
||||
"rtsp://admin:kkasd1234@192.168.2.217:554/h264/ch1/", // stream
|
||||
)?
|
||||
.with_batch(2) // iterate with batch_size = 2
|
||||
.build()?;
|
||||
|
||||
// Build annotator
|
||||
let annotator = Annotator::new()
|
||||
.with_bboxes_thickness(4)
|
||||
.with_saveout("YOLO-DataLoader");
|
||||
|
||||
// Build viewer
|
||||
let mut viewer = Viewer::new().with_delay(10).with_scale(1.).resizable(true);
|
||||
|
||||
// Run and annotate results
|
||||
for (xs, _) in dl {
|
||||
let ys = model.forward(&xs, false)?;
|
||||
// annotator.annotate(&xs, &ys);
|
||||
let images_plotted = annotator.plot(&xs, &ys, false)?;
|
||||
|
||||
// show image
|
||||
viewer.imshow(&images_plotted)?;
|
||||
|
||||
// check out window and key event
|
||||
if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) {
|
||||
break;
|
||||
}
|
||||
|
||||
// write video
|
||||
viewer.write_batch(&images_plotted)?;
|
||||
|
||||
// Retrieve inference results
|
||||
for y in ys {
|
||||
// bboxes
|
||||
if let Some(bboxes) = y.bboxes() {
|
||||
for bbox in bboxes {
|
||||
println!(
|
||||
"Bbox: {}, {}, {}, {}, {}, {}",
|
||||
bbox.xmin(),
|
||||
bbox.ymin(),
|
||||
bbox.xmax(),
|
||||
bbox.ymax(),
|
||||
bbox.confidence(),
|
||||
bbox.id(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
// finish video write
|
||||
viewer.finish_write()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
</br>
|
||||
|
||||
## 📌 License
|
||||
This project is licensed under [LICENSE](LICENSE).
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
use usls::{models::DepthPro, Annotator, DataLoader, Options};
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// options
|
||||
let options = Options::default()
|
||||
.with_model("depth-pro/q4f16.onnx")? // bnb4, f16
|
||||
.with_ixx(0, 0, 1.into()) // batch. Note: now only support batch_size = 1
|
||||
.with_ixx(0, 1, 3.into()) // channel
|
||||
.with_ixx(0, 2, 1536.into()) // height
|
||||
.with_ixx(0, 3, 1536.into()); // width
|
||||
let mut model = DepthPro::new(options)?;
|
||||
|
||||
// load
|
||||
let x = [DataLoader::try_read("images/street.jpg")?];
|
||||
|
||||
// run
|
||||
let y = model.run(&x)?;
|
||||
|
||||
// annotate
|
||||
let annotator = Annotator::default()
|
||||
.with_colormap("Turbo")
|
||||
.with_saveout("Depth-Pro");
|
||||
annotator.annotate(&x, &y);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -24,45 +24,41 @@
|
|||
## Quick Start
|
||||
```Shell
|
||||
|
||||
# customized
|
||||
cargo run -r --example yolo -- --task detect --ver v8 --nc 6 --model xxx.onnx # YOLOv8
|
||||
|
||||
# Classify
|
||||
cargo run -r --example yolo -- --task classify --ver v5 # YOLOv5
|
||||
cargo run -r --example yolo -- --task classify --ver v8 # YOLOv8
|
||||
cargo run -r --example yolo -- --task classify --ver v5 --scale s --width 224 --height 224 --nc 1000 # YOLOv5
|
||||
cargo run -r --example yolo -- --task classify --ver v8 --scale n --width 224 --height 224 --nc 1000 # YOLOv8
|
||||
cargo run -r --example yolo -- --task classify --ver v11 --scale n --width 224 --height 224 --nc 1000 # YOLOv11
|
||||
|
||||
# Detect
|
||||
cargo run -r --example yolo -- --task detect --ver v5 # YOLOv5
|
||||
cargo run -r --example yolo -- --task detect --ver v6 # YOLOv6
|
||||
cargo run -r --example yolo -- --task detect --ver v7 # YOLOv7
|
||||
cargo run -r --example yolo -- --task detect --ver v8 # YOLOv8
|
||||
cargo run -r --example yolo -- --task detect --ver v9 # YOLOv9
|
||||
cargo run -r --example yolo -- --task detect --ver v10 # YOLOv10
|
||||
cargo run -r --example yolo -- --task detect --ver rtdetr # YOLOv8-RTDETR
|
||||
cargo run -r --example yolo -- --task detect --ver v8 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world
|
||||
cargo run -r --example yolo -- --task detect --ver v5 --scale n # YOLOv5
|
||||
cargo run -r --example yolo -- --task detect --ver v6 --scale n # YOLOv6
|
||||
cargo run -r --example yolo -- --task detect --ver v7 --scale t # YOLOv7
|
||||
cargo run -r --example yolo -- --task detect --ver v8 --scale n # YOLOv8
|
||||
cargo run -r --example yolo -- --task detect --ver v9 --scale t # YOLOv9
|
||||
cargo run -r --example yolo -- --task detect --ver v10 --scale n # YOLOv10
|
||||
cargo run -r --example yolo -- --task detect --ver v11 --scale n # YOLOv11
|
||||
cargo run -r --example yolo -- --task detect --ver rtdetr --scale l # RTDETR
|
||||
cargo run -r --example yolo -- --task detect --ver v8 --model yolo/v8-s-world-v2-shoes.onnx # YOLOv8-world
|
||||
|
||||
# Pose
|
||||
cargo run -r --example yolo -- --task pose --ver v8 # YOLOv8-Pose
|
||||
cargo run -r --example yolo -- --task pose --ver v8 --scale n # YOLOv8-Pose
|
||||
cargo run -r --example yolo -- --task pose --ver v11 --scale n # YOLOv11-Pose
|
||||
|
||||
# Segment
|
||||
cargo run -r --example yolo -- --task segment --ver v5 # YOLOv5-Segment
|
||||
cargo run -r --example yolo -- --task segment --ver v8 # YOLOv8-Segment
|
||||
cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM
|
||||
cargo run -r --example yolo -- --task segment --ver v5 --scale n # YOLOv5-Segment
|
||||
cargo run -r --example yolo -- --task segment --ver v8 --scale n # YOLOv8-Segment
|
||||
cargo run -r --example yolo -- --task segment --ver v11 --scale n # YOLOv8-Segment
|
||||
cargo run -r --example yolo -- --task segment --ver v8 --model yolo/FastSAM-s-dyn-f16.onnx # FastSAM
|
||||
|
||||
# Obb
|
||||
cargo run -r --example yolo -- --task obb --ver v8 # YOLOv8-Obb
|
||||
cargo run -r --example yolo -- --ver v8 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv8-Obb
|
||||
cargo run -r --example yolo -- --ver v11 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv11-Obb
|
||||
```
|
||||
|
||||
<details close>
|
||||
<summary>other options</summary>
|
||||
|
||||
`--source` to specify the input images
|
||||
`--model` to specify the ONNX model
|
||||
`--width --height` to specify the input resolution
|
||||
`--nc` to specify the number of model's classes
|
||||
`--plot` to annotate with inference results
|
||||
`--profile` to profile
|
||||
`--cuda --trt --coreml --device_id` to select device
|
||||
`--half` to use float16 when using TensorRT EP
|
||||
|
||||
</details>
|
||||
**`cargo run -r --example yolo -- --help` for more options**
|
||||
|
||||
|
||||
## YOLOs configs with `Options`
|
||||
|
|
@ -96,6 +92,8 @@ let options = Options::default()
|
|||
..Default::default()
|
||||
}
|
||||
)
|
||||
// .with_nc(80)
|
||||
// .with_names(&COCO_CLASS_NAMES_80)
|
||||
.with_model("xxxx.onnx")?;
|
||||
```
|
||||
</details>
|
||||
|
|
@ -140,7 +138,7 @@ let options = Options::default()
|
|||
</details>
|
||||
|
||||
<details close>
|
||||
<summary>YOLOv8</summary>
|
||||
<summary>YOLOv8, YOLOv11</summary>
|
||||
|
||||
```Shell
|
||||
pip install -U ultralytics
|
||||
|
|
|
|||
|
|
@ -2,188 +2,169 @@ use anyhow::Result;
|
|||
use clap::Parser;
|
||||
|
||||
use usls::{
|
||||
models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion, COCO_KEYPOINTS_17,
|
||||
COCO_SKELETONS_16,
|
||||
models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOScale, YOLOTask,
|
||||
YOLOVersion, COCO_SKELETONS_16,
|
||||
};
|
||||
|
||||
#[derive(Parser, Clone)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
/// Path to the model
|
||||
#[arg(long)]
|
||||
pub model: Option<String>,
|
||||
|
||||
/// Input source path
|
||||
#[arg(long, default_value_t = String::from("./assets/bus.jpg"))]
|
||||
pub source: String,
|
||||
|
||||
/// YOLO Task
|
||||
#[arg(long, value_enum, default_value_t = YOLOTask::Detect)]
|
||||
pub task: YOLOTask,
|
||||
|
||||
/// YOLO Version
|
||||
#[arg(long, value_enum, default_value_t = YOLOVersion::V8)]
|
||||
pub ver: YOLOVersion,
|
||||
|
||||
/// YOLO Scale
|
||||
#[arg(long, value_enum, default_value_t = YOLOScale::N)]
|
||||
pub scale: YOLOScale,
|
||||
|
||||
/// Batch size
|
||||
#[arg(long, default_value_t = 1)]
|
||||
pub batch_size: usize,
|
||||
|
||||
/// Minimum input width
|
||||
#[arg(long, default_value_t = 224)]
|
||||
pub width_min: isize,
|
||||
|
||||
/// Input width
|
||||
#[arg(long, default_value_t = 640)]
|
||||
pub width: isize,
|
||||
|
||||
#[arg(long, default_value_t = 800)]
|
||||
/// Maximum input width
|
||||
#[arg(long, default_value_t = 1024)]
|
||||
pub width_max: isize,
|
||||
|
||||
/// Minimum input height
|
||||
#[arg(long, default_value_t = 224)]
|
||||
pub height_min: isize,
|
||||
|
||||
/// Input height
|
||||
#[arg(long, default_value_t = 640)]
|
||||
pub height: isize,
|
||||
|
||||
#[arg(long, default_value_t = 800)]
|
||||
/// Maximum input height
|
||||
#[arg(long, default_value_t = 1024)]
|
||||
pub height_max: isize,
|
||||
|
||||
/// Number of classes
|
||||
#[arg(long, default_value_t = 80)]
|
||||
pub nc: usize,
|
||||
|
||||
/// Class confidence
|
||||
#[arg(long)]
|
||||
pub confs: Vec<f32>,
|
||||
|
||||
/// Enable TensorRT support
|
||||
#[arg(long)]
|
||||
pub trt: bool,
|
||||
|
||||
/// Enable CUDA support
|
||||
#[arg(long)]
|
||||
pub cuda: bool,
|
||||
|
||||
#[arg(long)]
|
||||
pub half: bool,
|
||||
|
||||
/// Enable CoreML support
|
||||
#[arg(long)]
|
||||
pub coreml: bool,
|
||||
|
||||
/// Use TensorRT half precision
|
||||
#[arg(long)]
|
||||
pub half: bool,
|
||||
|
||||
/// Device ID to use
|
||||
#[arg(long, default_value_t = 0)]
|
||||
pub device_id: usize,
|
||||
|
||||
/// Enable performance profiling
|
||||
#[arg(long)]
|
||||
pub profile: bool,
|
||||
|
||||
#[arg(long)]
|
||||
pub no_plot: bool,
|
||||
|
||||
/// Disable contour drawing
|
||||
#[arg(long)]
|
||||
pub no_contours: bool,
|
||||
|
||||
/// Show result
|
||||
#[arg(long)]
|
||||
pub view: bool,
|
||||
|
||||
/// Do not save output
|
||||
#[arg(long)]
|
||||
pub nosave: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
// build options
|
||||
let options = Options::default();
|
||||
|
||||
// version & task
|
||||
let (options, saveout) = match args.ver {
|
||||
YOLOVersion::V5 => match args.task {
|
||||
YOLOTask::Classify => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v5-n-cls-dyn.onnx".to_string()))?,
|
||||
"YOLOv5-Classify",
|
||||
),
|
||||
YOLOTask::Detect => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v5-n-dyn.onnx".to_string()))?,
|
||||
"YOLOv5-Detect",
|
||||
),
|
||||
YOLOTask::Segment => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v5-n-seg-dyn.onnx".to_string()))?,
|
||||
"YOLOv5-Segment",
|
||||
),
|
||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
||||
},
|
||||
YOLOVersion::V6 => match args.task {
|
||||
YOLOTask::Detect => (
|
||||
options
|
||||
.with_model(&args.model.unwrap_or("yolo/v6-n-dyn.onnx".to_string()))?
|
||||
.with_nc(args.nc),
|
||||
"YOLOv6-Detect",
|
||||
),
|
||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
||||
},
|
||||
YOLOVersion::V7 => match args.task {
|
||||
YOLOTask::Detect => (
|
||||
options
|
||||
.with_model(&args.model.unwrap_or("yolo/v7-tiny-dyn.onnx".to_string()))?
|
||||
.with_nc(args.nc),
|
||||
"YOLOv7-Detect",
|
||||
),
|
||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
||||
},
|
||||
YOLOVersion::V8 => match args.task {
|
||||
YOLOTask::Classify => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-cls-dyn.onnx".to_string()))?,
|
||||
"YOLOv8-Classify",
|
||||
),
|
||||
YOLOTask::Detect => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-dyn.onnx".to_string()))?,
|
||||
"YOLOv8-Detect",
|
||||
),
|
||||
YOLOTask::Segment => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-seg-dyn.onnx".to_string()))?,
|
||||
"YOLOv8-Segment",
|
||||
),
|
||||
YOLOTask::Pose => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-pose-dyn.onnx".to_string()))?,
|
||||
"YOLOv8-Pose",
|
||||
),
|
||||
YOLOTask::Obb => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v8-m-obb-dyn.onnx".to_string()))?,
|
||||
"YOLOv8-Obb",
|
||||
),
|
||||
},
|
||||
YOLOVersion::V9 => match args.task {
|
||||
YOLOTask::Detect => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v9-c-dyn-f16.onnx".to_string()))?,
|
||||
"YOLOv9-Detect",
|
||||
),
|
||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
||||
},
|
||||
YOLOVersion::V10 => match args.task {
|
||||
YOLOTask::Detect => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/v10-n.onnx".to_string()))?,
|
||||
"YOLOv10-Detect",
|
||||
),
|
||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
||||
},
|
||||
YOLOVersion::RTDETR => match args.task {
|
||||
YOLOTask::Detect => (
|
||||
options.with_model(&args.model.unwrap_or("yolo/rtdetr-l-f16.onnx".to_string()))?,
|
||||
"RTDETR",
|
||||
),
|
||||
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
|
||||
},
|
||||
// model path
|
||||
let path = match &args.model {
|
||||
None => format!(
|
||||
"yolo/{}-{}-{}.onnx",
|
||||
args.ver.name(),
|
||||
args.scale.name(),
|
||||
args.task.name()
|
||||
),
|
||||
Some(x) => x.to_string(),
|
||||
};
|
||||
|
||||
let options = options
|
||||
.with_yolo_version(args.ver)
|
||||
.with_yolo_task(args.task);
|
||||
// saveout
|
||||
let saveout = match &args.model {
|
||||
None => format!(
|
||||
"{}-{}-{}",
|
||||
args.ver.name(),
|
||||
args.scale.name(),
|
||||
args.task.name()
|
||||
),
|
||||
Some(x) => {
|
||||
let p = std::path::PathBuf::from(&x);
|
||||
p.file_stem().unwrap().to_str().unwrap().to_string()
|
||||
}
|
||||
};
|
||||
|
||||
// device
|
||||
let options = if args.cuda {
|
||||
options.with_cuda(args.device_id)
|
||||
let device = if args.cuda {
|
||||
Device::Cuda(args.device_id)
|
||||
} else if args.trt {
|
||||
let options = options.with_trt(args.device_id);
|
||||
if args.half {
|
||||
options.with_trt_fp16(true)
|
||||
} else {
|
||||
options
|
||||
}
|
||||
Device::Trt(args.device_id)
|
||||
} else if args.coreml {
|
||||
options.with_coreml(args.device_id)
|
||||
Device::CoreML(args.device_id)
|
||||
} else {
|
||||
options.with_cpu()
|
||||
Device::Cpu(args.device_id)
|
||||
};
|
||||
let options = options
|
||||
|
||||
// build options
|
||||
let options = Options::new()
|
||||
.with_model(&path)?
|
||||
.with_yolo_version(args.ver)
|
||||
.with_yolo_task(args.task)
|
||||
.with_device(device)
|
||||
.with_trt_fp16(args.half)
|
||||
.with_ixx(0, 0, (1, args.batch_size as _, 4).into())
|
||||
.with_ixx(0, 2, (args.height_min, args.height, args.height_max).into())
|
||||
.with_ixx(0, 3, (args.width_min, args.width, args.width_max).into())
|
||||
.with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15
|
||||
.with_confs(if args.confs.is_empty() {
|
||||
&[0.2, 0.15]
|
||||
} else {
|
||||
&args.confs
|
||||
})
|
||||
.with_nc(args.nc)
|
||||
// .with_names(&COCO_CLASS_NAMES_80)
|
||||
.with_names2(&COCO_KEYPOINTS_17)
|
||||
// .with_names2(&COCO_KEYPOINTS_17)
|
||||
.with_find_contours(!args.no_contours) // find contours or not
|
||||
.exclude_classes(&[0])
|
||||
// .retain_classes(&[0, 5])
|
||||
.with_profile(args.profile);
|
||||
|
||||
// build model
|
||||
let mut model = YOLO::new(options)?;
|
||||
|
||||
// build dataloader
|
||||
|
|
@ -194,16 +175,54 @@ fn main() -> Result<()> {
|
|||
// build annotator
|
||||
let annotator = Annotator::default()
|
||||
.with_skeletons(&COCO_SKELETONS_16)
|
||||
.with_bboxes_thickness(4)
|
||||
.without_masks(true) // No masks plotting when doing segment task.
|
||||
.with_saveout(saveout);
|
||||
.with_bboxes_thickness(3)
|
||||
.with_keypoints_name(false) // Enable keypoints names
|
||||
.with_saveout_subs(&["YOLO"])
|
||||
.with_saveout(&saveout);
|
||||
|
||||
// build viewer
|
||||
let mut viewer = if args.view {
|
||||
Some(Viewer::new().with_delay(5).with_scale(1.).resizable(true))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// run & annotate
|
||||
for (xs, _paths) in dl {
|
||||
// let ys = model.run(&xs)?; // way one
|
||||
let ys = model.forward(&xs, args.profile)?; // way two
|
||||
if !args.no_plot {
|
||||
annotator.annotate(&xs, &ys);
|
||||
let images_plotted = annotator.plot(&xs, &ys, !args.nosave)?;
|
||||
|
||||
// show image
|
||||
match &mut viewer {
|
||||
Some(viewer) => viewer.imshow(&images_plotted)?,
|
||||
None => continue,
|
||||
}
|
||||
|
||||
// check out window and key event
|
||||
match &mut viewer {
|
||||
Some(viewer) => {
|
||||
if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
None => continue,
|
||||
}
|
||||
|
||||
// write video
|
||||
if !args.nosave {
|
||||
match &mut viewer {
|
||||
Some(viewer) => viewer.write_batch(&images_plotted)?,
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// finish video write
|
||||
if !args.nosave {
|
||||
if let Some(viewer) = &mut viewer {
|
||||
viewer.finish_write()?;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -48,6 +48,8 @@ pub struct Options {
|
|||
pub sam_kind: Option<SamKind>,
|
||||
pub use_low_res_mask: Option<bool>,
|
||||
pub sapiens_task: Option<SapiensTask>,
|
||||
pub classes_excluded: Vec<isize>,
|
||||
pub classes_retained: Vec<isize>,
|
||||
}
|
||||
|
||||
impl Default for Options {
|
||||
|
|
@ -88,6 +90,8 @@ impl Default for Options {
|
|||
use_low_res_mask: None,
|
||||
sapiens_task: None,
|
||||
task: Task::Untitled,
|
||||
classes_excluded: vec![],
|
||||
classes_retained: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -276,4 +280,16 @@ impl Options {
|
|||
self.iiixs.push(Iiix::from((i, ii, x)));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn exclude_classes(mut self, xs: &[isize]) -> Self {
|
||||
self.classes_retained.clear();
|
||||
self.classes_excluded.extend_from_slice(xs);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn retain_classes(mut self, xs: &[isize]) -> Self {
|
||||
self.classes_excluded.clear();
|
||||
self.classes_retained.extend_from_slice(xs);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ use anyhow::Result;
|
|||
use half::f16;
|
||||
use ndarray::{Array, IxDyn};
|
||||
use ort::{
|
||||
ExecutionProvider, Session, SessionBuilder, TensorElementType, TensorRTExecutionProvider,
|
||||
execution_providers::{ExecutionProvider, TensorRTExecutionProvider},
|
||||
session::{builder::SessionBuilder, Session},
|
||||
tensor::TensorElementType,
|
||||
};
|
||||
use prost::Message;
|
||||
use std::collections::HashSet;
|
||||
|
|
@ -88,14 +90,14 @@ impl OrtEngine {
|
|||
|
||||
// build
|
||||
ort::init().commit()?;
|
||||
let builder = Session::builder()?;
|
||||
let mut builder = Session::builder()?;
|
||||
let mut device = config.device.to_owned();
|
||||
match device {
|
||||
Device::Trt(device_id) => {
|
||||
Self::build_trt(
|
||||
&inputs_attrs.names,
|
||||
&inputs_minoptmax,
|
||||
&builder,
|
||||
&mut builder,
|
||||
device_id,
|
||||
config.trt_int8_enable,
|
||||
config.trt_fp16_enable,
|
||||
|
|
@ -103,23 +105,23 @@ impl OrtEngine {
|
|||
)?;
|
||||
}
|
||||
Device::Cuda(device_id) => {
|
||||
Self::build_cuda(&builder, device_id).unwrap_or_else(|err| {
|
||||
Self::build_cuda(&mut builder, device_id).unwrap_or_else(|err| {
|
||||
tracing::warn!("{err}, Using cpu");
|
||||
device = Device::Cpu(0);
|
||||
})
|
||||
}
|
||||
Device::CoreML(_) => Self::build_coreml(&builder).unwrap_or_else(|err| {
|
||||
Device::CoreML(_) => Self::build_coreml(&mut builder).unwrap_or_else(|err| {
|
||||
tracing::warn!("{err}, Using cpu");
|
||||
device = Device::Cpu(0);
|
||||
}),
|
||||
Device::Cpu(_) => {
|
||||
Self::build_cpu(&builder)?;
|
||||
Self::build_cpu(&mut builder)?;
|
||||
}
|
||||
_ => todo!(),
|
||||
}
|
||||
|
||||
let session = builder
|
||||
.with_optimization_level(ort::GraphOptimizationLevel::Level3)?
|
||||
.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
|
||||
.commit_from_file(&config.onnx_path)?;
|
||||
|
||||
// summary
|
||||
|
|
@ -149,7 +151,7 @@ impl OrtEngine {
|
|||
fn build_trt(
|
||||
names: &[String],
|
||||
inputs_minoptmax: &[Vec<MinOptMax>],
|
||||
builder: &SessionBuilder,
|
||||
builder: &mut SessionBuilder,
|
||||
device_id: usize,
|
||||
int8_enable: bool,
|
||||
fp16_enable: bool,
|
||||
|
|
@ -205,8 +207,9 @@ impl OrtEngine {
|
|||
}
|
||||
}
|
||||
|
||||
fn build_cuda(builder: &SessionBuilder, device_id: usize) -> Result<()> {
|
||||
let ep = ort::CUDAExecutionProvider::default().with_device_id(device_id as i32);
|
||||
fn build_cuda(builder: &mut SessionBuilder, device_id: usize) -> Result<()> {
|
||||
let ep = ort::execution_providers::CUDAExecutionProvider::default()
|
||||
.with_device_id(device_id as i32);
|
||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||
Ok(())
|
||||
} else {
|
||||
|
|
@ -214,8 +217,8 @@ impl OrtEngine {
|
|||
}
|
||||
}
|
||||
|
||||
fn build_coreml(builder: &SessionBuilder) -> Result<()> {
|
||||
let ep = ort::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
|
||||
fn build_coreml(builder: &mut SessionBuilder) -> Result<()> {
|
||||
let ep = ort::execution_providers::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
|
||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||
Ok(())
|
||||
} else {
|
||||
|
|
@ -223,8 +226,8 @@ impl OrtEngine {
|
|||
}
|
||||
}
|
||||
|
||||
fn build_cpu(builder: &SessionBuilder) -> Result<()> {
|
||||
let ep = ort::CPUExecutionProvider::default();
|
||||
fn build_cpu(builder: &mut SessionBuilder) -> Result<()> {
|
||||
let ep = ort::execution_providers::CPUExecutionProvider::default();
|
||||
if ep.is_available()? && ep.register(builder).is_ok() {
|
||||
Ok(())
|
||||
} else {
|
||||
|
|
@ -292,28 +295,28 @@ impl OrtEngine {
|
|||
let t_pre = std::time::Instant::now();
|
||||
for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) {
|
||||
let x_ = match &idtype {
|
||||
TensorElementType::Float32 => ort::Value::from_array(x.view())?.into_dyn(),
|
||||
TensorElementType::Float32 => ort::value::Value::from_array(x.view())?.into_dyn(),
|
||||
TensorElementType::Float16 => {
|
||||
ort::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
|
||||
ort::value::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
|
||||
}
|
||||
TensorElementType::Int32 => {
|
||||
ort::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
|
||||
ort::value::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
|
||||
}
|
||||
TensorElementType::Int64 => {
|
||||
ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
|
||||
ort::value::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
|
||||
}
|
||||
TensorElementType::Uint8 => {
|
||||
ort::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
|
||||
ort::value::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
|
||||
}
|
||||
TensorElementType::Int8 => {
|
||||
ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
|
||||
ort::value::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
|
||||
}
|
||||
TensorElementType::Bool => {
|
||||
ort::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
|
||||
ort::value::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
|
||||
}
|
||||
_ => todo!(),
|
||||
};
|
||||
xs_.push(Into::<ort::SessionInputValue<'_>>::into(x_));
|
||||
xs_.push(Into::<ort::session::SessionInputValue<'_>>::into(x_));
|
||||
}
|
||||
let t_pre = t_pre.elapsed();
|
||||
self.ts.add_or_push(0, t_pre);
|
||||
|
|
@ -451,45 +454,45 @@ impl OrtEngine {
|
|||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn nbytes_from_onnx_dtype(x: &ort::TensorElementType) -> usize {
|
||||
fn nbytes_from_onnx_dtype(x: &ort::tensor::TensorElementType) -> usize {
|
||||
match x {
|
||||
ort::TensorElementType::Float64
|
||||
| ort::TensorElementType::Uint64
|
||||
| ort::TensorElementType::Int64 => 8, // i64, f64, u64
|
||||
ort::TensorElementType::Float32
|
||||
| ort::TensorElementType::Uint32
|
||||
| ort::TensorElementType::Int32
|
||||
| ort::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
|
||||
ort::TensorElementType::Float16
|
||||
| ort::TensorElementType::Bfloat16
|
||||
| ort::TensorElementType::Int16
|
||||
| ort::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
|
||||
ort::TensorElementType::Uint8
|
||||
| ort::TensorElementType::Int8
|
||||
| ort::TensorElementType::Bool => 1, // u8, i8, bool
|
||||
ort::tensor::TensorElementType::Float64
|
||||
| ort::tensor::TensorElementType::Uint64
|
||||
| ort::tensor::TensorElementType::Int64 => 8, // i64, f64, u64
|
||||
ort::tensor::TensorElementType::Float32
|
||||
| ort::tensor::TensorElementType::Uint32
|
||||
| ort::tensor::TensorElementType::Int32
|
||||
| ort::tensor::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
|
||||
ort::tensor::TensorElementType::Float16
|
||||
| ort::tensor::TensorElementType::Bfloat16
|
||||
| ort::tensor::TensorElementType::Int16
|
||||
| ort::tensor::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
|
||||
ort::tensor::TensorElementType::Uint8
|
||||
| ort::tensor::TensorElementType::Int8
|
||||
| ort::tensor::TensorElementType::Bool => 1, // u8, i8, bool
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::TensorElementType> {
|
||||
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::tensor::TensorElementType> {
|
||||
match value {
|
||||
0 => None,
|
||||
1 => Some(ort::TensorElementType::Float32),
|
||||
2 => Some(ort::TensorElementType::Uint8),
|
||||
3 => Some(ort::TensorElementType::Int8),
|
||||
4 => Some(ort::TensorElementType::Uint16),
|
||||
5 => Some(ort::TensorElementType::Int16),
|
||||
6 => Some(ort::TensorElementType::Int32),
|
||||
7 => Some(ort::TensorElementType::Int64),
|
||||
8 => Some(ort::TensorElementType::String),
|
||||
9 => Some(ort::TensorElementType::Bool),
|
||||
10 => Some(ort::TensorElementType::Float16),
|
||||
11 => Some(ort::TensorElementType::Float64),
|
||||
12 => Some(ort::TensorElementType::Uint32),
|
||||
13 => Some(ort::TensorElementType::Uint64),
|
||||
1 => Some(ort::tensor::TensorElementType::Float32),
|
||||
2 => Some(ort::tensor::TensorElementType::Uint8),
|
||||
3 => Some(ort::tensor::TensorElementType::Int8),
|
||||
4 => Some(ort::tensor::TensorElementType::Uint16),
|
||||
5 => Some(ort::tensor::TensorElementType::Int16),
|
||||
6 => Some(ort::tensor::TensorElementType::Int32),
|
||||
7 => Some(ort::tensor::TensorElementType::Int64),
|
||||
8 => Some(ort::tensor::TensorElementType::String),
|
||||
9 => Some(ort::tensor::TensorElementType::Bool),
|
||||
10 => Some(ort::tensor::TensorElementType::Float16),
|
||||
11 => Some(ort::tensor::TensorElementType::Float64),
|
||||
12 => Some(ort::tensor::TensorElementType::Uint32),
|
||||
13 => Some(ort::tensor::TensorElementType::Uint64),
|
||||
14 => None, // COMPLEX64
|
||||
15 => None, // COMPLEX128
|
||||
16 => Some(ort::TensorElementType::Bfloat16),
|
||||
16 => Some(ort::tensor::TensorElementType::Bfloat16),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
|
@ -499,7 +502,7 @@ impl OrtEngine {
|
|||
value_info: &[onnx::ValueInfoProto],
|
||||
) -> Result<OrtTensorAttr> {
|
||||
let mut dimss: Vec<Vec<usize>> = Vec::new();
|
||||
let mut dtypes: Vec<ort::TensorElementType> = Vec::new();
|
||||
let mut dtypes: Vec<ort::tensor::TensorElementType> = Vec::new();
|
||||
let mut names: Vec<String> = Vec::new();
|
||||
for v in value_info.iter() {
|
||||
if initializer_names.contains(v.name.as_str()) {
|
||||
|
|
@ -569,7 +572,7 @@ impl OrtEngine {
|
|||
&self.outputs_attrs.names
|
||||
}
|
||||
|
||||
pub fn odtypes(&self) -> &Vec<ort::TensorElementType> {
|
||||
pub fn odtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
|
||||
&self.outputs_attrs.dtypes
|
||||
}
|
||||
|
||||
|
|
@ -585,7 +588,7 @@ impl OrtEngine {
|
|||
&self.inputs_attrs.names
|
||||
}
|
||||
|
||||
pub fn idtypes(&self) -> &Vec<ort::TensorElementType> {
|
||||
pub fn idtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
|
||||
&self.inputs_attrs.dtypes
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,86 @@
|
|||
use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y};
|
||||
use anyhow::Result;
|
||||
use image::DynamicImage;
|
||||
use ndarray::Axis;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DepthPro {
|
||||
engine: OrtEngine,
|
||||
height: MinOptMax,
|
||||
width: MinOptMax,
|
||||
batch: MinOptMax,
|
||||
}
|
||||
|
||||
impl DepthPro {
|
||||
pub fn new(options: Options) -> Result<Self> {
|
||||
let mut engine = OrtEngine::new(&options)?;
|
||||
let (batch, height, width) = (
|
||||
engine.batch().clone(),
|
||||
engine.height().clone(),
|
||||
engine.width().clone(),
|
||||
);
|
||||
engine.dry_run()?;
|
||||
|
||||
Ok(Self {
|
||||
engine,
|
||||
height,
|
||||
width,
|
||||
batch,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||
let xs_ = X::apply(&[
|
||||
Ops::Resize(
|
||||
xs,
|
||||
self.height.opt() as u32,
|
||||
self.width.opt() as u32,
|
||||
"Bilinear",
|
||||
),
|
||||
Ops::Normalize(0., 255.),
|
||||
Ops::Standardize(&[0.5, 0.5, 0.5], &[0.5, 0.5, 0.5], 3),
|
||||
Ops::Nhwc2nchw,
|
||||
])?;
|
||||
let ys = self.engine.run(Xs::from(xs_))?;
|
||||
|
||||
self.postprocess(ys, xs)
|
||||
}
|
||||
|
||||
pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
|
||||
let (predicted_depth, _focallength_px) = (&xs["predicted_depth"], &xs["focallength_px"]);
|
||||
let predicted_depth = predicted_depth.mapv(|x| 1. / x);
|
||||
|
||||
let mut ys: Vec<Y> = Vec::new();
|
||||
for (idx, luma) in predicted_depth.axis_iter(Axis(0)).enumerate() {
|
||||
let (w1, h1) = (xs0[idx].width(), xs0[idx].height());
|
||||
let v = luma.into_owned().into_raw_vec_and_offset().0;
|
||||
let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap();
|
||||
let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap();
|
||||
let v = v
|
||||
.iter()
|
||||
.map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let luma = Ops::resize_luma8_u8(
|
||||
&v,
|
||||
self.width.opt() as _,
|
||||
self.height.opt() as _,
|
||||
w1 as _,
|
||||
h1 as _,
|
||||
false,
|
||||
"Bilinear",
|
||||
)?;
|
||||
let luma: image::ImageBuffer<image::Luma<_>, Vec<_>> =
|
||||
match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) {
|
||||
None => continue,
|
||||
Some(x) => x,
|
||||
};
|
||||
ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)]));
|
||||
}
|
||||
Ok(ys)
|
||||
}
|
||||
|
||||
pub fn batch(&self) -> isize {
|
||||
self.batch.opt() as _
|
||||
}
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ mod blip;
|
|||
mod clip;
|
||||
mod db;
|
||||
mod depth_anything;
|
||||
mod depth_pro;
|
||||
mod dinov2;
|
||||
mod florence2;
|
||||
mod grounding_dino;
|
||||
|
|
@ -20,6 +21,7 @@ pub use blip::Blip;
|
|||
pub use clip::Clip;
|
||||
pub use db::DB;
|
||||
pub use depth_anything::DepthAnything;
|
||||
pub use depth_pro::DepthPro;
|
||||
pub use dinov2::Dinov2;
|
||||
pub use florence2::Florence2;
|
||||
pub use grounding_dino::GroundingDINO;
|
||||
|
|
|
|||
|
|
@ -20,12 +20,14 @@ pub struct YOLO {
|
|||
confs: DynConf,
|
||||
kconfs: DynConf,
|
||||
iou: f32,
|
||||
names: Option<Vec<String>>,
|
||||
names_kpt: Option<Vec<String>>,
|
||||
names: Vec<String>,
|
||||
names_kpt: Vec<String>,
|
||||
task: YOLOTask,
|
||||
layout: YOLOPreds,
|
||||
find_contours: bool,
|
||||
version: Option<YOLOVersion>,
|
||||
classes_excluded: Vec<isize>,
|
||||
classes_retained: Vec<isize>,
|
||||
}
|
||||
|
||||
impl Vision for YOLO {
|
||||
|
|
@ -64,27 +66,26 @@ impl Vision for YOLO {
|
|||
Some(task) => match task {
|
||||
YOLOTask::Classify => match ver {
|
||||
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_clss().apply_softmax(true)),
|
||||
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_clss()),
|
||||
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_clss()),
|
||||
x => anyhow::bail!("YOLOTask::Classify is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
||||
}
|
||||
YOLOTask::Detect => match ver {
|
||||
YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver),YOLOPreds::n_a_cxcywh_confclss()),
|
||||
YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()),
|
||||
YOLOVersion::V9 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()),
|
||||
YOLOVersion::V10 => (Some(ver),YOLOPreds::n_a_xyxy_confcls().apply_nms(false)),
|
||||
YOLOVersion::RTDETR => (Some(ver),YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)),
|
||||
YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss()),
|
||||
YOLOVersion::V8 | YOLOVersion::V9 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_a()),
|
||||
YOLOVersion::V10 => (Some(ver), YOLOPreds::n_a_xyxy_confcls().apply_nms(false)),
|
||||
YOLOVersion::RTDETR => (Some(ver), YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)),
|
||||
}
|
||||
YOLOTask::Pose => match ver {
|
||||
YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_xycs_a()),
|
||||
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_xycs_a()),
|
||||
x => anyhow::bail!("YOLOTask::Pose is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
||||
}
|
||||
YOLOTask::Segment => match ver {
|
||||
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss_coefs()),
|
||||
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()),
|
||||
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()),
|
||||
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
||||
}
|
||||
YOLOTask::Obb => match ver {
|
||||
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()),
|
||||
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()),
|
||||
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
|
||||
}
|
||||
}
|
||||
|
|
@ -97,49 +98,75 @@ impl Vision for YOLO {
|
|||
|
||||
let task = task.unwrap_or(layout.task());
|
||||
|
||||
// The number of classes & Class names
|
||||
let mut names = options.names.or(Self::fetch_names(&engine));
|
||||
let nc = match options.nc {
|
||||
Some(nc) => {
|
||||
match &names {
|
||||
None => names = Some((0..nc).map(|x| x.to_string()).collect::<Vec<String>>()),
|
||||
Some(names) => {
|
||||
assert_eq!(
|
||||
nc,
|
||||
// Class names: user-defined.or(parsed)
|
||||
let names_parsed = Self::fetch_names(&engine);
|
||||
let names = match names_parsed {
|
||||
Some(names_parsed) => match options.names {
|
||||
Some(names) => {
|
||||
if names.len() == names_parsed.len() {
|
||||
Some(names)
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"The lengths of parsed class names: {} and user-defined class names: {} do not match.",
|
||||
names_parsed.len(),
|
||||
names.len(),
|
||||
"The length of `nc` and `class names` is not equal."
|
||||
);
|
||||
}
|
||||
}
|
||||
nc
|
||||
}
|
||||
None => match &names {
|
||||
Some(names) => names.len(),
|
||||
None => panic!(
|
||||
"Can not parse model without `nc` and `class names`. Try to make it explicit with `options.with_nc(80)`"
|
||||
None => Some(names_parsed),
|
||||
},
|
||||
None => options.names,
|
||||
};
|
||||
|
||||
// nc: names.len().or(options.nc)
|
||||
let nc = match &names {
|
||||
Some(names) => names.len(),
|
||||
None => match options.nc {
|
||||
Some(nc) => nc,
|
||||
None => anyhow::bail!(
|
||||
"Unable to obtain the number of classes. Please specify them explicitly using `options.with_nc(usize)` or `options.with_names(&[&str])`."
|
||||
),
|
||||
}
|
||||
};
|
||||
|
||||
// Class names
|
||||
let names = match names {
|
||||
None => Self::n2s(nc),
|
||||
Some(names) => names,
|
||||
};
|
||||
|
||||
// Keypoint names & nk
|
||||
let (nk, names_kpt) = match Self::fetch_kpts(&engine) {
|
||||
None => (0, vec![]),
|
||||
Some(nk) => match options.names2 {
|
||||
Some(names) => {
|
||||
if names.len() == nk {
|
||||
(nk, names)
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"The lengths of user-defined keypoint names: {} and nk: {} do not match.",
|
||||
names.len(),
|
||||
nk,
|
||||
);
|
||||
}
|
||||
}
|
||||
None => (nk, Self::n2s(nk)),
|
||||
},
|
||||
};
|
||||
|
||||
// Keypoints names
|
||||
let names_kpt = options.names2;
|
||||
|
||||
// The number of keypoints
|
||||
let nk = engine
|
||||
.try_fetch("kpt_shape")
|
||||
.map(|kpt_string| {
|
||||
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
|
||||
let caps = re.captures(&kpt_string).unwrap();
|
||||
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
|
||||
})
|
||||
.unwrap_or(0_usize);
|
||||
// Confs & Iou
|
||||
let confs = DynConf::new(&options.confs, nc);
|
||||
let kconfs = DynConf::new(&options.kconfs, nk);
|
||||
let iou = options.iou.unwrap_or(0.45);
|
||||
|
||||
// Classes excluded and retained
|
||||
let classes_excluded = options.classes_excluded;
|
||||
let classes_retained = options.classes_retained;
|
||||
|
||||
// Summary
|
||||
tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version);
|
||||
|
||||
// dry run
|
||||
engine.dry_run()?;
|
||||
|
||||
Ok(Self {
|
||||
|
|
@ -158,6 +185,8 @@ impl Vision for YOLO {
|
|||
layout,
|
||||
version,
|
||||
find_contours: options.find_contours,
|
||||
classes_excluded,
|
||||
classes_retained,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
@ -219,10 +248,8 @@ impl Vision for YOLO {
|
|||
slice_clss.into_owned()
|
||||
};
|
||||
let mut probs = Prob::default().with_probs(&x.into_raw_vec_and_offset().0);
|
||||
if let Some(names) = &self.names {
|
||||
probs =
|
||||
probs.with_names(&names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
|
||||
}
|
||||
probs = probs
|
||||
.with_names(&self.names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
|
||||
|
||||
return Some(y.with_probs(&probs));
|
||||
}
|
||||
|
|
@ -257,7 +284,19 @@ impl Vision for YOLO {
|
|||
}
|
||||
};
|
||||
|
||||
// filtering
|
||||
// filtering by class id
|
||||
if !self.classes_excluded.is_empty()
|
||||
&& self.classes_excluded.contains(&(class_id as isize))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
if !self.classes_retained.is_empty()
|
||||
&& !self.classes_retained.contains(&(class_id as isize))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// filtering by conf
|
||||
if confidence < self.confs[class_id] {
|
||||
return None;
|
||||
}
|
||||
|
|
@ -325,9 +364,7 @@ impl Vision for YOLO {
|
|||
)
|
||||
.with_confidence(confidence)
|
||||
.with_id(class_id as isize);
|
||||
if let Some(names) = &self.names {
|
||||
mbr = mbr.with_name(&names[class_id]);
|
||||
}
|
||||
mbr = mbr.with_name(&self.names[class_id]);
|
||||
|
||||
(None, Some(mbr))
|
||||
}
|
||||
|
|
@ -337,9 +374,7 @@ impl Vision for YOLO {
|
|||
.with_confidence(confidence)
|
||||
.with_id(class_id as isize)
|
||||
.with_id_born(i as isize);
|
||||
if let Some(names) = &self.names {
|
||||
bbox = bbox.with_name(&names[class_id]);
|
||||
}
|
||||
bbox = bbox.with_name(&self.names[class_id]);
|
||||
|
||||
(Some(bbox), None)
|
||||
}
|
||||
|
|
@ -394,9 +429,7 @@ impl Vision for YOLO {
|
|||
ky.max(0.0f32).min(image_height),
|
||||
);
|
||||
|
||||
if let Some(names) = &self.names_kpt {
|
||||
kpt = kpt.with_name(&names[i]);
|
||||
}
|
||||
kpt = kpt.with_name(&self.names_kpt[i]);
|
||||
kpt
|
||||
}
|
||||
})
|
||||
|
|
@ -505,16 +538,16 @@ impl Vision for YOLO {
|
|||
}
|
||||
|
||||
impl YOLO {
|
||||
pub fn batch(&self) -> isize {
|
||||
self.batch.opt() as _
|
||||
pub fn batch(&self) -> usize {
|
||||
self.batch.opt()
|
||||
}
|
||||
|
||||
pub fn width(&self) -> isize {
|
||||
self.width.opt() as _
|
||||
pub fn width(&self) -> usize {
|
||||
self.width.opt()
|
||||
}
|
||||
|
||||
pub fn height(&self) -> isize {
|
||||
self.height.opt() as _
|
||||
pub fn height(&self) -> usize {
|
||||
self.height.opt()
|
||||
}
|
||||
|
||||
pub fn version(&self) -> Option<&YOLOVersion> {
|
||||
|
|
@ -541,4 +574,16 @@ impl YOLO {
|
|||
names_
|
||||
})
|
||||
}
|
||||
|
||||
fn fetch_kpts(engine: &OrtEngine) -> Option<usize> {
|
||||
engine.try_fetch("kpt_shape").map(|s| {
|
||||
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
|
||||
let caps = re.captures(&s).unwrap();
|
||||
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
fn n2s(n: usize) -> Vec<String> {
|
||||
(0..n).map(|x| format!("# {}", x)).collect::<Vec<String>>()
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,28 @@ pub enum YOLOTask {
|
|||
Obb,
|
||||
}
|
||||
|
||||
impl YOLOTask {
|
||||
pub fn name(&self) -> String {
|
||||
match self {
|
||||
Self::Classify => "cls".to_string(),
|
||||
Self::Detect => "det".to_string(),
|
||||
Self::Pose => "pose".to_string(),
|
||||
Self::Segment => "seg".to_string(),
|
||||
Self::Obb => "obb".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name_detailed(&self) -> String {
|
||||
match self {
|
||||
Self::Classify => "image-classification".to_string(),
|
||||
Self::Detect => "object-detection".to_string(),
|
||||
Self::Pose => "pose-estimation".to_string(),
|
||||
Self::Segment => "instance-segment".to_string(),
|
||||
Self::Obb => "oriented-object-detection".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, clap::ValueEnum)]
|
||||
pub enum YOLOVersion {
|
||||
V5,
|
||||
|
|
@ -17,9 +39,54 @@ pub enum YOLOVersion {
|
|||
V8,
|
||||
V9,
|
||||
V10,
|
||||
V11,
|
||||
RTDETR,
|
||||
}
|
||||
|
||||
impl YOLOVersion {
|
||||
pub fn name(&self) -> String {
|
||||
match self {
|
||||
Self::V5 => "v5".to_string(),
|
||||
Self::V6 => "v6".to_string(),
|
||||
Self::V7 => "v7".to_string(),
|
||||
Self::V8 => "v8".to_string(),
|
||||
Self::V9 => "v9".to_string(),
|
||||
Self::V10 => "v10".to_string(),
|
||||
Self::V11 => "v11".to_string(),
|
||||
Self::RTDETR => "rtdetr".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, clap::ValueEnum)]
|
||||
pub enum YOLOScale {
|
||||
N,
|
||||
T,
|
||||
B,
|
||||
S,
|
||||
M,
|
||||
L,
|
||||
C,
|
||||
E,
|
||||
X,
|
||||
}
|
||||
|
||||
impl YOLOScale {
|
||||
pub fn name(&self) -> String {
|
||||
match self {
|
||||
Self::N => "n".to_string(),
|
||||
Self::T => "t".to_string(),
|
||||
Self::S => "s".to_string(),
|
||||
Self::B => "b".to_string(),
|
||||
Self::M => "m".to_string(),
|
||||
Self::L => "l".to_string(),
|
||||
Self::C => "c".to_string(),
|
||||
Self::E => "e".to_string(),
|
||||
Self::X => "x".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum BoxType {
|
||||
/// 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue