From 46a4456a38fb11296db982d48f596fa61002c2c1 Mon Sep 17 00:00:00 2001 From: Jamjamjon <51357717+jamjamjon@users.noreply.github.com> Date: Thu, 1 Aug 2024 17:26:06 +0800 Subject: [PATCH] Add SAM2 and ONNX (#28) --- examples/sam/main.rs | 14 ++++++++++++++ src/models/sam.rs | 38 +++++++++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/examples/sam/main.rs b/examples/sam/main.rs index 0e6ef4d..f884fae 100644 --- a/examples/sam/main.rs +++ b/examples/sam/main.rs @@ -38,6 +38,20 @@ fn main() -> Result<(), Box> { .with_model("sam-vit-b-decoder-u8.onnx")?; (options_encoder, options_decoder, "SAM") } + SamKind::Sam2 => { + let options_encoder = Options::default() + // .with_model("sam2-hiera-tiny-encoder.onnx")?; + // .with_model("sam2-hiera-small-encoder.onnx")?; + .with_model("sam2-hiera-base-plus-encoder.onnx")?; + let options_decoder = Options::default() + .with_i31((1, 1, 1).into()) + .with_i41((1, 1, 1).into()) + .with_sam_kind(SamKind::Sam2) + // .with_model("sam2-hiera-tiny-decoder.onnx")?; + // .with_model("sam2-hiera-small-decoder.onnx")?; + .with_model("sam2-hiera-base-plus-decoder.onnx")?; + (options_encoder, options_decoder, "SAM2") + } SamKind::MobileSam => { let options_encoder = Options::default().with_model("mobile-sam-vit-t-encoder.onnx")?; diff --git a/src/models/sam.rs b/src/models/sam.rs index d7fe83f..c95fb26 100644 --- a/src/models/sam.rs +++ b/src/models/sam.rs @@ -8,6 +8,7 @@ use crate::{DynConf, Mask, MinOptMax, Ops, Options, OrtEngine, Polygon, X, Y}; #[derive(Debug, Clone, clap::ValueEnum)] pub enum SamKind { Sam, + Sam2, MobileSam, SamHq, EdgeSam, @@ -94,7 +95,7 @@ impl SAM { SamKind::Sam | SamKind::MobileSam | SamKind::SamHq => { options_decoder.use_low_res_mask.unwrap_or(false) } - SamKind::EdgeSam => true, + SamKind::EdgeSam | SamKind::Sam2 => true, }; encoder.dry_run()?; @@ -142,9 +143,13 @@ impl SAM { xs0: &[DynamicImage], prompts: &[SamPrompt], ) -> Result> { - let mut ys: Vec = Vec::new(); + let (image_embeddings, high_res_features_0, high_res_features_1) = match self.kind { + SamKind::Sam2 => (&xs[0], Some(&xs[1]), Some(&xs[2])), + _ => (&xs[0], None, None), + }; - for (idx, image_embedding) in xs[0].axis_iter(Axis(0)).enumerate() { + let mut ys: Vec = Vec::new(); + for (idx, image_embedding) in image_embeddings.axis_iter(Axis(0)).enumerate() { let image_width = xs0[idx].width() as f32; let image_height = xs0[idx].height() as f32; let ratio = @@ -180,6 +185,32 @@ impl SAM { prompts[idx].point_labels()?, ] } + SamKind::Sam2 => { + vec![ + X::from(image_embedding.into_dyn().into_owned()).insert_axis(0)?, + X::from( + high_res_features_0 + .unwrap() + .slice(s![idx, .., .., ..]) + .into_dyn() + .into_owned(), + ) + .insert_axis(0)?, + X::from( + high_res_features_1 + .unwrap() + .slice(s![idx, .., .., ..]) + .into_dyn() + .into_owned(), + ) + .insert_axis(0)?, + prompts[idx].point_coords(ratio)?, + prompts[idx].point_labels()?, + X::zeros(&[1, 1, self.height_low_res() as _, self.width_low_res() as _]), // mask_input + X::zeros(&[1]), // has_mask_input + X::from(vec![image_height, image_width]), // orig_im_size + ] + } }; let ys_ = self.decoder.run(args)?; @@ -196,6 +227,7 @@ impl SAM { (&ys_[2], &ys_[1]) } } + SamKind::Sam2 => (&ys_[0], &ys_[1]), SamKind::EdgeSam => match (ys_[0].ndim(), ys_[1].ndim()) { (2, 4) => (&ys_[1], &ys_[0]), (4, 2) => (&ys_[0], &ys_[1]),