Compare commits

...

10 Commits

Author SHA1 Message Date
jamjamjon 4e932c4910 Bump the version to 0.0.20 2024-12-03 19:37:34 +08:00
Collide 2785b090c6
upgrade ort to v2.0.0-rc.9 (#52) 2024-12-03 19:16:23 +08:00
Jamjamjon 57db14ce5d
Update README.md 2024-10-10 00:30:52 +08:00
Jamjamjon 447889028e
Add Apple ml-depth-pro model 2024-10-10 00:26:26 +08:00
Jamjamjon 1d596383de
Add support for restricting detection classes (#45)
* Add support for restricting detection classes in `Options`
2024-10-05 17:49:08 +08:00
Jamjamjon 0102c15687
Minor fixes 2024-10-01 09:37:46 +08:00
Jamjamjon 64dc804a13
Update README.md 2024-09-30 22:48:07 +08:00
Jamjamjon 0609dd1f1d
Add YOLOv11
* Add YOLOv11
2024-09-30 22:43:34 +08:00
Jamjamjon 2cb9e57fc4
Update README.md 2024-09-28 10:49:06 +08:00
Jamjamjon f2c4593672
Update README.md 2024-09-28 10:10:05 +08:00
11 changed files with 612 additions and 322 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "usls"
version = "0.0.16"
version = "0.0.20"
edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"
@ -12,7 +12,7 @@ exclude = ["assets/*", "examples/*", "scripts/*", "runs/*"]
[dependencies]
clap = { version = "4.2.4", features = ["derive"] }
ndarray = { version = "0.16.1", features = ["rayon"] }
ort = { version = "2.0.0-rc.5", default-features = false}
ort = { version = "2.0.0-rc.9", default-features = false }
anyhow = { version = "1.0.75" }
regex = { version = "1.5.4" }
rand = { version = "0.8.5" }
@ -30,7 +30,7 @@ imageproc = { version = "0.24" }
ab_glyph = "0.2.23"
geo = "0.28.0"
prost = "0.12.4"
fast_image_resize = { version = "4.2.1", features = ["image"]}
fast_image_resize = { version = "4.2.1", features = ["image"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tempfile = "3.12.0"
@ -50,7 +50,6 @@ default = [
"ort/cuda",
"ort/tensorrt",
"ort/coreml",
"ort/operator-libraries"
]
auto = ["ort/download-binaries"]

145
README.md
View File

@ -3,7 +3,7 @@
</p>
<p align="center">
| <a href="https://docs.rs/usls"><strong>Documentation</strong></a> |
<a href="https://docs.rs/usls"><strong>Documentation</strong></a>
<br>
<br>
<a href='https://github.com/microsoft/onnxruntime/releases'>
@ -34,9 +34,9 @@
**`usls`** is a Rust library integrated with **ONNXRuntime** that provides a collection of state-of-the-art models for **Computer Vision** and **Vision-Language** tasks, including:
- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10)
- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10), [YOLOv11](https://github.com/ultralytics/ultralytics)
- **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569)
- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569), [DepthPro](https://github.com/apple/ml-depth-pro)
- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242)
<details>
@ -51,7 +51,8 @@
| [YOLOv7](https://github.com/WongKinYiu/yolov7) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [YOLOv8](https://github.com/ultralytics/ultralytics) | Object Detection<br>Instance Segmentation<br>Classification<br>Oriented Object Detection<br>Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [YOLOv9](https://github.com/WongKinYiu/yolov9) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [YOLOv10](https://github.com/THU-MIG/yolov10) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [YOLOv11](https://github.com/ultralytics/ultralytics) | Object Detection<br>Instance Segmentation<br>Classification<br>Oriented Object Detection<br>Keypoint Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [RTDETR](https://arxiv.org/abs/2304.08069) | Object Detection | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) | Instance Segmentation | [demo](examples/yolo) | ✅ | ✅ | ✅ | ✅ |
| [SAM](https://github.com/facebookresearch/segment-anything) | Segment Anything | [demo](examples/sam) | ✅ | ✅ | | |
@ -67,11 +68,12 @@
| [SVTR](https://arxiv.org/abs/2205.00159) | Text Recognition | [demo](examples/svtr) | ✅ | ✅ | ✅ | ✅ |
| [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo) | Keypoint Detection | [demo](examples/rtmo) | ✅ | ✅ | ❌ | ❌ |
| [YOLOPv2](https://arxiv.org/abs/2208.11434) | Panoptic Driving Perception | [demo](examples/yolop) | ✅ | ✅ | ✅ | ✅ |
| [Depth-Anything](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ |
| [Depth-Anything v1 & v2](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ |
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ |
| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | |
| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Body Part Segmentation | [demo](examples/sapiens) | ✅ | ✅ | | |
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) | ✅ | ✅ | | |
| [DepthPro](https://github.com/apple/ml-depth-pro) | Monocular Depth Estimation | [demo](examples/depth-pro) | ✅ | ✅ | | |
@ -80,7 +82,8 @@
## ⛳️ ONNXRuntime Linking
You have two options to link the ONNXRuntime library
<details>
<summary>You have two options to link the ONNXRuntime library</summary>
- ### Option 1: Manual Linking
@ -99,6 +102,7 @@ You have two options to link the ONNXRuntime library
cargo run -r --example yolo --features auto
```
</details>
## 🎈 Demo
@ -123,70 +127,95 @@ cargo run -r --example yolo # blip, clip, yolop, svtr, db, ...
- Build model with the provided `models` and `Options`
- Load images, video and stream with `DataLoader`
- Do inference
- Annotate inference results with `Annotator`
- Retrieve inference results from `Vec<Y>`
- Annotate inference results with `Annotator`
- Display images and write them to video with `Viewer`
```rust
use usls::{models::YOLO, Annotator, DataLoader, Nms, Options, Vision, YOLOTask, YOLOVersion};
<br/>
<details>
<summary>example code</summary>
fn main() -> anyhow::Result<()> {
// Build model with Options
let options = Options::new()
.with_trt(0)
.with_model("yolo/v8-m-dyn.onnx")?
.with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
.with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb
.with_i00((1, 2, 4).into())
.with_i02((0, 640, 640).into())
.with_i03((0, 640, 640).into())
.with_confs(&[0.2]);
let mut model = YOLO::new(options)?;
```rust
use usls::{models::YOLO, Annotator, DataLoader, Nms, Options, Vision, YOLOTask, YOLOVersion};
// Build DataLoader to load image(s), video, stream
let dl = DataLoader::new(
// "./assets/bus.jpg", // local image
// "images/bus.jpg", // remote image
// "../images-folder", // local images (from folder)
// "../demo.mp4", // local video
// "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // online video
"rtsp://admin:kkasd1234@192.168.2.217:554/h264/ch1/", // stream
)?
.with_batch(2) // iterate with batch_size = 2
.build()?;
fn main() -> anyhow::Result<()> {
// Build model with Options
let options = Options::new()
.with_trt(0)
.with_model("yolo/v8-m-dyn.onnx")?
.with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
.with_yolo_task(YOLOTask::Detect) // YOLOTask: Classify, Detect, Pose, Segment, Obb
.with_ixx(0, 0, (1, 2, 4).into())
.with_ixx(0, 2, (0, 640, 640).into())
.with_ixx(0, 3, (0, 640, 640).into())
.with_confs(&[0.2]);
let mut model = YOLO::new(options)?;
// Build annotator
let annotator = Annotator::new()
.with_bboxes_thickness(4)
.with_saveout("YOLO-DataLoader");
// Build DataLoader to load image(s), video, stream
let dl = DataLoader::new(
// "./assets/bus.jpg", // local image
// "images/bus.jpg", // remote image
// "../images-folder", // local images (from folder)
// "../demo.mp4", // local video
// "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4", // online video
"rtsp://admin:kkasd1234@192.168.2.217:554/h264/ch1/", // stream
)?
.with_batch(2) // iterate with batch_size = 2
.build()?;
// Run and annotate results
for (xs, _) in dl {
let ys = model.forward(&xs, false)?;
annotator.annotate(&xs, &ys);
// Build annotator
let annotator = Annotator::new()
.with_bboxes_thickness(4)
.with_saveout("YOLO-DataLoader");
// Retrieve inference results
for y in ys {
// bboxes
if let Some(bboxes) = y.bboxes() {
for bbox in bboxes {
println!(
"Bbox: {}, {}, {}, {}, {}, {}",
bbox.xmin(),
bbox.ymin(),
bbox.xmax(),
bbox.ymax(),
bbox.confidence(),
bbox.id(),
);
}
// Build viewer
let mut viewer = Viewer::new().with_delay(10).with_scale(1.).resizable(true);
// Run and annotate results
for (xs, _) in dl {
let ys = model.forward(&xs, false)?;
// annotator.annotate(&xs, &ys);
let images_plotted = annotator.plot(&xs, &ys, false)?;
// show image
viewer.imshow(&images_plotted)?;
// check out window and key event
if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) {
break;
}
// write video
viewer.write_batch(&images_plotted)?;
// Retrieve inference results
for y in ys {
// bboxes
if let Some(bboxes) = y.bboxes() {
for bbox in bboxes {
println!(
"Bbox: {}, {}, {}, {}, {}, {}",
bbox.xmin(),
bbox.ymin(),
bbox.xmax(),
bbox.ymax(),
bbox.confidence(),
bbox.id(),
);
}
}
}
Ok(())
}
```
// finish video write
viewer.finish_write()?;
Ok(())
}
```
</details>
</br>
## 📌 License
This project is licensed under [LICENSE](LICENSE).

View File

@ -0,0 +1,26 @@
use usls::{models::DepthPro, Annotator, DataLoader, Options};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// options
let options = Options::default()
.with_model("depth-pro/q4f16.onnx")? // bnb4, f16
.with_ixx(0, 0, 1.into()) // batch. Note: now only support batch_size = 1
.with_ixx(0, 1, 3.into()) // channel
.with_ixx(0, 2, 1536.into()) // height
.with_ixx(0, 3, 1536.into()); // width
let mut model = DepthPro::new(options)?;
// load
let x = [DataLoader::try_read("images/street.jpg")?];
// run
let y = model.run(&x)?;
// annotate
let annotator = Annotator::default()
.with_colormap("Turbo")
.with_saveout("Depth-Pro");
annotator.annotate(&x, &y);
Ok(())
}

View File

@ -24,45 +24,41 @@
## Quick Start
```Shell
# customized
cargo run -r --example yolo -- --task detect --ver v8 --nc 6 --model xxx.onnx # YOLOv8
# Classify
cargo run -r --example yolo -- --task classify --ver v5 # YOLOv5
cargo run -r --example yolo -- --task classify --ver v8 # YOLOv8
cargo run -r --example yolo -- --task classify --ver v5 --scale s --width 224 --height 224 --nc 1000 # YOLOv5
cargo run -r --example yolo -- --task classify --ver v8 --scale n --width 224 --height 224 --nc 1000 # YOLOv8
cargo run -r --example yolo -- --task classify --ver v11 --scale n --width 224 --height 224 --nc 1000 # YOLOv11
# Detect
cargo run -r --example yolo -- --task detect --ver v5 # YOLOv5
cargo run -r --example yolo -- --task detect --ver v6 # YOLOv6
cargo run -r --example yolo -- --task detect --ver v7 # YOLOv7
cargo run -r --example yolo -- --task detect --ver v8 # YOLOv8
cargo run -r --example yolo -- --task detect --ver v9 # YOLOv9
cargo run -r --example yolo -- --task detect --ver v10 # YOLOv10
cargo run -r --example yolo -- --task detect --ver rtdetr # YOLOv8-RTDETR
cargo run -r --example yolo -- --task detect --ver v8 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world
cargo run -r --example yolo -- --task detect --ver v5 --scale n # YOLOv5
cargo run -r --example yolo -- --task detect --ver v6 --scale n # YOLOv6
cargo run -r --example yolo -- --task detect --ver v7 --scale t # YOLOv7
cargo run -r --example yolo -- --task detect --ver v8 --scale n # YOLOv8
cargo run -r --example yolo -- --task detect --ver v9 --scale t # YOLOv9
cargo run -r --example yolo -- --task detect --ver v10 --scale n # YOLOv10
cargo run -r --example yolo -- --task detect --ver v11 --scale n # YOLOv11
cargo run -r --example yolo -- --task detect --ver rtdetr --scale l # RTDETR
cargo run -r --example yolo -- --task detect --ver v8 --model yolo/v8-s-world-v2-shoes.onnx # YOLOv8-world
# Pose
cargo run -r --example yolo -- --task pose --ver v8 # YOLOv8-Pose
cargo run -r --example yolo -- --task pose --ver v8 --scale n # YOLOv8-Pose
cargo run -r --example yolo -- --task pose --ver v11 --scale n # YOLOv11-Pose
# Segment
cargo run -r --example yolo -- --task segment --ver v5 # YOLOv5-Segment
cargo run -r --example yolo -- --task segment --ver v8 # YOLOv8-Segment
cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM
cargo run -r --example yolo -- --task segment --ver v5 --scale n # YOLOv5-Segment
cargo run -r --example yolo -- --task segment --ver v8 --scale n # YOLOv8-Segment
cargo run -r --example yolo -- --task segment --ver v11 --scale n # YOLOv8-Segment
cargo run -r --example yolo -- --task segment --ver v8 --model yolo/FastSAM-s-dyn-f16.onnx # FastSAM
# Obb
cargo run -r --example yolo -- --task obb --ver v8 # YOLOv8-Obb
cargo run -r --example yolo -- --ver v8 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv8-Obb
cargo run -r --example yolo -- --ver v11 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv11-Obb
```
<details close>
<summary>other options</summary>
`--source` to specify the input images
`--model` to specify the ONNX model
`--width --height` to specify the input resolution
`--nc` to specify the number of model's classes
`--plot` to annotate with inference results
`--profile` to profile
`--cuda --trt --coreml --device_id` to select device
`--half` to use float16 when using TensorRT EP
</details>
**`cargo run -r --example yolo -- --help` for more options**
## YOLOs configs with `Options`
@ -96,6 +92,8 @@ let options = Options::default()
..Default::default()
}
)
// .with_nc(80)
// .with_names(&COCO_CLASS_NAMES_80)
.with_model("xxxx.onnx")?;
```
</details>
@ -140,7 +138,7 @@ let options = Options::default()
</details>
<details close>
<summary>YOLOv8</summary>
<summary>YOLOv8, YOLOv11</summary>
```Shell
pip install -U ultralytics

View File

@ -2,188 +2,169 @@ use anyhow::Result;
use clap::Parser;
use usls::{
models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion, COCO_KEYPOINTS_17,
COCO_SKELETONS_16,
models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOScale, YOLOTask,
YOLOVersion, COCO_SKELETONS_16,
};
#[derive(Parser, Clone)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// Path to the model
#[arg(long)]
pub model: Option<String>,
/// Input source path
#[arg(long, default_value_t = String::from("./assets/bus.jpg"))]
pub source: String,
/// YOLO Task
#[arg(long, value_enum, default_value_t = YOLOTask::Detect)]
pub task: YOLOTask,
/// YOLO Version
#[arg(long, value_enum, default_value_t = YOLOVersion::V8)]
pub ver: YOLOVersion,
/// YOLO Scale
#[arg(long, value_enum, default_value_t = YOLOScale::N)]
pub scale: YOLOScale,
/// Batch size
#[arg(long, default_value_t = 1)]
pub batch_size: usize,
/// Minimum input width
#[arg(long, default_value_t = 224)]
pub width_min: isize,
/// Input width
#[arg(long, default_value_t = 640)]
pub width: isize,
#[arg(long, default_value_t = 800)]
/// Maximum input width
#[arg(long, default_value_t = 1024)]
pub width_max: isize,
/// Minimum input height
#[arg(long, default_value_t = 224)]
pub height_min: isize,
/// Input height
#[arg(long, default_value_t = 640)]
pub height: isize,
#[arg(long, default_value_t = 800)]
/// Maximum input height
#[arg(long, default_value_t = 1024)]
pub height_max: isize,
/// Number of classes
#[arg(long, default_value_t = 80)]
pub nc: usize,
/// Class confidence
#[arg(long)]
pub confs: Vec<f32>,
/// Enable TensorRT support
#[arg(long)]
pub trt: bool,
/// Enable CUDA support
#[arg(long)]
pub cuda: bool,
#[arg(long)]
pub half: bool,
/// Enable CoreML support
#[arg(long)]
pub coreml: bool,
/// Use TensorRT half precision
#[arg(long)]
pub half: bool,
/// Device ID to use
#[arg(long, default_value_t = 0)]
pub device_id: usize,
/// Enable performance profiling
#[arg(long)]
pub profile: bool,
#[arg(long)]
pub no_plot: bool,
/// Disable contour drawing
#[arg(long)]
pub no_contours: bool,
/// Show result
#[arg(long)]
pub view: bool,
/// Do not save output
#[arg(long)]
pub nosave: bool,
}
fn main() -> Result<()> {
let args = Args::parse();
// build options
let options = Options::default();
// version & task
let (options, saveout) = match args.ver {
YOLOVersion::V5 => match args.task {
YOLOTask::Classify => (
options.with_model(&args.model.unwrap_or("yolo/v5-n-cls-dyn.onnx".to_string()))?,
"YOLOv5-Classify",
),
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/v5-n-dyn.onnx".to_string()))?,
"YOLOv5-Detect",
),
YOLOTask::Segment => (
options.with_model(&args.model.unwrap_or("yolo/v5-n-seg-dyn.onnx".to_string()))?,
"YOLOv5-Segment",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V6 => match args.task {
YOLOTask::Detect => (
options
.with_model(&args.model.unwrap_or("yolo/v6-n-dyn.onnx".to_string()))?
.with_nc(args.nc),
"YOLOv6-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V7 => match args.task {
YOLOTask::Detect => (
options
.with_model(&args.model.unwrap_or("yolo/v7-tiny-dyn.onnx".to_string()))?
.with_nc(args.nc),
"YOLOv7-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V8 => match args.task {
YOLOTask::Classify => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-cls-dyn.onnx".to_string()))?,
"YOLOv8-Classify",
),
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-dyn.onnx".to_string()))?,
"YOLOv8-Detect",
),
YOLOTask::Segment => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-seg-dyn.onnx".to_string()))?,
"YOLOv8-Segment",
),
YOLOTask::Pose => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-pose-dyn.onnx".to_string()))?,
"YOLOv8-Pose",
),
YOLOTask::Obb => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-obb-dyn.onnx".to_string()))?,
"YOLOv8-Obb",
),
},
YOLOVersion::V9 => match args.task {
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/v9-c-dyn-f16.onnx".to_string()))?,
"YOLOv9-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V10 => match args.task {
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/v10-n.onnx".to_string()))?,
"YOLOv10-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::RTDETR => match args.task {
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/rtdetr-l-f16.onnx".to_string()))?,
"RTDETR",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
// model path
let path = match &args.model {
None => format!(
"yolo/{}-{}-{}.onnx",
args.ver.name(),
args.scale.name(),
args.task.name()
),
Some(x) => x.to_string(),
};
let options = options
.with_yolo_version(args.ver)
.with_yolo_task(args.task);
// saveout
let saveout = match &args.model {
None => format!(
"{}-{}-{}",
args.ver.name(),
args.scale.name(),
args.task.name()
),
Some(x) => {
let p = std::path::PathBuf::from(&x);
p.file_stem().unwrap().to_str().unwrap().to_string()
}
};
// device
let options = if args.cuda {
options.with_cuda(args.device_id)
let device = if args.cuda {
Device::Cuda(args.device_id)
} else if args.trt {
let options = options.with_trt(args.device_id);
if args.half {
options.with_trt_fp16(true)
} else {
options
}
Device::Trt(args.device_id)
} else if args.coreml {
options.with_coreml(args.device_id)
Device::CoreML(args.device_id)
} else {
options.with_cpu()
Device::Cpu(args.device_id)
};
let options = options
// build options
let options = Options::new()
.with_model(&path)?
.with_yolo_version(args.ver)
.with_yolo_task(args.task)
.with_device(device)
.with_trt_fp16(args.half)
.with_ixx(0, 0, (1, args.batch_size as _, 4).into())
.with_ixx(0, 2, (args.height_min, args.height, args.height_max).into())
.with_ixx(0, 3, (args.width_min, args.width, args.width_max).into())
.with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15
.with_confs(if args.confs.is_empty() {
&[0.2, 0.15]
} else {
&args.confs
})
.with_nc(args.nc)
// .with_names(&COCO_CLASS_NAMES_80)
.with_names2(&COCO_KEYPOINTS_17)
// .with_names2(&COCO_KEYPOINTS_17)
.with_find_contours(!args.no_contours) // find contours or not
.exclude_classes(&[0])
// .retain_classes(&[0, 5])
.with_profile(args.profile);
// build model
let mut model = YOLO::new(options)?;
// build dataloader
@ -194,16 +175,54 @@ fn main() -> Result<()> {
// build annotator
let annotator = Annotator::default()
.with_skeletons(&COCO_SKELETONS_16)
.with_bboxes_thickness(4)
.without_masks(true) // No masks plotting when doing segment task.
.with_saveout(saveout);
.with_bboxes_thickness(3)
.with_keypoints_name(false) // Enable keypoints names
.with_saveout_subs(&["YOLO"])
.with_saveout(&saveout);
// build viewer
let mut viewer = if args.view {
Some(Viewer::new().with_delay(5).with_scale(1.).resizable(true))
} else {
None
};
// run & annotate
for (xs, _paths) in dl {
// let ys = model.run(&xs)?; // way one
let ys = model.forward(&xs, args.profile)?; // way two
if !args.no_plot {
annotator.annotate(&xs, &ys);
let images_plotted = annotator.plot(&xs, &ys, !args.nosave)?;
// show image
match &mut viewer {
Some(viewer) => viewer.imshow(&images_plotted)?,
None => continue,
}
// check out window and key event
match &mut viewer {
Some(viewer) => {
if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) {
break;
}
}
None => continue,
}
// write video
if !args.nosave {
match &mut viewer {
Some(viewer) => viewer.write_batch(&images_plotted)?,
None => continue,
}
}
}
// finish video write
if !args.nosave {
if let Some(viewer) = &mut viewer {
viewer.finish_write()?;
}
}

View File

@ -48,6 +48,8 @@ pub struct Options {
pub sam_kind: Option<SamKind>,
pub use_low_res_mask: Option<bool>,
pub sapiens_task: Option<SapiensTask>,
pub classes_excluded: Vec<isize>,
pub classes_retained: Vec<isize>,
}
impl Default for Options {
@ -88,6 +90,8 @@ impl Default for Options {
use_low_res_mask: None,
sapiens_task: None,
task: Task::Untitled,
classes_excluded: vec![],
classes_retained: vec![],
}
}
}
@ -276,4 +280,16 @@ impl Options {
self.iiixs.push(Iiix::from((i, ii, x)));
self
}
pub fn exclude_classes(mut self, xs: &[isize]) -> Self {
self.classes_retained.clear();
self.classes_excluded.extend_from_slice(xs);
self
}
pub fn retain_classes(mut self, xs: &[isize]) -> Self {
self.classes_excluded.clear();
self.classes_retained.extend_from_slice(xs);
self
}
}

View File

@ -2,7 +2,9 @@ use anyhow::Result;
use half::f16;
use ndarray::{Array, IxDyn};
use ort::{
ExecutionProvider, Session, SessionBuilder, TensorElementType, TensorRTExecutionProvider,
execution_providers::{ExecutionProvider, TensorRTExecutionProvider},
session::{builder::SessionBuilder, Session},
tensor::TensorElementType,
};
use prost::Message;
use std::collections::HashSet;
@ -88,14 +90,14 @@ impl OrtEngine {
// build
ort::init().commit()?;
let builder = Session::builder()?;
let mut builder = Session::builder()?;
let mut device = config.device.to_owned();
match device {
Device::Trt(device_id) => {
Self::build_trt(
&inputs_attrs.names,
&inputs_minoptmax,
&builder,
&mut builder,
device_id,
config.trt_int8_enable,
config.trt_fp16_enable,
@ -103,23 +105,23 @@ impl OrtEngine {
)?;
}
Device::Cuda(device_id) => {
Self::build_cuda(&builder, device_id).unwrap_or_else(|err| {
Self::build_cuda(&mut builder, device_id).unwrap_or_else(|err| {
tracing::warn!("{err}, Using cpu");
device = Device::Cpu(0);
})
}
Device::CoreML(_) => Self::build_coreml(&builder).unwrap_or_else(|err| {
Device::CoreML(_) => Self::build_coreml(&mut builder).unwrap_or_else(|err| {
tracing::warn!("{err}, Using cpu");
device = Device::Cpu(0);
}),
Device::Cpu(_) => {
Self::build_cpu(&builder)?;
Self::build_cpu(&mut builder)?;
}
_ => todo!(),
}
let session = builder
.with_optimization_level(ort::GraphOptimizationLevel::Level3)?
.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)?
.commit_from_file(&config.onnx_path)?;
// summary
@ -149,7 +151,7 @@ impl OrtEngine {
fn build_trt(
names: &[String],
inputs_minoptmax: &[Vec<MinOptMax>],
builder: &SessionBuilder,
builder: &mut SessionBuilder,
device_id: usize,
int8_enable: bool,
fp16_enable: bool,
@ -205,8 +207,9 @@ impl OrtEngine {
}
}
fn build_cuda(builder: &SessionBuilder, device_id: usize) -> Result<()> {
let ep = ort::CUDAExecutionProvider::default().with_device_id(device_id as i32);
fn build_cuda(builder: &mut SessionBuilder, device_id: usize) -> Result<()> {
let ep = ort::execution_providers::CUDAExecutionProvider::default()
.with_device_id(device_id as i32);
if ep.is_available()? && ep.register(builder).is_ok() {
Ok(())
} else {
@ -214,8 +217,8 @@ impl OrtEngine {
}
}
fn build_coreml(builder: &SessionBuilder) -> Result<()> {
let ep = ort::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
fn build_coreml(builder: &mut SessionBuilder) -> Result<()> {
let ep = ort::execution_providers::CoreMLExecutionProvider::default().with_subgraphs(); //.with_ane_only();
if ep.is_available()? && ep.register(builder).is_ok() {
Ok(())
} else {
@ -223,8 +226,8 @@ impl OrtEngine {
}
}
fn build_cpu(builder: &SessionBuilder) -> Result<()> {
let ep = ort::CPUExecutionProvider::default();
fn build_cpu(builder: &mut SessionBuilder) -> Result<()> {
let ep = ort::execution_providers::CPUExecutionProvider::default();
if ep.is_available()? && ep.register(builder).is_ok() {
Ok(())
} else {
@ -292,28 +295,28 @@ impl OrtEngine {
let t_pre = std::time::Instant::now();
for (idtype, x) in self.inputs_attrs.dtypes.iter().zip(xs.into_iter()) {
let x_ = match &idtype {
TensorElementType::Float32 => ort::Value::from_array(x.view())?.into_dyn(),
TensorElementType::Float32 => ort::value::Value::from_array(x.view())?.into_dyn(),
TensorElementType::Float16 => {
ort::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
ort::value::Value::from_array(x.mapv(f16::from_f32).view())?.into_dyn()
}
TensorElementType::Int32 => {
ort::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
ort::value::Value::from_array(x.mapv(|x_| x_ as i32).view())?.into_dyn()
}
TensorElementType::Int64 => {
ort::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
ort::value::Value::from_array(x.mapv(|x_| x_ as i64).view())?.into_dyn()
}
TensorElementType::Uint8 => {
ort::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
ort::value::Value::from_array(x.mapv(|x_| x_ as u8).view())?.into_dyn()
}
TensorElementType::Int8 => {
ort::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
ort::value::Value::from_array(x.mapv(|x_| x_ as i8).view())?.into_dyn()
}
TensorElementType::Bool => {
ort::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
ort::value::Value::from_array(x.mapv(|x_| x_ != 0.).view())?.into_dyn()
}
_ => todo!(),
};
xs_.push(Into::<ort::SessionInputValue<'_>>::into(x_));
xs_.push(Into::<ort::session::SessionInputValue<'_>>::into(x_));
}
let t_pre = t_pre.elapsed();
self.ts.add_or_push(0, t_pre);
@ -451,45 +454,45 @@ impl OrtEngine {
}
#[allow(dead_code)]
fn nbytes_from_onnx_dtype(x: &ort::TensorElementType) -> usize {
fn nbytes_from_onnx_dtype(x: &ort::tensor::TensorElementType) -> usize {
match x {
ort::TensorElementType::Float64
| ort::TensorElementType::Uint64
| ort::TensorElementType::Int64 => 8, // i64, f64, u64
ort::TensorElementType::Float32
| ort::TensorElementType::Uint32
| ort::TensorElementType::Int32
| ort::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
ort::TensorElementType::Float16
| ort::TensorElementType::Bfloat16
| ort::TensorElementType::Int16
| ort::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
ort::TensorElementType::Uint8
| ort::TensorElementType::Int8
| ort::TensorElementType::Bool => 1, // u8, i8, bool
ort::tensor::TensorElementType::Float64
| ort::tensor::TensorElementType::Uint64
| ort::tensor::TensorElementType::Int64 => 8, // i64, f64, u64
ort::tensor::TensorElementType::Float32
| ort::tensor::TensorElementType::Uint32
| ort::tensor::TensorElementType::Int32
| ort::tensor::TensorElementType::String => 4, // f32, i32, u32, string(1~4)
ort::tensor::TensorElementType::Float16
| ort::tensor::TensorElementType::Bfloat16
| ort::tensor::TensorElementType::Int16
| ort::tensor::TensorElementType::Uint16 => 2, // f16, bf16, i16, u16
ort::tensor::TensorElementType::Uint8
| ort::tensor::TensorElementType::Int8
| ort::tensor::TensorElementType::Bool => 1, // u8, i8, bool
}
}
#[allow(dead_code)]
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::TensorElementType> {
fn ort_dtype_from_onnx_dtype_id(value: i32) -> Option<ort::tensor::TensorElementType> {
match value {
0 => None,
1 => Some(ort::TensorElementType::Float32),
2 => Some(ort::TensorElementType::Uint8),
3 => Some(ort::TensorElementType::Int8),
4 => Some(ort::TensorElementType::Uint16),
5 => Some(ort::TensorElementType::Int16),
6 => Some(ort::TensorElementType::Int32),
7 => Some(ort::TensorElementType::Int64),
8 => Some(ort::TensorElementType::String),
9 => Some(ort::TensorElementType::Bool),
10 => Some(ort::TensorElementType::Float16),
11 => Some(ort::TensorElementType::Float64),
12 => Some(ort::TensorElementType::Uint32),
13 => Some(ort::TensorElementType::Uint64),
1 => Some(ort::tensor::TensorElementType::Float32),
2 => Some(ort::tensor::TensorElementType::Uint8),
3 => Some(ort::tensor::TensorElementType::Int8),
4 => Some(ort::tensor::TensorElementType::Uint16),
5 => Some(ort::tensor::TensorElementType::Int16),
6 => Some(ort::tensor::TensorElementType::Int32),
7 => Some(ort::tensor::TensorElementType::Int64),
8 => Some(ort::tensor::TensorElementType::String),
9 => Some(ort::tensor::TensorElementType::Bool),
10 => Some(ort::tensor::TensorElementType::Float16),
11 => Some(ort::tensor::TensorElementType::Float64),
12 => Some(ort::tensor::TensorElementType::Uint32),
13 => Some(ort::tensor::TensorElementType::Uint64),
14 => None, // COMPLEX64
15 => None, // COMPLEX128
16 => Some(ort::TensorElementType::Bfloat16),
16 => Some(ort::tensor::TensorElementType::Bfloat16),
_ => None,
}
}
@ -499,7 +502,7 @@ impl OrtEngine {
value_info: &[onnx::ValueInfoProto],
) -> Result<OrtTensorAttr> {
let mut dimss: Vec<Vec<usize>> = Vec::new();
let mut dtypes: Vec<ort::TensorElementType> = Vec::new();
let mut dtypes: Vec<ort::tensor::TensorElementType> = Vec::new();
let mut names: Vec<String> = Vec::new();
for v in value_info.iter() {
if initializer_names.contains(v.name.as_str()) {
@ -569,7 +572,7 @@ impl OrtEngine {
&self.outputs_attrs.names
}
pub fn odtypes(&self) -> &Vec<ort::TensorElementType> {
pub fn odtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
&self.outputs_attrs.dtypes
}
@ -585,7 +588,7 @@ impl OrtEngine {
&self.inputs_attrs.names
}
pub fn idtypes(&self) -> &Vec<ort::TensorElementType> {
pub fn idtypes(&self) -> &Vec<ort::tensor::TensorElementType> {
&self.inputs_attrs.dtypes
}

86
src/models/depth_pro.rs Normal file
View File

@ -0,0 +1,86 @@
use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y};
use anyhow::Result;
use image::DynamicImage;
use ndarray::Axis;
#[derive(Debug)]
pub struct DepthPro {
engine: OrtEngine,
height: MinOptMax,
width: MinOptMax,
batch: MinOptMax,
}
impl DepthPro {
pub fn new(options: Options) -> Result<Self> {
let mut engine = OrtEngine::new(&options)?;
let (batch, height, width) = (
engine.batch().clone(),
engine.height().clone(),
engine.width().clone(),
);
engine.dry_run()?;
Ok(Self {
engine,
height,
width,
batch,
})
}
pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
let xs_ = X::apply(&[
Ops::Resize(
xs,
self.height.opt() as u32,
self.width.opt() as u32,
"Bilinear",
),
Ops::Normalize(0., 255.),
Ops::Standardize(&[0.5, 0.5, 0.5], &[0.5, 0.5, 0.5], 3),
Ops::Nhwc2nchw,
])?;
let ys = self.engine.run(Xs::from(xs_))?;
self.postprocess(ys, xs)
}
pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
let (predicted_depth, _focallength_px) = (&xs["predicted_depth"], &xs["focallength_px"]);
let predicted_depth = predicted_depth.mapv(|x| 1. / x);
let mut ys: Vec<Y> = Vec::new();
for (idx, luma) in predicted_depth.axis_iter(Axis(0)).enumerate() {
let (w1, h1) = (xs0[idx].width(), xs0[idx].height());
let v = luma.into_owned().into_raw_vec_and_offset().0;
let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap();
let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap();
let v = v
.iter()
.map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8)
.collect::<Vec<_>>();
let luma = Ops::resize_luma8_u8(
&v,
self.width.opt() as _,
self.height.opt() as _,
w1 as _,
h1 as _,
false,
"Bilinear",
)?;
let luma: image::ImageBuffer<image::Luma<_>, Vec<_>> =
match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) {
None => continue,
Some(x) => x,
};
ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)]));
}
Ok(ys)
}
pub fn batch(&self) -> isize {
self.batch.opt() as _
}
}

View File

@ -4,6 +4,7 @@ mod blip;
mod clip;
mod db;
mod depth_anything;
mod depth_pro;
mod dinov2;
mod florence2;
mod grounding_dino;
@ -20,6 +21,7 @@ pub use blip::Blip;
pub use clip::Clip;
pub use db::DB;
pub use depth_anything::DepthAnything;
pub use depth_pro::DepthPro;
pub use dinov2::Dinov2;
pub use florence2::Florence2;
pub use grounding_dino::GroundingDINO;

View File

@ -20,12 +20,14 @@ pub struct YOLO {
confs: DynConf,
kconfs: DynConf,
iou: f32,
names: Option<Vec<String>>,
names_kpt: Option<Vec<String>>,
names: Vec<String>,
names_kpt: Vec<String>,
task: YOLOTask,
layout: YOLOPreds,
find_contours: bool,
version: Option<YOLOVersion>,
classes_excluded: Vec<isize>,
classes_retained: Vec<isize>,
}
impl Vision for YOLO {
@ -64,27 +66,26 @@ impl Vision for YOLO {
Some(task) => match task {
YOLOTask::Classify => match ver {
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_clss().apply_softmax(true)),
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_clss()),
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_clss()),
x => anyhow::bail!("YOLOTask::Classify is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
}
YOLOTask::Detect => match ver {
YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver),YOLOPreds::n_a_cxcywh_confclss()),
YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()),
YOLOVersion::V9 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()),
YOLOVersion::V10 => (Some(ver),YOLOPreds::n_a_xyxy_confcls().apply_nms(false)),
YOLOVersion::RTDETR => (Some(ver),YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)),
YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss()),
YOLOVersion::V8 | YOLOVersion::V9 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_a()),
YOLOVersion::V10 => (Some(ver), YOLOPreds::n_a_xyxy_confcls().apply_nms(false)),
YOLOVersion::RTDETR => (Some(ver), YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)),
}
YOLOTask::Pose => match ver {
YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_xycs_a()),
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_xycs_a()),
x => anyhow::bail!("YOLOTask::Pose is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
}
YOLOTask::Segment => match ver {
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss_coefs()),
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()),
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()),
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
}
YOLOTask::Obb => match ver {
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()),
YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()),
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
}
}
@ -97,49 +98,75 @@ impl Vision for YOLO {
let task = task.unwrap_or(layout.task());
// The number of classes & Class names
let mut names = options.names.or(Self::fetch_names(&engine));
let nc = match options.nc {
Some(nc) => {
match &names {
None => names = Some((0..nc).map(|x| x.to_string()).collect::<Vec<String>>()),
Some(names) => {
assert_eq!(
nc,
// Class names: user-defined.or(parsed)
let names_parsed = Self::fetch_names(&engine);
let names = match names_parsed {
Some(names_parsed) => match options.names {
Some(names) => {
if names.len() == names_parsed.len() {
Some(names)
} else {
anyhow::bail!(
"The lengths of parsed class names: {} and user-defined class names: {} do not match.",
names_parsed.len(),
names.len(),
"The length of `nc` and `class names` is not equal."
);
}
}
nc
}
None => match &names {
Some(names) => names.len(),
None => panic!(
"Can not parse model without `nc` and `class names`. Try to make it explicit with `options.with_nc(80)`"
None => Some(names_parsed),
},
None => options.names,
};
// nc: names.len().or(options.nc)
let nc = match &names {
Some(names) => names.len(),
None => match options.nc {
Some(nc) => nc,
None => anyhow::bail!(
"Unable to obtain the number of classes. Please specify them explicitly using `options.with_nc(usize)` or `options.with_names(&[&str])`."
),
}
};
// Class names
let names = match names {
None => Self::n2s(nc),
Some(names) => names,
};
// Keypoint names & nk
let (nk, names_kpt) = match Self::fetch_kpts(&engine) {
None => (0, vec![]),
Some(nk) => match options.names2 {
Some(names) => {
if names.len() == nk {
(nk, names)
} else {
anyhow::bail!(
"The lengths of user-defined keypoint names: {} and nk: {} do not match.",
names.len(),
nk,
);
}
}
None => (nk, Self::n2s(nk)),
},
};
// Keypoints names
let names_kpt = options.names2;
// The number of keypoints
let nk = engine
.try_fetch("kpt_shape")
.map(|kpt_string| {
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
let caps = re.captures(&kpt_string).unwrap();
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
})
.unwrap_or(0_usize);
// Confs & Iou
let confs = DynConf::new(&options.confs, nc);
let kconfs = DynConf::new(&options.kconfs, nk);
let iou = options.iou.unwrap_or(0.45);
// Classes excluded and retained
let classes_excluded = options.classes_excluded;
let classes_retained = options.classes_retained;
// Summary
tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version);
// dry run
engine.dry_run()?;
Ok(Self {
@ -158,6 +185,8 @@ impl Vision for YOLO {
layout,
version,
find_contours: options.find_contours,
classes_excluded,
classes_retained,
})
}
@ -219,10 +248,8 @@ impl Vision for YOLO {
slice_clss.into_owned()
};
let mut probs = Prob::default().with_probs(&x.into_raw_vec_and_offset().0);
if let Some(names) = &self.names {
probs =
probs.with_names(&names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
}
probs = probs
.with_names(&self.names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
return Some(y.with_probs(&probs));
}
@ -257,7 +284,19 @@ impl Vision for YOLO {
}
};
// filtering
// filtering by class id
if !self.classes_excluded.is_empty()
&& self.classes_excluded.contains(&(class_id as isize))
{
return None;
}
if !self.classes_retained.is_empty()
&& !self.classes_retained.contains(&(class_id as isize))
{
return None;
}
// filtering by conf
if confidence < self.confs[class_id] {
return None;
}
@ -325,9 +364,7 @@ impl Vision for YOLO {
)
.with_confidence(confidence)
.with_id(class_id as isize);
if let Some(names) = &self.names {
mbr = mbr.with_name(&names[class_id]);
}
mbr = mbr.with_name(&self.names[class_id]);
(None, Some(mbr))
}
@ -337,9 +374,7 @@ impl Vision for YOLO {
.with_confidence(confidence)
.with_id(class_id as isize)
.with_id_born(i as isize);
if let Some(names) = &self.names {
bbox = bbox.with_name(&names[class_id]);
}
bbox = bbox.with_name(&self.names[class_id]);
(Some(bbox), None)
}
@ -394,9 +429,7 @@ impl Vision for YOLO {
ky.max(0.0f32).min(image_height),
);
if let Some(names) = &self.names_kpt {
kpt = kpt.with_name(&names[i]);
}
kpt = kpt.with_name(&self.names_kpt[i]);
kpt
}
})
@ -505,16 +538,16 @@ impl Vision for YOLO {
}
impl YOLO {
pub fn batch(&self) -> isize {
self.batch.opt() as _
pub fn batch(&self) -> usize {
self.batch.opt()
}
pub fn width(&self) -> isize {
self.width.opt() as _
pub fn width(&self) -> usize {
self.width.opt()
}
pub fn height(&self) -> isize {
self.height.opt() as _
pub fn height(&self) -> usize {
self.height.opt()
}
pub fn version(&self) -> Option<&YOLOVersion> {
@ -541,4 +574,16 @@ impl YOLO {
names_
})
}
fn fetch_kpts(engine: &OrtEngine) -> Option<usize> {
engine.try_fetch("kpt_shape").map(|s| {
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
let caps = re.captures(&s).unwrap();
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
})
}
fn n2s(n: usize) -> Vec<String> {
(0..n).map(|x| format!("# {}", x)).collect::<Vec<String>>()
}
}

View File

@ -9,6 +9,28 @@ pub enum YOLOTask {
Obb,
}
impl YOLOTask {
pub fn name(&self) -> String {
match self {
Self::Classify => "cls".to_string(),
Self::Detect => "det".to_string(),
Self::Pose => "pose".to_string(),
Self::Segment => "seg".to_string(),
Self::Obb => "obb".to_string(),
}
}
pub fn name_detailed(&self) -> String {
match self {
Self::Classify => "image-classification".to_string(),
Self::Detect => "object-detection".to_string(),
Self::Pose => "pose-estimation".to_string(),
Self::Segment => "instance-segment".to_string(),
Self::Obb => "oriented-object-detection".to_string(),
}
}
}
#[derive(Debug, Copy, Clone, clap::ValueEnum)]
pub enum YOLOVersion {
V5,
@ -17,9 +39,54 @@ pub enum YOLOVersion {
V8,
V9,
V10,
V11,
RTDETR,
}
impl YOLOVersion {
pub fn name(&self) -> String {
match self {
Self::V5 => "v5".to_string(),
Self::V6 => "v6".to_string(),
Self::V7 => "v7".to_string(),
Self::V8 => "v8".to_string(),
Self::V9 => "v9".to_string(),
Self::V10 => "v10".to_string(),
Self::V11 => "v11".to_string(),
Self::RTDETR => "rtdetr".to_string(),
}
}
}
#[derive(Debug, Copy, Clone, clap::ValueEnum)]
pub enum YOLOScale {
N,
T,
B,
S,
M,
L,
C,
E,
X,
}
impl YOLOScale {
pub fn name(&self) -> String {
match self {
Self::N => "n".to_string(),
Self::T => "t".to_string(),
Self::S => "s".to_string(),
Self::B => "b".to_string(),
Self::M => "m".to_string(),
Self::L => "l".to_string(),
Self::C => "c".to_string(),
Self::E => "e".to_string(),
Self::X => "x".to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum BoxType {
/// 1