diff --git a/Cargo.toml b/Cargo.toml index ca76358..3722b18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "usls" -version = "0.0.11" +version = "0.0.12" edition = "2021" description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." repository = "https://github.com/jamjamjon/usls" diff --git a/README.md b/README.md index b5dbca3..8a2c9aa 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ - **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10) - **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) -- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet) +- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569) - **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World)
@@ -70,6 +70,7 @@ | [Depth-Anything](https://github.com/LiheYoung/Depth-Anything) | Monocular Depth Estimation | [demo](examples/depth-anything) | ✅ | ✅ | ❌ | ❌ | | [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) | ✅ | ✅ | ✅ | ✅ | | [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) | ✅ | ✅ | | | +| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Body Part Segmentation | [demo](examples/sapiens) | ✅ | ✅ | | |
diff --git a/assets/paul-george.jpg b/assets/paul-george.jpg new file mode 100644 index 0000000..9b92635 Binary files /dev/null and b/assets/paul-george.jpg differ diff --git a/benches/yolo.rs b/benches/yolo.rs index 4868ba6..b128338 100644 --- a/benches/yolo.rs +++ b/benches/yolo.rs @@ -1,7 +1,7 @@ use anyhow::Result; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use usls::{coco, models::YOLO, DataLoader, Options, Vision, YOLOTask, YOLOVersion}; +use usls::{models::YOLO, DataLoader, Options, Vision, YOLOTask, YOLOVersion, COCO_KEYPOINTS_17}; enum Stage { Pre, @@ -60,7 +60,7 @@ pub fn benchmark_cuda(c: &mut Criterion, h: isize, w: isize) -> Result<()> { .with_i02((320, h, 1280).into()) .with_i03((320, w, 1280).into()) .with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15 - .with_names2(&coco::KEYPOINTS_NAMES_17); + .with_names2(&COCO_KEYPOINTS_17); let mut model = YOLO::new(options)?; let xs = vec![DataLoader::try_read("./assets/bus.jpg")?]; diff --git a/examples/depth-anything/main.rs b/examples/depth-anything/main.rs index aa84691..9a4cd38 100644 --- a/examples/depth-anything/main.rs +++ b/examples/depth-anything/main.rs @@ -11,7 +11,7 @@ fn main() -> Result<(), Box> { let mut model = DepthAnything::new(options)?; // load - let x = vec![DataLoader::try_read("./assets/2.jpg")?]; + let x = [DataLoader::try_read("./assets/2.jpg")?]; // run let y = model.run(&x)?; diff --git a/examples/rtmo/main.rs b/examples/rtmo/main.rs index bf6f696..b685c1f 100644 --- a/examples/rtmo/main.rs +++ b/examples/rtmo/main.rs @@ -1,4 +1,4 @@ -use usls::{coco, models::RTMO, Annotator, DataLoader, Options}; +use usls::{models::RTMO, Annotator, DataLoader, Options, COCO_SKELETONS_16}; fn main() -> Result<(), Box> { // build model @@ -19,7 +19,7 @@ fn main() -> Result<(), Box> { // annotate let annotator = Annotator::default() .with_saveout("RTMO") - .with_skeletons(&coco::SKELETONS_16); + .with_skeletons(&COCO_SKELETONS_16); annotator.annotate(&x, &y); Ok(()) diff --git a/examples/sapiens/main.rs b/examples/sapiens/main.rs new file mode 100644 index 0000000..63c3b12 --- /dev/null +++ b/examples/sapiens/main.rs @@ -0,0 +1,30 @@ +use usls::{ + models::{Sapiens, SapiensTask}, + Annotator, DataLoader, Options, BODY_PARTS_28, +}; + +fn main() -> Result<(), Box> { + // build + let options = Options::default() + .with_model("sapiens-seg-0.3b-dyn.onnx")? + .with_sapiens_task(SapiensTask::Seg) + .with_names(&BODY_PARTS_28) + .with_profile(false) + .with_i00((1, 1, 8).into()); + let mut model = Sapiens::new(options)?; + + // load + let x = [DataLoader::try_read("./assets/paul-george.jpg")?]; + + // run + let y = model.run(&x)?; + + // annotate + let annotator = Annotator::default() + .without_masks(true) + .with_polygons_name(false) + .with_saveout("Sapiens"); + annotator.annotate(&x, &y); + + Ok(()) +} diff --git a/examples/yolo/main.rs b/examples/yolo/main.rs index 97af956..3587265 100644 --- a/examples/yolo/main.rs +++ b/examples/yolo/main.rs @@ -1,7 +1,10 @@ use anyhow::Result; use clap::Parser; -use usls::{coco, models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion}; +use usls::{ + models::YOLO, Annotator, DataLoader, Options, Vision, YOLOTask, YOLOVersion, COCO_KEYPOINTS_17, + COCO_SKELETONS_16, +}; #[derive(Parser, Clone)] #[command(author, version, about, long_about = None)] @@ -174,8 +177,8 @@ fn main() -> Result<()> { .with_i02((args.height_min, args.height, args.height_max).into()) .with_i03((args.width_min, args.width, args.width_max).into()) .with_confs(&[0.2, 0.15]) // class_0: 0.4, others: 0.15 - // .with_names(&coco::NAMES_80) - .with_names2(&coco::KEYPOINTS_NAMES_17) + // .with_names(&COCO_CLASS_NAMES_80) + .with_names2(&COCO_KEYPOINTS_17) .with_find_contours(!args.no_contours) // find contours or not .with_profile(args.profile); let mut model = YOLO::new(options)?; @@ -187,7 +190,7 @@ fn main() -> Result<()> { // build annotator let annotator = Annotator::default() - .with_skeletons(&coco::SKELETONS_16) + .with_skeletons(&COCO_SKELETONS_16) .with_bboxes_thickness(4) .without_masks(true) // No masks plotting when doing segment task. .with_saveout(saveout); diff --git a/src/core/mod.rs b/src/core/mod.rs index 0577c6a..2586cf6 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -9,6 +9,7 @@ pub mod onnx; pub mod ops; mod options; mod ort_engine; +mod task; mod tokenizer_stream; mod ts; mod vision; @@ -25,6 +26,7 @@ pub use min_opt_max::MinOptMax; pub use ops::Ops; pub use options::Options; pub use ort_engine::OrtEngine; +pub use task::Task; pub use tokenizer_stream::TokenizerStream; pub use ts::Ts; pub use vision::Vision; diff --git a/src/core/ops.rs b/src/core/ops.rs index 6b5b94f..10d3a74 100644 --- a/src/core/ops.rs +++ b/src/core/ops.rs @@ -7,7 +7,7 @@ use fast_image_resize::{ FilterType, ResizeAlg, ResizeOptions, Resizer, }; use image::{DynamicImage, GenericImageView}; -use ndarray::{s, Array, Axis, IntoDimension, IxDyn}; +use ndarray::{s, Array, Array3, Axis, IntoDimension, IxDyn}; use rayon::prelude::*; pub enum Ops<'a> { @@ -159,7 +159,41 @@ impl Ops<'_> { mask.resize_exact(w1 as u32, h1 as u32, image::imageops::FilterType::Triangle) } - pub fn resize_lumaf32_vec( + // pub fn argmax(xs: Array, d: usize, keep_dims: bool) -> Result> { + // let mask = Array::zeros(xs.raw_dim()); + // todo!(); + // } + + pub fn interpolate_3d( + xs: Array, + tw: f32, + th: f32, + filter: &str, + ) -> Result> { + let d_max = xs.ndim(); + if d_max != 3 { + anyhow::bail!("`interpolate_3d`: The input's ndim: {} is not 3.", d_max); + } + let (n, h, w) = (xs.shape()[0], xs.shape()[1], xs.shape()[2]); + let mut ys = Array3::zeros((n, th as usize, tw as usize)); + for (i, luma) in xs.axis_iter(Axis(0)).enumerate() { + let v = Ops::resize_lumaf32_f32( + &luma.to_owned().into_raw_vec_and_offset().0, + w as _, + h as _, + tw as _, + th as _, + false, + filter, + )?; + let y_ = Array::from_shape_vec((th as usize, tw as usize), v)?; + ys.slice_mut(s![i, .., ..]).assign(&y_); + } + + Ok(ys.into_dyn()) + } + + pub fn resize_lumaf32_u8( v: &[f32], w0: f32, h0: f32, @@ -168,6 +202,20 @@ impl Ops<'_> { crop_src: bool, filter: &str, ) -> Result> { + let mask_f32 = Self::resize_lumaf32_f32(v, w0, h0, w1, h1, crop_src, filter)?; + let v: Vec = mask_f32.par_iter().map(|&x| (x * 255.0) as u8).collect(); + Ok(v) + } + + pub fn resize_lumaf32_f32( + v: &[f32], + w0: f32, + h0: f32, + w1: f32, + h1: f32, + crop_src: bool, + filter: &str, + ) -> Result> { let src = Image::from_vec_u8( w0 as _, h0 as _, @@ -189,12 +237,10 @@ impl Ops<'_> { .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) .collect(); - // f32 -> u8 - let v: Vec = mask_f32.par_iter().map(|&x| (x * 255.0) as u8).collect(); - Ok(v) + Ok(mask_f32) } - pub fn resize_luma8_vec( + pub fn resize_luma8_u8( v: &[u8], w0: f32, h0: f32, diff --git a/src/core/options.rs b/src/core/options.rs index 0b11113..14f621d 100644 --- a/src/core/options.rs +++ b/src/core/options.rs @@ -4,7 +4,7 @@ use anyhow::Result; use crate::{ auto_load, - models::{SamKind, YOLOPreds, YOLOTask, YOLOVersion}, + models::{SamKind, SapiensTask, YOLOPreds, YOLOTask, YOLOVersion}, Device, MinOptMax, }; @@ -92,6 +92,7 @@ pub struct Options { pub find_contours: bool, pub sam_kind: Option, pub use_low_res_mask: Option, + pub sapiens_task: Option, } impl Default for Options { @@ -175,6 +176,7 @@ impl Default for Options { find_contours: false, sam_kind: None, use_low_res_mask: None, + sapiens_task: None, } } } @@ -220,6 +222,11 @@ impl Options { self } + pub fn with_sapiens_task(mut self, x: SapiensTask) -> Self { + self.sapiens_task = Some(x); + self + } + pub fn with_yolo_version(mut self, x: YOLOVersion) -> Self { self.yolo_version = Some(x); self diff --git a/src/core/task.rs b/src/core/task.rs new file mode 100644 index 0000000..b85a00d --- /dev/null +++ b/src/core/task.rs @@ -0,0 +1,27 @@ +#[derive(Debug, Clone)] +pub enum Task { + // vision + ImageClassification, + ObjectDetection, + KeypointsDetection, + RegisonProposal, + PoseEstimation, + SemanticSegmentation, + InstanceSegmentation, + DepthEstimation, + SurfaceNormalPrediction, + Image2ImageGeneration, + Inpainting, + SuperResolution, + Denoising, + + // vl + Tagging, + Captioning, + DetailedCaptioning, + MoreDetailedCaptioning, + PhraseGrounding, + Vqa, + Ocr, + Text2ImageGeneration, +} diff --git a/src/lib.rs b/src/lib.rs index 77dd69c..1713725 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,7 @@ //! - [YOLOPv2](https://arxiv.org/abs/2208.11434): Panoptic Driving Perception //! - [Depth-Anything (v1, v2)](https://github.com/LiheYoung/Depth-Anything): Monocular Depth Estimation //! - [MODNet](https://github.com/ZHKKKe/MODNet): Image Matting +//! - [Sapiens](https://arxiv.org/abs/2408.12569): Human-centric Vision Tasks //! //! # Examples //! @@ -35,7 +36,7 @@ //! Using provided [`models`] with [`Options`] //! //! ```rust, no_run -//! use usls::{coco, models::YOLO, Annotator, DataLoader, Options, Vision}; +//! use usls::{ models::YOLO, Annotator, DataLoader, Options, Vision, COCO_CLASS_NAMES_80}; //! //! let options = Options::default() //! .with_yolo_version(YOLOVersion::V8) // YOLOVersion: V5, V6, V7, V8, V9, V10, RTDETR @@ -74,7 +75,7 @@ //! //! ```rust, no_run //! let options = Options::default() -//! .with_names(&coco::NAMES_80); +//! .with_names(&COCO_CLASS_NAMES_80); //! ``` //! //! More options can be found in the [`Options`] documentation. diff --git a/src/models/db.rs b/src/models/db.rs index 12d3e89..89bed64 100644 --- a/src/models/db.rs +++ b/src/models/db.rs @@ -93,7 +93,7 @@ impl DB { }) .collect::>(); - let luma = Ops::resize_luma8_vec( + let luma = Ops::resize_luma8_u8( &v, self.width() as _, self.height() as _, diff --git a/src/models/depth_anything.rs b/src/models/depth_anything.rs index 2a5c86b..1b97ad7 100644 --- a/src/models/depth_anything.rs +++ b/src/models/depth_anything.rs @@ -57,7 +57,7 @@ impl DepthAnything { .map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8) .collect::>(); - let luma = Ops::resize_luma8_vec( + let luma = Ops::resize_luma8_u8( &v, self.width() as _, self.height() as _, diff --git a/src/models/mod.rs b/src/models/mod.rs index 237fe8f..0d13182 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -9,6 +9,7 @@ mod grounding_dino; mod modnet; mod rtmo; mod sam; +mod sapiens; mod svtr; mod yolo; mod yolo_; @@ -23,6 +24,7 @@ pub use grounding_dino::GroundingDINO; pub use modnet::MODNet; pub use rtmo::RTMO; pub use sam::{SamKind, SamPrompt, SAM}; +pub use sapiens::{Sapiens, SapiensTask}; pub use svtr::SVTR; pub use yolo::YOLO; pub use yolo_::*; diff --git a/src/models/modnet.rs b/src/models/modnet.rs index 57e647c..f7dfd01 100644 --- a/src/models/modnet.rs +++ b/src/models/modnet.rs @@ -51,7 +51,7 @@ impl MODNet { for (idx, luma) in xs[0].axis_iter(Axis(0)).enumerate() { let (w1, h1) = (xs0[idx].width(), xs0[idx].height()); let luma = luma.mapv(|x| (x * 255.0) as u8); - let luma = Ops::resize_luma8_vec( + let luma = Ops::resize_luma8_u8( &luma.into_raw_vec_and_offset().0, self.width() as _, self.height() as _, diff --git a/src/models/sam.rs b/src/models/sam.rs index 6a03283..d6f37d4 100644 --- a/src/models/sam.rs +++ b/src/models/sam.rs @@ -264,7 +264,7 @@ impl SAM { let (h, w) = mask.dim(); let luma = if self.use_low_res_mask { - Ops::resize_lumaf32_vec( + Ops::resize_lumaf32_u8( &mask.into_owned().into_raw_vec_and_offset().0, w as _, h as _, diff --git a/src/models/sapiens.rs b/src/models/sapiens.rs new file mode 100644 index 0000000..448149e --- /dev/null +++ b/src/models/sapiens.rs @@ -0,0 +1,158 @@ +use anyhow::Result; +use image::DynamicImage; +use ndarray::{s, Array2, Axis}; + +use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, Xs, X, Y}; + +#[derive(Debug, Clone, clap::ValueEnum)] +pub enum SapiensTask { + Seg, + Depth, + Normal, + Pose, +} + +#[derive(Debug)] +pub struct Sapiens { + engine_seg: OrtEngine, + height: MinOptMax, + width: MinOptMax, + batch: MinOptMax, + task: SapiensTask, + names_body: Option>, +} + +impl Sapiens { + pub fn new(options_seg: Options) -> Result { + let mut engine_seg = OrtEngine::new(&options_seg)?; + let (batch, height, width) = ( + engine_seg.batch().to_owned(), + engine_seg.height().to_owned(), + engine_seg.width().to_owned(), + ); + let task = options_seg + .sapiens_task + .expect("Error: No sapiens task specified."); + let names_body = options_seg.names; + engine_seg.dry_run()?; + + Ok(Self { + engine_seg, + height, + width, + batch, + task, + names_body, + }) + } + + pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { + let xs_ = X::apply(&[ + Ops::Resize(xs, self.height() as u32, self.width() as u32, "Bilinear"), + Ops::Standardize(&[123.5, 116.5, 103.5], &[58.5, 57.0, 57.5], 3), + Ops::Nhwc2nchw, + ])?; + + match self.task { + SapiensTask::Seg => { + let ys = self.engine_seg.run(Xs::from(xs_))?; + self.postprocess_seg(ys, xs) + } + _ => todo!(), + } + } + + pub fn postprocess_seg(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { + let mut ys: Vec = Vec::new(); + for (idx, b) in xs[0].axis_iter(Axis(0)).enumerate() { + let (w1, h1) = (xs0[idx].width(), xs0[idx].height()); + + // rescale + let masks = Ops::interpolate_3d(b.to_owned(), w1 as _, h1 as _, "Bilinear")?; + + // generate mask + let mut mask = Array2::::zeros((h1 as _, w1 as _)); + let mut ids = Vec::new(); + for hh in 0..h1 { + for ww in 0..w1 { + let pt_slice = masks.slice(s![.., hh as usize, ww as usize]); + let (i, c) = match pt_slice + .into_iter() + .enumerate() + .max_by(|a, b| a.1.total_cmp(b.1)) + { + Some((i, c)) => (i, c), + None => continue, + }; + + if *c <= 0. || i == 0 { + continue; + } + mask[[hh as _, ww as _]] = i; + + if !ids.contains(&i) { + ids.push(i); + } + } + } + + // generate masks and polygons + let mut y_masks: Vec = Vec::new(); + let mut y_polygons: Vec = Vec::new(); + for i in ids.iter() { + let luma = mask.mapv(|x| if x == *i { 255 } else { 0 }); + let luma: image::ImageBuffer, Vec<_>> = + match image::ImageBuffer::from_raw( + w1 as _, + h1 as _, + luma.into_raw_vec_and_offset().0, + ) { + None => continue, + Some(x) => x, + }; + + // contours + let contours: Vec> = + imageproc::contours::find_contours_with_threshold(&luma, 0); + let polygon = match contours + .into_iter() + .map(|x| { + let mut polygon = Polygon::default() + .with_id(*i as _) + .with_points_imageproc(&x.points); + if let Some(names_body) = &self.names_body { + polygon = polygon.with_name(&names_body[*i]); + } + polygon + }) + .max_by(|x, y| x.area().total_cmp(&y.area())) + { + Some(p) => p, + None => continue, + }; + + y_polygons.push(polygon); + + let mut mask = Mask::default().with_mask(luma).with_id(*i as _); + if let Some(names_body) = &self.names_body { + mask = mask.with_name(&names_body[*i]); + } + y_masks.push(mask); + } + ys.push(Y::default().with_masks(&y_masks).with_polygons(&y_polygons)); + } + Ok(ys) + } + + pub fn batch(&self) -> isize { + self.batch.opt + } + + pub fn width(&self) -> isize { + self.width.opt + } + + pub fn height(&self) -> isize { + self.height.opt + } +} diff --git a/src/models/yolo.rs b/src/models/yolo.rs index bf5a9d1..f79b2ca 100644 --- a/src/models/yolo.rs +++ b/src/models/yolo.rs @@ -421,7 +421,7 @@ impl Vision for YOLO { let mask = coefs.dot(&proto); // (mh, mw, n) // Mask rescale - let mask = Ops::resize_lumaf32_vec( + let mask = Ops::resize_lumaf32_u8( &mask.into_raw_vec_and_offset().0, mw as _, mh as _, diff --git a/src/models/yolop.rs b/src/models/yolop.rs index 2aefcd3..65b1d51 100644 --- a/src/models/yolop.rs +++ b/src/models/yolop.rs @@ -191,7 +191,7 @@ impl YOLOPv2 { h1: f32, ) -> Result>> { let mask = mask.mapv(|x| if x < thresh { 0u8 } else { 255u8 }); - let mask = Ops::resize_luma8_vec( + let mask = Ops::resize_luma8_u8( &mask.into_raw_vec_and_offset().0, w0, h0, diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 34a3078..eb52f2d 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -4,10 +4,11 @@ use rand::{distributions::Alphanumeric, thread_rng, Rng}; use std::io::{Read, Write}; use std::path::{Path, PathBuf}; -pub mod coco; pub mod colormap256; +pub mod names; pub use colormap256::*; +pub use names::*; pub(crate) const GITHUB_ASSETS: &str = "https://github.com/jamjamjon/assets/releases/download/v0.0.1"; diff --git a/src/utils/coco.rs b/src/utils/names.rs similarity index 67% rename from src/utils/coco.rs rename to src/utils/names.rs index 11344cb..ea6b648 100644 --- a/src/utils/coco.rs +++ b/src/utils/names.rs @@ -1,6 +1,6 @@ -//! Some constants releated with COCO dataset: [`SKELETONS_16`], [`KEYPOINTS_NAMES_17`], [`NAMES_80`] +//! Some constants releated with COCO dataset: [`COCO_SKELETONS_16`], [`COCO_KEYPOINTS_17`], [`COCO_CLASS_NAMES_80`] -pub const SKELETONS_16: [(usize, usize); 16] = [ +pub const COCO_SKELETONS_16: [(usize, usize); 16] = [ (0, 1), (0, 2), (1, 3), @@ -19,7 +19,7 @@ pub const SKELETONS_16: [(usize, usize); 16] = [ (14, 16), ]; -pub const KEYPOINTS_NAMES_17: [&str; 17] = [ +pub const COCO_KEYPOINTS_17: [&str; 17] = [ "nose", "left_eye", "right_eye", @@ -39,7 +39,7 @@ pub const KEYPOINTS_NAMES_17: [&str; 17] = [ "right_ankle", ]; -pub const NAMES_80: [&str; 80] = [ +pub const COCO_CLASS_NAMES_80: [&str; 80] = [ "person", "bicycle", "car", @@ -121,3 +121,34 @@ pub const NAMES_80: [&str; 80] = [ "hair drier", "toothbrush", ]; + +pub const BODY_PARTS_28: [&str; 28] = [ + "Background", + "Apparel", + "Face Neck", + "Hair", + "Left Foot", + "Left Hand", + "Left Lower Arm", + "Left Lower Leg", + "Left Shoe", + "Left Sock", + "Left Upper Arm", + "Left Upper Leg", + "Lower Clothing", + "Right Foot", + "Right Hand", + "Right Lower Arm", + "Right Lower Leg", + "Right Shoe", + "Right Sock", + "Right Upper Arm", + "Right Upper Leg", + "Torso", + "Upper Clothing", + "Lower Lip", + "Upper Lip", + "Lower Teeth", + "Upper Teeth", + "Tongue", +];