diff --git a/Cargo.toml b/Cargo.toml
index ca76358..3722b18 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "usls"
-version = "0.0.11"
+version = "0.0.12"
edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"
diff --git a/README.md b/README.md
index b5dbca3..8a2c9aa 100644
--- a/README.md
+++ b/README.md
@@ -36,7 +36,7 @@
- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10)
- **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
-- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet)
+- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569)
- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World)
@@ -70,6 +70,7 @@
| [Depth-Anything](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ |
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ |
| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | |
+| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Body Part Segmentation | [demo](examples/sapiens) | ✅ | ✅ | | |
diff --git a/assets/paul-george.jpg b/assets/paul-george.jpg
new file mode 100644
index 0000000..9b92635
Binary files /dev/null and b/assets/paul-george.jpg differ
diff --git a/benches/yolo.rs b/benches/yolo.rs
index 4868ba6..b128338 100644
--- a/benches/yolo.rs
+++ b/benches/yolo.rs
@@ -1,7 +1,7 @@
use anyhow::Result;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
-use usls::{coco, models::YOLO, DataLoader, Options, Vision, YOLOTask, YOLOVersion};
+use usls::{models::YOLO, DataLoader, Options, Vision, YOLOTask, YOLOVersion, COCO_KEYPOINTS_17};
enum Stage {
Pre,
@@ -60,7 +60,7 @@ pub fn benchmark_cuda(c: &mut Criterion, h: isize, w: isize) -> Result<()> {
.with_i02((320, h, 1280).into())
.with_i03((320, w, 1280).into())
.with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15
- .with_names2(&coco::KEYPOINTS_NAMES_17);
+ .with_names2(&COCO_KEYPOINTS_17);
let mut model = YOLO::new(options)?;
let xs = vec![DataLoader::try_read("./assets/bus.jpg")?];
diff --git a/examples/depth-anything/main.rs b/examples/depth-anything/main.rs
index aa84691..9a4cd38 100644
--- a/examples/depth-anything/main.rs
+++ b/examples/depth-anything/main.rs
@@ -11,7 +11,7 @@ fn main() -> Result<(), Box> {
let mut model = DepthAnything::new(options)?;
// load
- let x = vec![DataLoader::try_read("./assets/2.jpg")?];
+ let x = [DataLoader::try_read("./assets/2.jpg")?];
// run
let y = model.run(&x)?;
diff --git a/examples/rtmo/main.rs b/examples/rtmo/main.rs
index bf6f696..b685c1f 100644
--- a/examples/rtmo/main.rs
+++ b/examples/rtmo/main.rs
@@ -1,4 +1,4 @@
-use usls::{coco, models::RTMO, Annotator, DataLoader, Options};
+use usls::{models::RTMO, Annotator, DataLoader, Options, COCO_SKELETONS_16};
fn main() -> Result<(), Box> {
// build model
@@ -19,7 +19,7 @@ fn main() -> Result<(), Box> {
// annotate
let annotator = Annotator::default()
.with_saveout("RTMO")
- .with_skeletons(&coco::SKELETONS_16);
+ .with_skeletons(&COCO_SKELETONS_16);
annotator.annotate(&x, &y);
Ok(())
diff --git a/examples/sapiens/main.rs b/examples/sapiens/main.rs
new file mode 100644
index 0000000..63c3b12
--- /dev/null
+++ b/examples/sapiens/main.rs
@@ -0,0 +1,30 @@
+use usls::{
+ models::{Sapiens, SapiensTask},
+ Annotator, DataLoader, Options, BODY_PARTS_28,
+};
+
+fn main() -> Result<(), Box> {
+ // build
+ let options = Options::default()
+ .with_model("sapiens-seg-0.3b-dyn.onnx")?
+ .with_sapiens_task(SapiensTask::Seg)
+ .with_names(&BODY_PARTS_28)
+ .with_profile(false)
+ .with_i00((1, 1, 8).into());
+ let mut model = Sapiens::new(options)?;
+
+ // load
+ let x = [DataLoader::try_read("./assets/paul-george.jpg")?];
+
+ // run
+ let y = model.run(&x)?;
+
+ // annotate
+ let annotator = Annotator::default()
+ .without_masks(true)
+ .with_polygons_name(false)
+ .with_saveout("Sapiens");
+ annotator.annotate(&x, &y);
+
+ Ok(())
+}
diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs
index 97af956..3587265 100644
--- a/examples/yolo/main.rs
+++ b/examples/yolo/main.rs
@@ -1,7 +1,10 @@
use anyhow::Result;
use clap::Parser;
-use usls::{coco, models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion};
+use usls::{
+ models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion, COCO_KEYPOINTS_17,
+ COCO_SKELETONS_16,
+};
#[derive(Parser, Clone)]
#[command(author, version, about, long_about = None)]
@@ -174,8 +177,8 @@ fn main() -> Result<()> {
.with_i02((args.height_min, args.height, args.height_max).into())
.with_i03((args.width_min, args.width, args.width_max).into())
.with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15
- // .with_names(&coco::NAMES_80)
- .with_names2(&coco::KEYPOINTS_NAMES_17)
+ // .with_names(&COCO_CLASS_NAMES_80)
+ .with_names2(&COCO_KEYPOINTS_17)
.with_find_contours(!args.no_contours) // find contours or not
.with_profile(args.profile);
let mut model = YOLO::new(options)?;
@@ -187,7 +190,7 @@ fn main() -> Result<()> {
// build annotator
let annotator = Annotator::default()
- .with_skeletons(&coco::SKELETONS_16)
+ .with_skeletons(&COCO_SKELETONS_16)
.with_bboxes_thickness(4)
.without_masks(true) // No masks plotting when doing segment task.
.with_saveout(saveout);
diff --git a/src/core/mod.rs b/src/core/mod.rs
index 0577c6a..2586cf6 100644
--- a/src/core/mod.rs
+++ b/src/core/mod.rs
@@ -9,6 +9,7 @@ pub mod onnx;
pub mod ops;
mod options;
mod ort_engine;
+mod task;
mod tokenizer_stream;
mod ts;
mod vision;
@@ -25,6 +26,7 @@ pub use min_opt_max::MinOptMax;
pub use ops::Ops;
pub use options::Options;
pub use ort_engine::OrtEngine;
+pub use task::Task;
pub use tokenizer_stream::TokenizerStream;
pub use ts::Ts;
pub use vision::Vision;
diff --git a/src/core/ops.rs b/src/core/ops.rs
index 6b5b94f..10d3a74 100644
--- a/src/core/ops.rs
+++ b/src/core/ops.rs
@@ -7,7 +7,7 @@ use fast_image_resize::{
FilterType, ResizeAlg, ResizeOptions, Resizer,
};
use image::{DynamicImage, GenericImageView};
-use ndarray::{s, Array, Axis, IntoDimension, IxDyn};
+use ndarray::{s, Array, Array3, Axis, IntoDimension, IxDyn};
use rayon::prelude::*;
pub enum Ops<'a> {
@@ -159,7 +159,41 @@ impl Ops<'_> {
mask.resize_exact(w1 as u32, h1 as u32, image::imageops::FilterType::Triangle)
}
- pub fn resize_lumaf32_vec(
+ // pub fn argmax(xs: Array, d: usize, keep_dims: bool) -> Result> {
+ // let mask = Array::zeros(xs.raw_dim());
+ // todo!();
+ // }
+
+ pub fn interpolate_3d(
+ xs: Array,
+ tw: f32,
+ th: f32,
+ filter: &str,
+ ) -> Result> {
+ let d_max = xs.ndim();
+ if d_max != 3 {
+ anyhow::bail!("`interpolate_3d`: The input's ndim: {} is not 3.", d_max);
+ }
+ let (n, h, w) = (xs.shape()[0], xs.shape()[1], xs.shape()[2]);
+ let mut ys = Array3::zeros((n, th as usize, tw as usize));
+ for (i, luma) in xs.axis_iter(Axis(0)).enumerate() {
+ let v = Ops::resize_lumaf32_f32(
+ &luma.to_owned().into_raw_vec_and_offset().0,
+ w as _,
+ h as _,
+ tw as _,
+ th as _,
+ false,
+ filter,
+ )?;
+ let y_ = Array::from_shape_vec((th as usize, tw as usize), v)?;
+ ys.slice_mut(s![i, .., ..]).assign(&y_);
+ }
+
+ Ok(ys.into_dyn())
+ }
+
+ pub fn resize_lumaf32_u8(
v: &[f32],
w0: f32,
h0: f32,
@@ -168,6 +202,20 @@ impl Ops<'_> {
crop_src: bool,
filter: &str,
) -> Result> {
+ let mask_f32 = Self::resize_lumaf32_f32(v, w0, h0, w1, h1, crop_src, filter)?;
+ let v: Vec = mask_f32.par_iter().map(|&x| (x * 255.0) as u8).collect();
+ Ok(v)
+ }
+
+ pub fn resize_lumaf32_f32(
+ v: &[f32],
+ w0: f32,
+ h0: f32,
+ w1: f32,
+ h1: f32,
+ crop_src: bool,
+ filter: &str,
+ ) -> Result> {
let src = Image::from_vec_u8(
w0 as _,
h0 as _,
@@ -189,12 +237,10 @@ impl Ops<'_> {
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
- // f32 -> u8
- let v: Vec = mask_f32.par_iter().map(|&x| (x * 255.0) as u8).collect();
- Ok(v)
+ Ok(mask_f32)
}
- pub fn resize_luma8_vec(
+ pub fn resize_luma8_u8(
v: &[u8],
w0: f32,
h0: f32,
diff --git a/src/core/options.rs b/src/core/options.rs
index 0b11113..14f621d 100644
--- a/src/core/options.rs
+++ b/src/core/options.rs
@@ -4,7 +4,7 @@ use anyhow::Result;
use crate::{
auto_load,
- models::{SamKind, YOLOPreds, YOLOTask, YOLOVersion},
+ models::{SamKind, SapiensTask, YOLOPreds, YOLOTask, YOLOVersion},
Device, MinOptMax,
};
@@ -92,6 +92,7 @@ pub struct Options {
pub find_contours: bool,
pub sam_kind: Option,
pub use_low_res_mask: Option,
+ pub sapiens_task: Option,
}
impl Default for Options {
@@ -175,6 +176,7 @@ impl Default for Options {
find_contours: false,
sam_kind: None,
use_low_res_mask: None,
+ sapiens_task: None,
}
}
}
@@ -220,6 +222,11 @@ impl Options {
self
}
+ pub fn with_sapiens_task(mut self, x: SapiensTask) -> Self {
+ self.sapiens_task = Some(x);
+ self
+ }
+
pub fn with_yolo_version(mut self, x: YOLOVersion) -> Self {
self.yolo_version = Some(x);
self
diff --git a/src/core/task.rs b/src/core/task.rs
new file mode 100644
index 0000000..b85a00d
--- /dev/null
+++ b/src/core/task.rs
@@ -0,0 +1,27 @@
+#[derive(Debug, Clone)]
+pub enum Task {
+ // vision
+ ImageClassification,
+ ObjectDetection,
+ KeypointsDetection,
+ RegisonProposal,
+ PoseEstimation,
+ SemanticSegmentation,
+ InstanceSegmentation,
+ DepthEstimation,
+ SurfaceNormalPrediction,
+ Image2ImageGeneration,
+ Inpainting,
+ SuperResolution,
+ Denoising,
+
+ // vl
+ Tagging,
+ Captioning,
+ DetailedCaptioning,
+ MoreDetailedCaptioning,
+ PhraseGrounding,
+ Vqa,
+ Ocr,
+ Text2ImageGeneration,
+}
diff --git a/src/lib.rs b/src/lib.rs
index 77dd69c..1713725 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -24,6 +24,7 @@
//! - [YOLOPv2](https://arxiv.org/abs/2208.11434): Panoptic Driving Perception
//! - [Depth-Anything (v1, v2)](https://github.com/LiheYoung/Depth-Anything): Monocular Depth Estimation
//! - [MODNet](https://github.com/ZHKKKe/MODNet): Image Matting
+//! - [Sapiens](https://arxiv.org/abs/2408.12569): Human-centric Vision Tasks
//!
//! # Examples
//!
@@ -35,7 +36,7 @@
//! Using provided [`models`] with [`Options`]
//!
//! ```rust, no_run
-//! use usls::{coco, models::YOLO, Annotator, DataLoader, Options, Vision};
+//! use usls::{ models::YOLO, Annotator, DataLoader, Options, Vision, COCO_CLASS_NAMES_80};
//!
//! let options = Options::default()
//! .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR
@@ -74,7 +75,7 @@
//!
//! ```rust, no_run
//! let options = Options::default()
-//! .with_names(&coco::NAMES_80);
+//! .with_names(&COCO_CLASS_NAMES_80);
//! ```
//!
//! More options can be found in the [`Options`] documentation.
diff --git a/src/models/db.rs b/src/models/db.rs
index 12d3e89..89bed64 100644
--- a/src/models/db.rs
+++ b/src/models/db.rs
@@ -93,7 +93,7 @@ impl DB {
})
.collect::>();
- let luma = Ops::resize_luma8_vec(
+ let luma = Ops::resize_luma8_u8(
&v,
self.width() as _,
self.height() as _,
diff --git a/src/models/depth_anything.rs b/src/models/depth_anything.rs
index 2a5c86b..1b97ad7 100644
--- a/src/models/depth_anything.rs
+++ b/src/models/depth_anything.rs
@@ -57,7 +57,7 @@ impl DepthAnything {
.map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8)
.collect::>();
- let luma = Ops::resize_luma8_vec(
+ let luma = Ops::resize_luma8_u8(
&v,
self.width() as _,
self.height() as _,
diff --git a/src/models/mod.rs b/src/models/mod.rs
index 237fe8f..0d13182 100644
--- a/src/models/mod.rs
+++ b/src/models/mod.rs
@@ -9,6 +9,7 @@ mod grounding_dino;
mod modnet;
mod rtmo;
mod sam;
+mod sapiens;
mod svtr;
mod yolo;
mod yolo_;
@@ -23,6 +24,7 @@ pub use grounding_dino::GroundingDINO;
pub use modnet::MODNet;
pub use rtmo::RTMO;
pub use sam::{SamKind, SamPrompt, SAM};
+pub use sapiens::{Sapiens, SapiensTask};
pub use svtr::SVTR;
pub use yolo::YOLO;
pub use yolo_::*;
diff --git a/src/models/modnet.rs b/src/models/modnet.rs
index 57e647c..f7dfd01 100644
--- a/src/models/modnet.rs
+++ b/src/models/modnet.rs
@@ -51,7 +51,7 @@ impl MODNet {
for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() {
let (w1, h1) = (xs0[idx].width(), xs0[idx].height());
let luma = luma.mapv(|x| (x * 255.0) as u8);
- let luma = Ops::resize_luma8_vec(
+ let luma = Ops::resize_luma8_u8(
&luma.into_raw_vec_and_offset().0,
self.width() as _,
self.height() as _,
diff --git a/src/models/sam.rs b/src/models/sam.rs
index 6a03283..d6f37d4 100644
--- a/src/models/sam.rs
+++ b/src/models/sam.rs
@@ -264,7 +264,7 @@ impl SAM {
let (h, w) = mask.dim();
let luma = if self.use_low_res_mask {
- Ops::resize_lumaf32_vec(
+ Ops::resize_lumaf32_u8(
&mask.into_owned().into_raw_vec_and_offset().0,
w as _,
h as _,
diff --git a/src/models/sapiens.rs b/src/models/sapiens.rs
new file mode 100644
index 0000000..448149e
--- /dev/null
+++ b/src/models/sapiens.rs
@@ -0,0 +1,158 @@
+use anyhow::Result;
+use image::DynamicImage;
+use ndarray::{s, Array2, Axis};
+
+use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y};
+
+#[derive(Debug, Clone, clap::ValueEnum)]
+pub enum SapiensTask {
+ Seg,
+ Depth,
+ Normal,
+ Pose,
+}
+
+#[derive(Debug)]
+pub struct Sapiens {
+ engine_seg: OrtEngine,
+ height: MinOptMax,
+ width: MinOptMax,
+ batch: MinOptMax,
+ task: SapiensTask,
+ names_body: Option>,
+}
+
+impl Sapiens {
+ pub fn new(options_seg: Options) -> Result {
+ let mut engine_seg = OrtEngine::new(&options_seg)?;
+ let (batch, height, width) = (
+ engine_seg.batch().to_owned(),
+ engine_seg.height().to_owned(),
+ engine_seg.width().to_owned(),
+ );
+ let task = options_seg
+ .sapiens_task
+ .expect("Error: No sapiens task specified.");
+ let names_body = options_seg.names;
+ engine_seg.dry_run()?;
+
+ Ok(Self {
+ engine_seg,
+ height,
+ width,
+ batch,
+ task,
+ names_body,
+ })
+ }
+
+ pub fn run(&mut self, xs: &[DynamicImage]) -> Result> {
+ let xs_ = X::apply(&[
+ Ops::Resize(xs, self.height() as u32, self.width() as u32, "Bilinear"),
+ Ops::Standardize(&[123.5, 116.5, 103.5], &[58.5, 57.0, 57.5], 3),
+ Ops::Nhwc2nchw,
+ ])?;
+
+ match self.task {
+ SapiensTask::Seg => {
+ let ys = self.engine_seg.run(Xs::from(xs_))?;
+ self.postprocess_seg(ys, xs)
+ }
+ _ => todo!(),
+ }
+ }
+
+ pub fn postprocess_seg(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> {
+ let mut ys: Vec = Vec::new();
+ for (idx, b) in xs[0].axis_iter(Axis(0)).enumerate() {
+ let (w1, h1) = (xs0[idx].width(), xs0[idx].height());
+
+ // rescale
+ let masks = Ops::interpolate_3d(b.to_owned(), w1 as _, h1 as _, "Bilinear")?;
+
+ // generate mask
+ let mut mask = Array2::::zeros((h1 as _, w1 as _));
+ let mut ids = Vec::new();
+ for hh in 0..h1 {
+ for ww in 0..w1 {
+ let pt_slice = masks.slice(s![.., hh as usize, ww as usize]);
+ let (i, c) = match pt_slice
+ .into_iter()
+ .enumerate()
+ .max_by(|a, b| a.1.total_cmp(b.1))
+ {
+ Some((i, c)) => (i, c),
+ None => continue,
+ };
+
+ if *c <= 0. || i == 0 {
+ continue;
+ }
+ mask[[hh as _, ww as _]] = i;
+
+ if !ids.contains(&i) {
+ ids.push(i);
+ }
+ }
+ }
+
+ // generate masks and polygons
+ let mut y_masks: Vec = Vec::new();
+ let mut y_polygons: Vec = Vec::new();
+ for i in ids.iter() {
+ let luma = mask.mapv(|x| if x == *i { 255 } else { 0 });
+ let luma: image::ImageBuffer, Vec<_>> =
+ match image::ImageBuffer::from_raw(
+ w1 as _,
+ h1 as _,
+ luma.into_raw_vec_and_offset().0,
+ ) {
+ None => continue,
+ Some(x) => x,
+ };
+
+ // contours
+ let contours: Vec> =
+ imageproc::contours::find_contours_with_threshold(&luma, 0);
+ let polygon = match contours
+ .into_iter()
+ .map(|x| {
+ let mut polygon = Polygon::default()
+ .with_id(*i as _)
+ .with_points_imageproc(&x.points);
+ if let Some(names_body) = &self.names_body {
+ polygon = polygon.with_name(&names_body[*i]);
+ }
+ polygon
+ })
+ .max_by(|x, y| x.area().total_cmp(&y.area()))
+ {
+ Some(p) => p,
+ None => continue,
+ };
+
+ y_polygons.push(polygon);
+
+ let mut mask = Mask::default().with_mask(luma).with_id(*i as _);
+ if let Some(names_body) = &self.names_body {
+ mask = mask.with_name(&names_body[*i]);
+ }
+ y_masks.push(mask);
+ }
+ ys.push(Y::default().with_masks(&y_masks).with_polygons(&y_polygons));
+ }
+ Ok(ys)
+ }
+
+ pub fn batch(&self) -> isize {
+ self.batch.opt
+ }
+
+ pub fn width(&self) -> isize {
+ self.width.opt
+ }
+
+ pub fn height(&self) -> isize {
+ self.height.opt
+ }
+}
diff --git a/src/models/yolo.rs b/src/models/yolo.rs
index bf5a9d1..f79b2ca 100644
--- a/src/models/yolo.rs
+++ b/src/models/yolo.rs
@@ -421,7 +421,7 @@ impl Vision for YOLO {
let mask = coefs.dot(&proto); // (mh, mw, n)
// Mask rescale
- let mask = Ops::resize_lumaf32_vec(
+ let mask = Ops::resize_lumaf32_u8(
&mask.into_raw_vec_and_offset().0,
mw as _,
mh as _,
diff --git a/src/models/yolop.rs b/src/models/yolop.rs
index 2aefcd3..65b1d51 100644
--- a/src/models/yolop.rs
+++ b/src/models/yolop.rs
@@ -191,7 +191,7 @@ impl YOLOPv2 {
h1: f32,
) -> Result>> {
let mask = mask.mapv(|x| if x < thresh { 0u8 } else { 255u8 });
- let mask = Ops::resize_luma8_vec(
+ let mask = Ops::resize_luma8_u8(
&mask.into_raw_vec_and_offset().0,
w0,
h0,
diff --git a/src/utils/mod.rs b/src/utils/mod.rs
index 34a3078..eb52f2d 100644
--- a/src/utils/mod.rs
+++ b/src/utils/mod.rs
@@ -4,10 +4,11 @@ use rand::{distributions::Alphanumeric, thread_rng, Rng};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
-pub mod coco;
pub mod colormap256;
+pub mod names;
pub use colormap256::*;
+pub use names::*;
pub(crate) const GITHUB_ASSETS: &str =
"https://github.com/jamjamjon/assets/releases/download/v0.0.1";
diff --git a/src/utils/coco.rs b/src/utils/names.rs
similarity index 67%
rename from src/utils/coco.rs
rename to src/utils/names.rs
index 11344cb..ea6b648 100644
--- a/src/utils/coco.rs
+++ b/src/utils/names.rs
@@ -1,6 +1,6 @@
-//! Some constants releated with COCO dataset: [`SKELETONS_16`], [`KEYPOINTS_NAMES_17`], [`NAMES_80`]
+//! Some constants releated with COCO dataset: [`COCO_SKELETONS_16`], [`COCO_KEYPOINTS_17`], [`COCO_CLASS_NAMES_80`]
-pub const SKELETONS_16: [(usize, usize); 16] = [
+pub const COCO_SKELETONS_16: [(usize, usize); 16] = [
(0, 1),
(0, 2),
(1, 3),
@@ -19,7 +19,7 @@ pub const SKELETONS_16: [(usize, usize); 16] = [
(14, 16),
];
-pub const KEYPOINTS_NAMES_17: [&str; 17] = [
+pub const COCO_KEYPOINTS_17: [&str; 17] = [
"nose",
"left_eye",
"right_eye",
@@ -39,7 +39,7 @@ pub const KEYPOINTS_NAMES_17: [&str; 17] = [
"right_ankle",
];
-pub const NAMES_80: [&str; 80] = [
+pub const COCO_CLASS_NAMES_80: [&str; 80] = [
"person",
"bicycle",
"car",
@@ -121,3 +121,34 @@ pub const NAMES_80: [&str; 80] = [
"hair drier",
"toothbrush",
];
+
+pub const BODY_PARTS_28: [&str; 28] = [
+ "Background",
+ "Apparel",
+ "Face Neck",
+ "Hair",
+ "Left Foot",
+ "Left Hand",
+ "Left Lower Arm",
+ "Left Lower Leg",
+ "Left Shoe",
+ "Left Sock",
+ "Left Upper Arm",
+ "Left Upper Leg",
+ "Lower Clothing",
+ "Right Foot",
+ "Right Hand",
+ "Right Lower Arm",
+ "Right Lower Leg",
+ "Right Shoe",
+ "Right Sock",
+ "Right Upper Arm",
+ "Right Upper Leg",
+ "Torso",
+ "Upper Clothing",
+ "Lower Lip",
+ "Upper Lip",
+ "Lower Teeth",
+ "Upper Teeth",
+ "Tongue",
+];