Add YOLOv11

* Add YOLOv11
This commit is contained in:
Jamjamjon 2024-09-30 22:43:34 +08:00 committed by GitHub
parent 2cb9e57fc4
commit 0609dd1f1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 300 additions and 200 deletions

View File

@ -1,6 +1,6 @@
[package] [package]
name = "usls" name = "usls"
version = "0.0.16" version = "0.0.17"
edition = "2021" edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls" repository = "https://github.com/jamjamjon/usls"

View File

@ -24,45 +24,41 @@
## Quick Start ## Quick Start
```Shell ```Shell
# customized
cargo run -r --example yolo -- --task detect --ver v8 --nc 6 --model xxx.onnx # YOLOv8
# Classify # Classify
cargo run -r --example yolo -- --task classify --ver v5 # YOLOv5 cargo run -r --example yolo -- --task classify --ver v5 --scale s --width 224 --height 224 --nc 1000 # YOLOv5
cargo run -r --example yolo -- --task classify --ver v8 # YOLOv8 cargo run -r --example yolo -- --task classify --ver v8 --scale n --width 224 --height 224 --nc 1000 # YOLOv8
cargo run -r --example yolo -- --task classify --ver v11 --scale n --width 224 --height 224 --nc 1000 # YOLOv11
# Detect # Detect
cargo run -r --example yolo -- --task detect --ver v5 # YOLOv5 cargo run -r --example yolo -- --task detect --ver v5 --scale n # YOLOv5
cargo run -r --example yolo -- --task detect --ver v6 # YOLOv6 cargo run -r --example yolo -- --task detect --ver v6 --scale n # YOLOv6
cargo run -r --example yolo -- --task detect --ver v7 # YOLOv7 cargo run -r --example yolo -- --task detect --ver v7 --scale t # YOLOv7
cargo run -r --example yolo -- --task detect --ver v8 # YOLOv8 cargo run -r --example yolo -- --task detect --ver v8 --scale n # YOLOv8
cargo run -r --example yolo -- --task detect --ver v9 # YOLOv9 cargo run -r --example yolo -- --task detect --ver v9 --scale t # YOLOv9
cargo run -r --example yolo -- --task detect --ver v10 # YOLOv10 cargo run -r --example yolo -- --task detect --ver v10 --scale n # YOLOv10
cargo run -r --example yolo -- --task detect --ver rtdetr # YOLOv8-RTDETR cargo run -r --example yolo -- --task detect --ver v11 --scale n # YOLOv11
cargo run -r --example yolo -- --task detect --ver v8 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world cargo run -r --example yolo -- --task detect --ver rtdetr --scale l # RTDETR
cargo run -r --example yolo -- --task detect --ver v8 --nc 1 --model yolov8s-world-v2-shoes.onnx # YOLOv8-world <local file>
# Pose # Pose
cargo run -r --example yolo -- --task pose --ver v8 # YOLOv8-Pose cargo run -r --example yolo -- --task pose --ver v8 --scale n # YOLOv8-Pose
cargo run -r --example yolo -- --task pose --ver v11 --scale n # YOLOv11-Pose
# Segment # Segment
cargo run -r --example yolo -- --task segment --ver v5 # YOLOv5-Segment cargo run -r --example yolo -- --task segment --ver v5 --scale n # YOLOv5-Segment
cargo run -r --example yolo -- --task segment --ver v8 # YOLOv8-Segment cargo run -r --example yolo -- --task segment --ver v8 --scale n # YOLOv8-Segment
cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM cargo run -r --example yolo -- --task segment --ver v11 --scale n # YOLOv8-Segment
cargo run -r --example yolo -- --task segment --ver v8 --model FastSAM-s-dyn-f16.onnx # FastSAM <local file>
# Obb # Obb
cargo run -r --example yolo -- --task obb --ver v8 # YOLOv8-Obb cargo run -r --example yolo -- --ver v8 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv8-Obb
cargo run -r --example yolo -- --ver v11 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv11-Obb
``` ```
<details close> **`cargo run -r --example yolo -- --help` for more options**
<summary>other options</summary>
`--source` to specify the input images
`--model` to specify the ONNX model
`--width --height` to specify the input resolution
`--nc` to specify the number of model's classes
`--plot` to annotate with inference results
`--profile` to profile
`--cuda --trt --coreml --device_id` to select device
`--half` to use float16 when using TensorRT EP
</details>
## YOLOs configs with `Options` ## YOLOs configs with `Options`
@ -74,6 +70,8 @@ cargo run -r --example yolo -- --task obb --ver v8 # YOLOv8-Obb
let options = Options::default() let options = Options::default()
.with_yolo_version(YOLOVersion::V5) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR .with_yolo_version(YOLOVersion::V5) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
.with_yolo_task(YOLOTask::Classify) // YOLOTask: Classify, Detect, Pose, Segment, Obb .with_yolo_task(YOLOTask::Classify) // YOLOTask: Classify, Detect, Pose, Segment, Obb
// .with_nc(80)
// .with_names(&COCO_CLASS_NAMES_80)
.with_model("xxxx.onnx")?; .with_model("xxxx.onnx")?;
``` ```
@ -140,7 +138,7 @@ let options = Options::default()
</details> </details>
<details close> <details close>
<summary>YOLOv8</summary> <summary>YOLOv8, YOLOv11</summary>
```Shell ```Shell
pip install -U ultralytics pip install -U ultralytics

View File

@ -2,188 +2,160 @@ use anyhow::Result;
use clap::Parser; use clap::Parser;
use usls::{ use usls::{
models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion, COCO_KEYPOINTS_17, models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOScale, YOLOTask,
COCO_SKELETONS_16, YOLOVersion, COCO_SKELETONS_16,
}; };
#[derive(Parser, Clone)] #[derive(Parser, Clone)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
pub struct Args { pub struct Args {
/// Path to the model
#[arg(long)] #[arg(long)]
pub model: Option<String>, pub model: Option<String>,
/// Input source path
#[arg(long, default_value_t = String::from("./assets/bus.jpg"))] #[arg(long, default_value_t = String::from("./assets/bus.jpg"))]
pub source: String, pub source: String,
/// YOLO Task
#[arg(long, value_enum, default_value_t = YOLOTask::Detect)] #[arg(long, value_enum, default_value_t = YOLOTask::Detect)]
pub task: YOLOTask, pub task: YOLOTask,
/// YOLO Version
#[arg(long, value_enum, default_value_t = YOLOVersion::V8)] #[arg(long, value_enum, default_value_t = YOLOVersion::V8)]
pub ver: YOLOVersion, pub ver: YOLOVersion,
/// YOLO Scale
#[arg(long, value_enum, default_value_t = YOLOScale::N)]
pub scale: YOLOScale,
/// Batch size
#[arg(long, default_value_t = 1)] #[arg(long, default_value_t = 1)]
pub batch_size: usize, pub batch_size: usize,
/// Minimum input width
#[arg(long, default_value_t = 224)] #[arg(long, default_value_t = 224)]
pub width_min: isize, pub width_min: isize,
/// Input width
#[arg(long, default_value_t = 640)] #[arg(long, default_value_t = 640)]
pub width: isize, pub width: isize,
#[arg(long, default_value_t = 800)] /// Maximum input width
#[arg(long, default_value_t = 1024)]
pub width_max: isize, pub width_max: isize,
/// Minimum input height
#[arg(long, default_value_t = 224)] #[arg(long, default_value_t = 224)]
pub height_min: isize, pub height_min: isize,
/// Input height
#[arg(long, default_value_t = 640)] #[arg(long, default_value_t = 640)]
pub height: isize, pub height: isize,
#[arg(long, default_value_t = 800)] /// Maximum input height
#[arg(long, default_value_t = 1024)]
pub height_max: isize, pub height_max: isize,
/// Number of classes
#[arg(long, default_value_t = 80)] #[arg(long, default_value_t = 80)]
pub nc: usize, pub nc: usize,
/// Class confidence
#[arg(long)]
pub confs: Vec<f32>,
/// Enable TensorRT support
#[arg(long)] #[arg(long)]
pub trt: bool, pub trt: bool,
/// Enable CUDA support
#[arg(long)] #[arg(long)]
pub cuda: bool, pub cuda: bool,
#[arg(long)] /// Enable CoreML support
pub half: bool,
#[arg(long)] #[arg(long)]
pub coreml: bool, pub coreml: bool,
/// Use TensorRT half precision
#[arg(long)]
pub half: bool,
/// Device ID to use
#[arg(long, default_value_t = 0)] #[arg(long, default_value_t = 0)]
pub device_id: usize, pub device_id: usize,
/// Enable performance profiling
#[arg(long)] #[arg(long)]
pub profile: bool, pub profile: bool,
#[arg(long)] /// Disable contour drawing
pub no_plot: bool,
#[arg(long)] #[arg(long)]
pub no_contours: bool, pub no_contours: bool,
/// Show result
#[arg(long)]
pub view: bool,
/// Do not save output
#[arg(long)]
pub nosave: bool,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
let args = Args::parse(); let args = Args::parse();
// build options // path
let options = Options::default(); let path = args.model.unwrap_or({
format!(
"yolo/{}-{}-{}.onnx",
args.ver.name(),
args.scale.name(),
args.task.name()
)
});
// version & task // saveout
let (options, saveout) = match args.ver { let saveout = format!(
YOLOVersion::V5 => match args.task { "{}-{}-{}",
YOLOTask::Classify => ( args.ver.name(),
options.with_model(&args.model.unwrap_or("yolo/v5-n-cls-dyn.onnx".to_string()))?, args.scale.name(),
"YOLOv5-Classify", args.task.name()
), );
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/v5-n-dyn.onnx".to_string()))?,
"YOLOv5-Detect",
),
YOLOTask::Segment => (
options.with_model(&args.model.unwrap_or("yolo/v5-n-seg-dyn.onnx".to_string()))?,
"YOLOv5-Segment",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V6 => match args.task {
YOLOTask::Detect => (
options
.with_model(&args.model.unwrap_or("yolo/v6-n-dyn.onnx".to_string()))?
.with_nc(args.nc),
"YOLOv6-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V7 => match args.task {
YOLOTask::Detect => (
options
.with_model(&args.model.unwrap_or("yolo/v7-tiny-dyn.onnx".to_string()))?
.with_nc(args.nc),
"YOLOv7-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V8 => match args.task {
YOLOTask::Classify => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-cls-dyn.onnx".to_string()))?,
"YOLOv8-Classify",
),
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-dyn.onnx".to_string()))?,
"YOLOv8-Detect",
),
YOLOTask::Segment => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-seg-dyn.onnx".to_string()))?,
"YOLOv8-Segment",
),
YOLOTask::Pose => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-pose-dyn.onnx".to_string()))?,
"YOLOv8-Pose",
),
YOLOTask::Obb => (
options.with_model(&args.model.unwrap_or("yolo/v8-m-obb-dyn.onnx".to_string()))?,
"YOLOv8-Obb",
),
},
YOLOVersion::V9 => match args.task {
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/v9-c-dyn-f16.onnx".to_string()))?,
"YOLOv9-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::V10 => match args.task {
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/v10-n.onnx".to_string()))?,
"YOLOv10-Detect",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
YOLOVersion::RTDETR => match args.task {
YOLOTask::Detect => (
options.with_model(&args.model.unwrap_or("yolo/rtdetr-l-f16.onnx".to_string()))?,
"RTDETR",
),
t => anyhow::bail!("Task: {t:?} is unsupported for {:?}", args.ver),
},
};
let options = options
.with_yolo_version(args.ver)
.with_yolo_task(args.task);
// device // device
let options = if args.cuda { let device = if args.cuda {
options.with_cuda(args.device_id) Device::Cuda(args.device_id)
} else if args.trt { } else if args.trt {
let options = options.with_trt(args.device_id); Device::Trt(args.device_id)
if args.half {
options.with_trt_fp16(true)
} else {
options
}
} else if args.coreml { } else if args.coreml {
options.with_coreml(args.device_id) Device::CoreML(args.device_id)
} else { } else {
options.with_cpu() Device::Cpu(args.device_id)
}; };
let options = options
// build options
let options = Options::new()
.with_model(&path)?
.with_yolo_version(args.ver)
.with_yolo_task(args.task)
.with_device(device)
.with_trt_fp16(args.half)
.with_ixx(0, 0, (1, args.batch_size as _, 4).into()) .with_ixx(0, 0, (1, args.batch_size as _, 4).into())
.with_ixx(0, 2, (args.height_min, args.height, args.height_max).into()) .with_ixx(0, 2, (args.height_min, args.height, args.height_max).into())
.with_ixx(0, 3, (args.width_min, args.width, args.width_max).into()) .with_ixx(0, 3, (args.width_min, args.width, args.width_max).into())
.with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15 .with_confs(if args.confs.is_empty() {
&[0.2, 0.15]
} else {
&args.confs
})
.with_nc(args.nc)
// .with_names(&COCO_CLASS_NAMES_80) // .with_names(&COCO_CLASS_NAMES_80)
.with_names2(&COCO_KEYPOINTS_17) // .with_names2(&COCO_KEYPOINTS_17)
.with_find_contours(!args.no_contours) // find contours or not .with_find_contours(!args.no_contours) // find contours or not
.with_profile(args.profile); .with_profile(args.profile);
// build model
let mut model = YOLO::new(options)?; let mut model = YOLO::new(options)?;
// build dataloader // build dataloader
@ -194,16 +166,54 @@ fn main() -> Result<()> {
// build annotator // build annotator
let annotator = Annotator::default() let annotator = Annotator::default()
.with_skeletons(&COCO_SKELETONS_16) .with_skeletons(&COCO_SKELETONS_16)
.with_bboxes_thickness(4)
.without_masks(true) // No masks plotting when doing segment task. .without_masks(true) // No masks plotting when doing segment task.
.with_saveout(saveout); .with_bboxes_thickness(3)
.with_keypoints_name(false) // Enable keypoints names
.with_saveout_subs(&["YOLO"])
.with_saveout(&saveout);
// build viewer
let mut viewer = if args.view {
Some(Viewer::new().with_delay(5).with_scale(1.).resizable(true))
} else {
None
};
// run & annotate // run & annotate
for (xs, _paths) in dl { for (xs, _paths) in dl {
// let ys = model.run(&xs)?; // way one // let ys = model.run(&xs)?; // way one
let ys = model.forward(&xs, args.profile)?; // way two let ys = model.forward(&xs, args.profile)?; // way two
if !args.no_plot { let images_plotted = annotator.plot(&xs, &ys, !args.nosave)?;
annotator.annotate(&xs, &ys);
// show image
match &mut viewer {
Some(viewer) => viewer.imshow(&images_plotted)?,
None => continue,
}
// check out window and key event
match &mut viewer {
Some(viewer) => {
if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) {
break;
}
}
None => continue,
}
// write video
if !args.nosave {
match &mut viewer {
Some(viewer) => viewer.write_batch(&images_plotted)?,
None => continue,
}
}
}
// finish video write
if !args.nosave {
if let Some(viewer) = &mut viewer {
viewer.finish_write()?;
} }
} }

View File

@ -20,8 +20,8 @@ pub struct YOLO {
confs: DynConf, confs: DynConf,
kconfs: DynConf, kconfs: DynConf,
iou: f32, iou: f32,
names: Option<Vec<String>>, names: Vec<String>,
names_kpt: Option<Vec<String>>, names_kpt: Vec<String>,
task: YOLOTask, task: YOLOTask,
layout: YOLOPreds, layout: YOLOPreds,
find_contours: bool, find_contours: bool,
@ -64,27 +64,26 @@ impl Vision for YOLO {
Some(task) => match task { Some(task) => match task {
YOLOTask::Classify => match ver { YOLOTask::Classify => match ver {
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_clss().apply_softmax(true)), YOLOVersion::V5 => (Some(ver), YOLOPreds::n_clss().apply_softmax(true)),
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_clss()), YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_clss()),
x => anyhow::bail!("YOLOTask::Classify is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") x => anyhow::bail!("YOLOTask::Classify is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
} }
YOLOTask::Detect => match ver { YOLOTask::Detect => match ver {
YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver),YOLOPreds::n_a_cxcywh_confclss()), YOLOVersion::V5 | YOLOVersion::V6 | YOLOVersion::V7 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss()),
YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()), YOLOVersion::V8 | YOLOVersion::V9 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_a()),
YOLOVersion::V9 => (Some(ver),YOLOPreds::n_cxcywh_clss_a()), YOLOVersion::V10 => (Some(ver), YOLOPreds::n_a_xyxy_confcls().apply_nms(false)),
YOLOVersion::V10 => (Some(ver),YOLOPreds::n_a_xyxy_confcls().apply_nms(false)), YOLOVersion::RTDETR => (Some(ver), YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)),
YOLOVersion::RTDETR => (Some(ver),YOLOPreds::n_a_cxcywh_clss_n().apply_nms(false)),
} }
YOLOTask::Pose => match ver { YOLOTask::Pose => match ver {
YOLOVersion::V8 => (Some(ver),YOLOPreds::n_cxcywh_clss_xycs_a()), YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_xycs_a()),
x => anyhow::bail!("YOLOTask::Pose is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") x => anyhow::bail!("YOLOTask::Pose is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
} }
YOLOTask::Segment => match ver { YOLOTask::Segment => match ver {
YOLOVersion::V5 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss_coefs()), YOLOVersion::V5 => (Some(ver), YOLOPreds::n_a_cxcywh_confclss_coefs()),
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()), YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_coefs_a()),
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
} }
YOLOTask::Obb => match ver { YOLOTask::Obb => match ver {
YOLOVersion::V8 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()), YOLOVersion::V8 | YOLOVersion::V11 => (Some(ver), YOLOPreds::n_cxcywh_clss_r_a()),
x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.") x => anyhow::bail!("YOLOTask::Segment is unsupported for {x:?}. Try using `.with_yolo_preds()` for customization.")
} }
} }
@ -97,42 +96,63 @@ impl Vision for YOLO {
let task = task.unwrap_or(layout.task()); let task = task.unwrap_or(layout.task());
// The number of classes & Class names // Class names: user-defined.or(parsed)
let mut names = options.names.or(Self::fetch_names(&engine)); let names_parsed = Self::fetch_names(&engine);
let nc = match options.nc { let names = match names_parsed {
Some(nc) => { Some(names_parsed) => match options.names {
match &names { Some(names) => {
None => names = Some((0..nc).map(|x| x.to_string()).collect::<Vec<String>>()), if names.len() == names_parsed.len() {
Some(names) => { Some(names)
assert_eq!( } else {
nc, anyhow::bail!(
"The lengths of parsed class names: {} and user-defined class names: {} do not match.",
names_parsed.len(),
names.len(), names.len(),
"The length of `nc` and `class names` is not equal."
); );
} }
} }
nc None => Some(names_parsed),
} },
None => match &names { None => options.names,
Some(names) => names.len(), };
None => panic!(
"Can not parse model without `nc` and `class names`. Try to make it explicit with `options.with_nc(80)`" // nc: names.len().or(options.nc)
let nc = match &names {
Some(names) => names.len(),
None => match options.nc {
Some(nc) => nc,
None => anyhow::bail!(
"Unable to obtain the number of classes. Please specify them explicitly using `options.with_nc(usize)` or `options.with_names(&[&str])`."
), ),
}
};
// Class names
let names = match names {
None => Self::n2s(nc),
Some(names) => names,
};
// Keypoint names & nk
let (nk, names_kpt) = match Self::fetch_kpts(&engine) {
None => (0, vec![]),
Some(nk) => match options.names2 {
Some(names) => {
if names.len() == nk {
(nk, names)
} else {
anyhow::bail!(
"The lengths of user-defined keypoint names: {} and nk: {} do not match.",
names.len(),
nk,
);
}
}
None => (nk, Self::n2s(nk)),
}, },
}; };
// Keypoints names // Confs & Iou
let names_kpt = options.names2;
// The number of keypoints
let nk = engine
.try_fetch("kpt_shape")
.map(|kpt_string| {
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
let caps = re.captures(&kpt_string).unwrap();
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
})
.unwrap_or(0_usize);
let confs = DynConf::new(&options.confs, nc); let confs = DynConf::new(&options.confs, nc);
let kconfs = DynConf::new(&options.kconfs, nk); let kconfs = DynConf::new(&options.kconfs, nk);
let iou = options.iou.unwrap_or(0.45); let iou = options.iou.unwrap_or(0.45);
@ -140,6 +160,7 @@ impl Vision for YOLO {
// Summary // Summary
tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version); tracing::info!("YOLO Task: {:?}, Version: {:?}", task, version);
// dry run
engine.dry_run()?; engine.dry_run()?;
Ok(Self { Ok(Self {
@ -219,10 +240,8 @@ impl Vision for YOLO {
slice_clss.into_owned() slice_clss.into_owned()
}; };
let mut probs = Prob::default().with_probs(&x.into_raw_vec_and_offset().0); let mut probs = Prob::default().with_probs(&x.into_raw_vec_and_offset().0);
if let Some(names) = &self.names { probs = probs
probs = .with_names(&self.names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
probs.with_names(&names.iter().map(|x| x.as_str()).collect::<Vec<_>>());
}
return Some(y.with_probs(&probs)); return Some(y.with_probs(&probs));
} }
@ -325,9 +344,7 @@ impl Vision for YOLO {
) )
.with_confidence(confidence) .with_confidence(confidence)
.with_id(class_id as isize); .with_id(class_id as isize);
if let Some(names) = &self.names { mbr = mbr.with_name(&self.names[class_id]);
mbr = mbr.with_name(&names[class_id]);
}
(None, Some(mbr)) (None, Some(mbr))
} }
@ -337,9 +354,7 @@ impl Vision for YOLO {
.with_confidence(confidence) .with_confidence(confidence)
.with_id(class_id as isize) .with_id(class_id as isize)
.with_id_born(i as isize); .with_id_born(i as isize);
if let Some(names) = &self.names { bbox = bbox.with_name(&self.names[class_id]);
bbox = bbox.with_name(&names[class_id]);
}
(Some(bbox), None) (Some(bbox), None)
} }
@ -394,9 +409,7 @@ impl Vision for YOLO {
ky.max(0.0f32).min(image_height), ky.max(0.0f32).min(image_height),
); );
if let Some(names) = &self.names_kpt { kpt = kpt.with_name(&self.names_kpt[i]);
kpt = kpt.with_name(&names[i]);
}
kpt kpt
} }
}) })
@ -505,16 +518,16 @@ impl Vision for YOLO {
} }
impl YOLO { impl YOLO {
pub fn batch(&self) -> isize { pub fn batch(&self) -> usize {
self.batch.opt() as _ self.batch.opt()
} }
pub fn width(&self) -> isize { pub fn width(&self) -> usize {
self.width.opt() as _ self.width.opt()
} }
pub fn height(&self) -> isize { pub fn height(&self) -> usize {
self.height.opt() as _ self.height.opt()
} }
pub fn version(&self) -> Option<&YOLOVersion> { pub fn version(&self) -> Option<&YOLOVersion> {
@ -541,4 +554,16 @@ impl YOLO {
names_ names_
}) })
} }
fn fetch_kpts(engine: &OrtEngine) -> Option<usize> {
engine.try_fetch("kpt_shape").map(|s| {
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
let caps = re.captures(&s).unwrap();
caps.get(1).unwrap().as_str().parse::<usize>().unwrap()
})
}
fn n2s(n: usize) -> Vec<String> {
(0..n).map(|x| format!("# {}", x)).collect::<Vec<String>>()
}
} }

View File

@ -9,6 +9,28 @@ pub enum YOLOTask {
Obb, Obb,
} }
impl YOLOTask {
pub fn name(&self) -> String {
match self {
Self::Classify => "cls".to_string(),
Self::Detect => "det".to_string(),
Self::Pose => "pose".to_string(),
Self::Segment => "seg".to_string(),
Self::Obb => "obb".to_string(),
}
}
pub fn name_detailed(&self) -> String {
match self {
Self::Classify => "image-classification".to_string(),
Self::Detect => "object-detection".to_string(),
Self::Pose => "pose-estimation".to_string(),
Self::Segment => "instance-segment".to_string(),
Self::Obb => "oriented-object-detection".to_string(),
}
}
}
#[derive(Debug, Copy, Clone, clap::ValueEnum)] #[derive(Debug, Copy, Clone, clap::ValueEnum)]
pub enum YOLOVersion { pub enum YOLOVersion {
V5, V5,
@ -17,9 +39,54 @@ pub enum YOLOVersion {
V8, V8,
V9, V9,
V10, V10,
V11,
RTDETR, RTDETR,
} }
impl YOLOVersion {
pub fn name(&self) -> String {
match self {
Self::V5 => "v5".to_string(),
Self::V6 => "v6".to_string(),
Self::V7 => "v7".to_string(),
Self::V8 => "v8".to_string(),
Self::V9 => "v9".to_string(),
Self::V10 => "v10".to_string(),
Self::V11 => "v11".to_string(),
Self::RTDETR => "rtdetr".to_string(),
}
}
}
#[derive(Debug, Copy, Clone, clap::ValueEnum)]
pub enum YOLOScale {
N,
T,
B,
S,
M,
L,
C,
E,
X,
}
impl YOLOScale {
pub fn name(&self) -> String {
match self {
Self::N => "n".to_string(),
Self::T => "t".to_string(),
Self::S => "s".to_string(),
Self::B => "b".to_string(),
Self::M => "m".to_string(),
Self::L => "l".to_string(),
Self::C => "c".to_string(),
Self::E => "e".to_string(),
Self::X => "x".to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub enum BoxType { pub enum BoxType {
/// 1 /// 1