diff --git a/examples/florence2/main.rs b/examples/florence2/main.rs index 0b6d822..2362138 100644 --- a/examples/florence2/main.rs +++ b/examples/florence2/main.rs @@ -1,10 +1,10 @@ -use usls::{models::Florence2, DataLoader, Options, Task}; +use usls::{models::Florence2, Annotator, DataLoader, Options, Task}; fn main() -> Result<(), Box> { // vision encoder let options_vision_encoder = Options::default() - .with_model("florence2/base-ft-vision-encoder-dyn.onnx")? - .with_i00((1, 1, 4).into()) + .with_model("florence2/base-vision-encoder.onnx")? + .with_i00((1, 2, 4).into()) .with_i02((512, 768, 800).into()) .with_i03((512, 768, 800).into()) .with_profile(false) @@ -12,118 +12,118 @@ fn main() -> Result<(), Box> { // text embed let options_text_embed = Options::default() - .with_model("florence2/base-ft-embed-tokens-dyn.onnx")? - .with_i00((1, 1, 4).into()) - .with_i01((1, 1, 20).into()) // seq_length + .with_model("florence2/base-embed-tokens.onnx")? + .with_i00((1, 2, 4).into()) + .with_i01((1, 2, 20).into()) // seq_length .with_tokenizer("florence2/tokenizer.json")? .with_profile(false); // transformer encoder let options_encoder = Options::default() - .with_model("florence2/base-ft-encoder.onnx")? - .with_i00((1, 1, 4).into()) - .with_i01((1, 1, 300).into()) // encoder_sequence_length - .with_i10((1, 1, 4).into()) - .with_i11((1, 1, 300).into()) // encoder_sequence_length + .with_model("florence2/base-encoder.onnx")? + .with_i00((1, 2, 4).into()) + .with_i01((1, 2, 300).into()) // encoder_sequence_length + .with_i10((1, 2, 4).into()) + .with_i11((1, 2, 300).into()) // encoder_sequence_length .with_profile(false); // transformer decoder let options_decoder = Options::default() - .with_model("florence2/base-ft-decoder-dyn.onnx")? - .with_i00((1, 1, 4).into()) - .with_i01((1, 1, 300).into()) // encoder_sequence_length - .with_i10((1, 1, 4).into()) - .with_i11((1, 1, 300).into()) // encoder_sequence_length - .with_i20((1, 1, 4).into()) - .with_i21((1, 1, 300).into()) // encoder_sequence_length + .with_model("florence2/base-decoder.onnx")? + .with_i00((1, 2, 4).into()) + .with_i01((1, 2, 300).into()) // encoder_sequence_length + .with_i10((1, 2, 4).into()) + .with_i11((1, 2, 300).into()) // encoder_sequence_length + .with_i20((1, 2, 4).into()) + .with_i21((1, 2, 300).into()) // encoder_sequence_length .with_profile(false); // transformer decoder merged let options_decoder_merged = Options::default() - .with_model("florence2/base-ft-decoder-merged-dyn.onnx")? + .with_model("florence2/base-decoder-merged.onnx")? // encoder_attention_mask - .with_i00((1, 1, 4).into()) - .with_i01((1, 1, 300).into()) // encoder_sequence_length + .with_i00((1, 2, 4).into()) + .with_i01((1, 2, 300).into()) // encoder_sequence_length // encoder_hidden_states - .with_i10((1, 1, 4).into()) - .with_i11((1, 1, 300).into()) // encoder_sequence_length + .with_i10((1, 2, 4).into()) + .with_i11((1, 2, 300).into()) // encoder_sequence_length // inputs_embeds - .with_i20((1, 1, 4).into()) - .with_i21((1, 1, 300).into()) // encoder_sequence_length + .with_i20((1, 2, 4).into()) + .with_i21((1, 2, 300).into()) // encoder_sequence_length // past_key_values.0.decoder.key - .with_i30((1, 1, 4).into()) - .with_i32_((1, 1, 1).into()) + .with_i30((1, 2, 4).into()) + .with_i32_((1, 2, 1).into()) // past_key_values.0.decoder.value - .with_i40((1, 1, 4).into()) - .with_i42((1, 1, 1).into()) + .with_i40((1, 2, 4).into()) + .with_i42((1, 2, 1).into()) // past_key_values.0.encoder.key - .with_i50((1, 1, 4).into()) - .with_i52((1, 1, 1).into()) + .with_i50((1, 2, 4).into()) + .with_i52((1, 2, 1).into()) // past_key_values.0.decoder.value - .with_i60((1, 1, 4).into()) - .with_i62((1, 1, 1).into()) + .with_i60((1, 2, 4).into()) + .with_i62((1, 2, 1).into()) // past_key_values.1.decoder.key - .with_i70((1, 1, 4).into()) - .with_i72((1, 1, 1).into()) + .with_i70((1, 2, 4).into()) + .with_i72((1, 2, 1).into()) // past_key_values.1.decoder.value - .with_i80((1, 1, 4).into()) - .with_i82((1, 1, 1).into()) + .with_i80((1, 2, 4).into()) + .with_i82((1, 2, 1).into()) // past_key_values.1.encoder.key - .with_i90((1, 1, 4).into()) - .with_i92((1, 1, 1).into()) + .with_i90((1, 2, 4).into()) + .with_i92((1, 2, 1).into()) // past_key_values.1.decoder.value - .with_i100((1, 1, 4).into()) - .with_i102((1, 1, 1).into()) + .with_i100((1, 2, 4).into()) + .with_i102((1, 2, 1).into()) // past_key_values.2.decoder.key - .with_i110((1, 1, 4).into()) - .with_i112((1, 1, 1).into()) + .with_i110((1, 2, 4).into()) + .with_i112((1, 2, 1).into()) // past_key_values.2.decoder.value - .with_i120((1, 1, 4).into()) - .with_i122((1, 1, 1).into()) + .with_i120((1, 2, 4).into()) + .with_i122((1, 2, 1).into()) // past_key_values.2.encoder.key - .with_i130((1, 1, 4).into()) - .with_i132((1, 1, 1).into()) + .with_i130((1, 2, 4).into()) + .with_i132((1, 2, 1).into()) // past_key_values.2.decoder.value - .with_i140((1, 1, 4).into()) - .with_i142((1, 1, 1).into()) + .with_i140((1, 2, 4).into()) + .with_i142((1, 2, 1).into()) // past_key_values.3.decoder.key - .with_i150((1, 1, 4).into()) - .with_i152((1, 1, 1).into()) + .with_i150((1, 2, 4).into()) + .with_i152((1, 2, 1).into()) // past_key_values.3.decoder.value - .with_i160((1, 1, 4).into()) - .with_i162((1, 1, 1).into()) + .with_i160((1, 2, 4).into()) + .with_i162((1, 2, 1).into()) // past_key_values.3.encoder.key - .with_i170((1, 1, 4).into()) - .with_i172((1, 1, 1).into()) + .with_i170((1, 2, 4).into()) + .with_i172((1, 2, 1).into()) // past_key_values.3.decoder.value - .with_i180((1, 1, 4).into()) - .with_i182((1, 1, 1).into()) + .with_i180((1, 2, 4).into()) + .with_i182((1, 2, 1).into()) // past_key_values.4.decoder.key - .with_i190((1, 1, 4).into()) - .with_i192((1, 1, 1).into()) + .with_i190((1, 2, 4).into()) + .with_i192((1, 2, 1).into()) // past_key_values.4.decoder.value - .with_i200((1, 1, 4).into()) - .with_i202((1, 1, 1).into()) + .with_i200((1, 2, 4).into()) + .with_i202((1, 2, 1).into()) // past_key_values.4.encoder.key - .with_i210((1, 1, 4).into()) - .with_i212((1, 1, 1).into()) + .with_i210((1, 2, 4).into()) + .with_i212((1, 2, 1).into()) // past_key_values.4.decoder.value - .with_i220((1, 1, 4).into()) - .with_i222((1, 1, 1).into()) + .with_i220((1, 2, 4).into()) + .with_i222((1, 2, 1).into()) // past_key_values.5.decoder.key - .with_i230((1, 1, 4).into()) - .with_i232((1, 1, 1).into()) + .with_i230((1, 2, 4).into()) + .with_i232((1, 2, 1).into()) // past_key_values.5.decoder.value - .with_i240((1, 1, 4).into()) - .with_i242((1, 1, 1).into()) + .with_i240((1, 2, 4).into()) + .with_i242((1, 2, 1).into()) // past_key_values.5.encoder.key - .with_i250((1, 1, 4).into()) - .with_i252((1, 1, 1).into()) + .with_i250((1, 2, 4).into()) + .with_i252((1, 2, 1).into()) // past_key_values.5.decoder.value - .with_i260((1, 1, 4).into()) - .with_i262((1, 1, 1).into()) + .with_i260((1, 2, 4).into()) + .with_i262((1, 2, 1).into()) //use_cache_branch - .with_i270((1, 1, 1).into()) + .with_i270((1, 2, 1).into()) .with_profile(false); // build model @@ -133,17 +133,69 @@ fn main() -> Result<(), Box> { options_encoder, options_decoder, options_decoder_merged, - )? - .with_task(Task::Caption(2)); + )?; + // .with_task(Task::Caption(2)); // load images - let images = [DataLoader::try_read("florence2/car.jpg")?]; + let xs = [ + DataLoader::try_read("florence2/car.jpg")?, + DataLoader::try_read("assets/bus.jpg")?, + ]; - // encode image - let image_embeddings = model.encode_images(&images)?; + // run with a batch of tasks + let ys = model.run( + &xs, + &[ + // w/ inputs + Task::Caption(0), + Task::Caption(1), + Task::Caption(2), + Task::Ocr, + Task::RegionProposal, + Task::ObjectDetection, + Task::DenseRegionCaption, + // // Task::OcrWithRegion, // TODO + // w/o inputs + // Task::OpenSetDetection("A green car".into()), + // Task::CaptionToPhraseGrounding("A green car".into()), + // Task::ReferringExpressionSegmentation("A green car".into()), + // Task::RegionToSegmentation(702, 575, 866, 772), + // Task::RegionToCategory(52, 332, 932, 774), + // Task::RegionToDescription(52, 332, 932, 774), + // Task::RegionToOcr(100, 100, 300, 300), + ], + )?; - // caption - let _ys = model.caption(&image_embeddings, true)?; // display results + // annotator + for (task, ys_) in ys.iter() { + match task { + Task::Caption(_) | Task::Ocr => println!("Task: {:?}\n{:?}\n", task, ys_), + Task::DenseRegionCaption => { + let annotator = Annotator::default() + .without_bboxes_conf(true) + .with_bboxes_thickness(4) + .with_saveout("Florence2-DenseRegionCaption"); + annotator.annotate(&xs, ys_); + } + Task::RegionProposal => { + let annotator = Annotator::default() + .without_bboxes_conf(true) + .without_bboxes_name(true) + .with_bboxes_thickness(4) + .with_saveout("Florence2-RegionProposal"); + annotator.annotate(&xs, ys_); + } + Task::ObjectDetection => { + let annotator = Annotator::default() + .without_bboxes_conf(true) + .with_bboxes_thickness(4) + .with_saveout("Florence2-ObjectDetection"); + annotator.annotate(&xs, ys_); + } + + _ => (), + } + } Ok(()) } diff --git a/src/core/min_opt_max.rs b/src/core/min_opt_max.rs index 71c4970..803d4d2 100644 --- a/src/core/min_opt_max.rs +++ b/src/core/min_opt_max.rs @@ -50,4 +50,14 @@ impl MinOptMax { max: opt, } } + + pub fn update(&mut self, opt: isize) { + self.opt = opt; + if self.min > opt { + self.min = opt; + } + if self.max < opt { + self.max = opt; + } + } } diff --git a/src/core/task.rs b/src/core/task.rs index 382dbdb..54a914c 100644 --- a/src/core/task.rs +++ b/src/core/task.rs @@ -1,4 +1,4 @@ -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Ord, Eq, PartialOrd, PartialEq)] pub enum Task { Untitled, @@ -18,15 +18,21 @@ pub enum Task { /// 0 for brief, 1 for detailed, 2 for more detailed Caption(u8), + /// Region proposal task, detecting all objects in the image. + /// Input: image + /// Output: bounding boxes (bboxes) + RegionProposal, + /// Object detection task, detecting all objects in the image. /// Input: image /// Output: bounding boxes (bboxes), class labels, and optional scores for the detected objects ObjectDetection, - /// Region proposal task, detecting all objects in the image. + /// Open set detection task, detecting and classifying objects in an image, with the ability to handle unseen or unknown objects. /// Input: image - /// Output: bounding boxes (bboxes) - RegionProposal, + /// Output: bounding boxes, class labels (including an "unknown" category for unfamiliar objects), and detection scores + /// Open set detection task, with String query + OpenSetDetection(String), /// Task for generating brief descriptions of dense regions in the image. /// Input: image @@ -87,37 +93,36 @@ pub enum Task { /// Phrase grounding task, finding the region in an image corresponding to a text description. /// Input: image and text /// Output: image region and the corresponding phrase - PhraseGrounding, + /// caption to phrase grounding + CaptionToPhraseGrounding(String), /// Referring expression segmentation task, segmenting objects in the image based on a text description. /// Input: image and referring expression /// Output: a segmentation mask for the object referred to by the text - ReferringExpressionSegmentation, + ReferringExpressionSegmentation(String), /// Region-to-segmentation task, similar to combining object detection with segmentation (e.g., YOLO + SAM). /// Input: image and region proposals /// Output: segmentation masks for the regions - RegionToSegmentation, - - /// Open set detection task, detecting and classifying objects in an image, with the ability to handle unseen or unknown objects. - /// Input: image - /// Output: bounding boxes, class labels (including an "unknown" category for unfamiliar objects), and detection scores - OpenSetDetection, + /// Region, bbox: top-left, bottom-right + RegionToSegmentation(u32, u32, u32, u32), /// Region-to-category classification task, classifying the object in a given region of the image. /// Input: image and region /// Output: class label for the region - RegionToCategory, + /// Region, bbox: top-left, bottom-right + RegionToCategory(u32, u32, u32, u32), /// Region-to-description task, generating a detailed description for a given region in the image. /// Input: image and region /// Output: a detailed textual description for the region - RegionToDescription, + /// Region, bbox: top-left, bottom-right + RegionToDescription(u32, u32, u32, u32), /// Visual question answering (VQA) task, answering questions related to an image. /// Input: image and question text /// Output: the answer to the question - Vqa, + Vqa(String), /// Optical character recognition (OCR) task, recognizing text in an image. /// Input: image @@ -128,5 +133,63 @@ pub enum Task { /// Input: image /// Output: recognized text and its bounding box in the image OcrWithRegion, - RegionToOcr, + + /// Region, bbox: top-left, bottom-right + RegionToOcr(u32, u32, u32, u32), +} + +impl Task { + pub fn prompt_for_florence2(&self) -> anyhow::Result { + let prompt = match self { + Self::Untitled => anyhow::bail!("No task specified."), + Self::Caption(0) => "What does the image describe?".to_string(), + Self::Caption(1) => "Describe in detail what is shown in the image.".to_string(), + Self::Caption(2) => "Describe with a paragraph what is shown in the image.".to_string(), + Self::Ocr => "What is the text in the image?".to_string(), + Self::OcrWithRegion => "What is the text in the image, with regions?".to_string(), + Self::ObjectDetection => { + "Locate the objects with category name in the image.".to_string() + } + Self::DenseRegionCaption => { + "Locate the objects in the image, with their descriptions.".to_string() + } + Self::RegionProposal => "Locate the region proposals in the image.".to_string(), + Self::OpenSetDetection(text) => { + format!("Locate {} in the image.", text) + } + Self::CaptionToPhraseGrounding(text) => { + format!("Locate the phrases in the caption: {}", text) + } + Self::ReferringExpressionSegmentation(text) => { + format!("Locate {} in the image with mask", text) + } + Self::RegionToSegmentation(x0, y0, x1, y1) => { + format!( + "What is the polygon mask of region ", + x0, y0, x1, y1 + ) + } + Self::RegionToCategory(x0, y0, x1, y1) => { + format!( + "What is the region ?", + x0, y0, x1, y1 + ) + } + Self::RegionToDescription(x0, y0, x1, y1) => { + format!( + "What does the region describe?", + x0, y0, x1, y1 + ) + } + Self::RegionToOcr(x0, y0, x1, y1) => { + format!( + "What text is in the region ?", + x0, y0, x1, y1 + ) + } + _ => anyhow::bail!("Unsupported task."), + }; + + Ok(prompt) + } } diff --git a/src/models/florence2.rs b/src/models/florence2.rs index 04c7943..c9ee618 100644 --- a/src/models/florence2.rs +++ b/src/models/florence2.rs @@ -1,10 +1,10 @@ use anyhow::Result; use image::DynamicImage; -use ndarray::s; -use std::io::Write; +use ndarray::{s, Axis}; +use std::collections::BTreeMap; use tokenizers::Tokenizer; -use crate::{LogitsSampler, MinOptMax, Ops, Options, OrtEngine, Task, TokenizerStream, Xs, X, Y}; +use crate::{Bbox, LogitsSampler, MinOptMax, Ops, Options, OrtEngine, Quantizer, Task, Xs, X, Y}; #[derive(Debug)] pub struct Florence2 { @@ -16,8 +16,9 @@ pub struct Florence2 { pub height: MinOptMax, pub width: MinOptMax, pub batch: MinOptMax, - tokenizer: TokenizerStream, - task: Task, + tokenizer: Tokenizer, + max_length: usize, + quantizer: Quantizer, } impl Florence2 { @@ -38,7 +39,6 @@ impl Florence2 { vision_encoder.height().to_owned(), vision_encoder.width().to_owned(), ); - let task = options_text_embed.task; let tokenizer = options_text_embed .tokenizer .ok_or(anyhow::anyhow!("No tokenizer file found"))?; @@ -47,7 +47,8 @@ impl Florence2 { Ok(x) => x, }; - let tokenizer = TokenizerStream::new(tokenizer); + let quantizer = Quantizer::default(); + let max_length = 1024; // dry run vision_encoder.dry_run()?; @@ -66,7 +67,8 @@ impl Florence2 { width, batch, tokenizer, - task, + max_length, + quantizer, }) } @@ -86,21 +88,165 @@ impl Florence2 { Ok(ys) } - pub fn caption(&mut self, image_embeddings: &X, display: bool) -> Result> { - let mut ys: Vec = Vec::new(); + pub fn run(&mut self, xs: &[DynamicImage], tasks: &[Task]) -> Result>> { + let mut ys: BTreeMap> = BTreeMap::new(); - // encode prompt - let input_ids = self - .construct_prompt(None)? - .insert_axis(0)? - .repeat(0, self.batch())?; - let text_embedings = self.text_embed.run(Xs::from(input_ids))?[0] - .clone() - .repeat(0, self.batch())?; + // encode batch images + let image_embeddings = self.encode_images(xs)?; - // concate image_embeddings and prompt embeddings - let inputs_embeds = image_embeddings.clone().concatenate(&text_embedings, 1)?; + // note: the length of xs is not always equal to batch size + self.batch.update(xs.len() as isize); + + // tasks loop + for task in tasks.iter() { + let mut ys_task: Vec = Vec::new(); + + // construct prompt and encode + let input_ids = self + .encode_prompt(task)? + .insert_axis(0)? + .repeat(0, self.batch())?; + let text_embeddings = self.text_embed.run(Xs::from(input_ids))?[0].clone(); + + // run + let texts = self.run_batch(&image_embeddings, &text_embeddings)?; + + // postprocess + for batch in 0..self.batch() { + // image size + let image_width = xs[batch].width() as usize; + let image_height = xs[batch].height() as usize; + + // texts clean up + let text = texts[batch] + .as_str() + .replace("", "") + .replace("", "") + .replace("", ""); + + // parse texts for each task + match task { + Task::Caption(_) | Task::Ocr => { + // pure text + ys_task.push(Y::default().with_texts(&[text])); + } + Task::ObjectDetection => { + let mut y_bboxes = Vec::new(); + + // parse + let elems = Self::loc_parse(&text)?; + + // de-quantize and save in one loop + for (i, elem) in elems.iter().enumerate() { + let name = &elem[0]; + elem[1..].chunks(4).for_each(|chunk| { + let coord: Vec = + chunk.iter().map(|s| s.parse::().unwrap()).collect(); + + // dequantize + let dequantized_bbox = self + .quantizer + .dequantize(&coord, (image_width, image_height)); + + // save + y_bboxes.push( + Bbox::default() + .with_name(name) + .with_xyxy( + dequantized_bbox[0].max(0.0f32).min(image_width as f32), + dequantized_bbox[1] + .max(0.0f32) + .min(image_height as f32), + dequantized_bbox[2], + dequantized_bbox[3], + ) + .with_id(i as _), + ); + }); + } + + ys_task.push(Y::default().with_bboxes(&y_bboxes)); + } + Task::RegionProposal => { + let mut y_bboxes = Vec::new(); + + // parse + let elems = Self::loc_parse(&text)?; + + // de-quantize and save in one loop + elems[0].chunks(4).enumerate().for_each(|(i, chunk)| { + let coord: Vec = + chunk.iter().map(|s| s.parse::().unwrap()).collect(); + + // dequantize + let dequantized_bbox = self + .quantizer + .dequantize(&coord, (image_width, image_height)); + + // save + y_bboxes.push( + Bbox::default() + .with_xyxy( + dequantized_bbox[0].max(0.0f32).min(image_width as f32), + dequantized_bbox[1].max(0.0f32).min(image_height as f32), + dequantized_bbox[2], + dequantized_bbox[3], + ) + .with_id(i as _), + ); + }); + + ys_task.push(Y::default().with_bboxes(&y_bboxes)); + } + Task::DenseRegionCaption => { + let mut y_bboxes = Vec::new(); + + // parse + let elems = Self::loc_parse(&text)?; + + // de-quantize and save in one loop + for (i, elem) in elems.iter().enumerate() { + let name = &elem[0]; + elem[1..].chunks(4).for_each(|chunk| { + let coord: Vec = + chunk.iter().map(|s| s.parse::().unwrap()).collect(); + + // dequantize + let dequantized_bbox = self + .quantizer + .dequantize(&coord, (image_width, image_height)); + + // save + y_bboxes.push( + Bbox::default() + .with_name(name) + .with_xyxy( + dequantized_bbox[0].max(0.0f32).min(image_width as f32), + dequantized_bbox[1] + .max(0.0f32) + .min(image_height as f32), + dequantized_bbox[2], + dequantized_bbox[3], + ) + .with_id(i as _), + ); + }); + } + + ys_task.push(Y::default().with_bboxes(&y_bboxes)); + } + _ => todo!(), + }; + } + + ys.insert(task.clone(), ys_task); + } + Ok(ys) + } + fn run_batch(&mut self, image_embeddings: &X, text_embeddings: &X) -> Result> { + // concate image_embeddings and prompt embeddings + let inputs_embeds = image_embeddings.clone().concatenate(text_embeddings, 1)?; let attention_mask = X::ones(&[self.batch(), inputs_embeds.dims()[1]]); // encoder @@ -137,12 +283,14 @@ impl Florence2 { let encoder_k5 = decoder_outputs[23].clone(); let encoder_v5 = decoder_outputs[24].clone(); - let mut y_text = String::new(); - let mut generated_tokens = Vec::new(); + let mut generated_tokens: Vec> = vec![vec![]; self.batch()]; + let mut finished = vec![false; self.batch()]; + + // save last batch tokens + let mut last_tokens: Vec = vec![0.; self.batch()]; - // TODO: batch iter let mut logits_sampler = LogitsSampler::new(); - loop { + for _ in 0..self.max_length { let logits = &decoder_outputs["logits"]; let decoder_k0 = &decoder_outputs[1]; let decoder_v0 = &decoder_outputs[2]; @@ -157,37 +305,37 @@ impl Florence2 { let decoder_k5 = &decoder_outputs[21]; let decoder_v5 = &decoder_outputs[22]; - let next_token_logits = logits - .slice(s![.., -1.., ..]) - .to_owned() - .into_raw_vec_and_offset() - .0; - - let token_id = logits_sampler.decode(&next_token_logits)?; - generated_tokens.push(token_id as f32); - - // - if token_id == 2 { - break; + // Decode each token for each batch + for (i, logit) in logits.axis_iter(Axis(0)).enumerate() { + if !finished[i] { + let token_id = logits_sampler.decode( + &logit + .slice(s![-1, ..]) + .into_owned() + .into_raw_vec_and_offset() + .0, + )?; // + generated_tokens[i].push(token_id); + + // update last_tokens + last_tokens[i] = token_id as f32; + + if token_id == 2 { + finished[i] = true; + } + } } - // streaming generation - if let Some(t) = self.tokenizer.next_token(token_id)? { - y_text.push_str(&t); - if display { - print!("{t}"); - std::thread::sleep(std::time::Duration::from_millis(2)); - } - std::io::stdout().flush()?; + // all finished? + if finished.iter().all(|&x| x) { + break; } // next input text embedding - let next_token = X::from(vec![token_id as f32]) - .insert_axis(0)? - .repeat(0, self.batch())?; + let next_tokens = X::from(last_tokens.clone()).insert_axis(1)?; // decode - let inputs_embeds = &self.text_embed.run(Xs::from(next_token))?[0].clone(); + let inputs_embeds = &self.text_embed.run(Xs::from(next_tokens))?[0].clone(); let use_cache = X::ones(&[1]); decoder_outputs = self.decoder_merged.run(Xs::from(vec![ attention_mask.clone(), @@ -220,76 +368,56 @@ impl Florence2 { use_cache, ]))?; } - if display { - println!(); - } - self.tokenizer.clear(); - ys.push(Y::default().with_texts(&[y_text])); + // batch decode + let texts = match self.tokenizer.decode_batch( + &generated_tokens + .iter() + .map(|tokens| tokens.as_slice()) + .collect::>(), + false, + ) { + Err(err) => anyhow::bail!("{:?}", err), + Ok(xs) => xs, + }; - Ok(ys) + Ok(texts) } - pub fn construct_prompt(&self, text: Option<&str>) -> Result { - let prompt = match self.task { - Task::Untitled => anyhow::bail!("No task specified."), - Task::Caption(0) => "What does the image describe?".to_string(), - Task::Caption(1) => "Describe in detail what is shown in the image.".to_string(), - Task::Caption(2) => "Describe with a paragraph what is shown in the image.".to_string(), - Task::Ocr => "What is the text in the image?".to_string(), - Task::OcrWithRegion => "What is the text in the image, with regions?".to_string(), - Task::ObjectDetection => { - "Locate the objects with category name in the image.".to_string() - } - Task::DenseRegionCaption => { - "Locate the objects in the image, with their descriptions.".to_string() - } - Task::RegionProposal => "Locate the region proposals in the image.".to_string(), - Task::PhraseGrounding => format!( - "Locate the phrases in the caption: {}", - text.unwrap_or_default() - ), - Task::ReferringExpressionSegmentation => { - format!("Locate {} in the image with mask", text.unwrap_or_default()) - } - Task::RegionToSegmentation => { - format!( - "What is the polygon mask of region {}", - text.unwrap_or_default() - ) - } - Task::OpenSetDetection => { - format!("Locate {} in the image.", text.unwrap_or_default()) - } - Task::RegionToCategory => { - format!("What is the region {}?", text.unwrap_or_default()) - } - Task::RegionToDescription => { - format!( - "What does the region {} describe?", - text.unwrap_or_default() - ) - } - Task::RegionToOcr => { - format!("What text is in the region {}?", text.unwrap_or_default()) - } - - _ => todo!(), - }; - - let encodings = match self.tokenizer.tokenizer().encode(prompt, true) { + pub fn encode_prompt(&self, task: &Task) -> Result { + let prompt = task.prompt_for_florence2()?; + let encodings = match self.tokenizer.encode(prompt, true) { Err(err) => anyhow::bail!("{}", err), Ok(x) => x, }; let ids: Vec = encodings.get_ids().iter().map(|x| *x as f32).collect(); - let ids = X::from(ids); - Ok(ids) + Ok(X::from(ids)) } - pub fn with_task(mut self, task: Task) -> Self { - self.task = task; - self + fn loc_parse(hay: &str) -> Result>> { + let pattern = r"(?i)(\d+)>)|(?[^<]+)"; + let re = regex::Regex::new(pattern)?; + let mut ys: Vec> = Vec::new(); + let mut y = Vec::new(); + + for cap in re.captures_iter(hay) { + if let Some(loc) = cap.name("coord") { + y.push(loc.as_str().to_string()); + } else if let Some(text) = cap.name("name") { + if !text.as_str().is_empty() { + if !y.is_empty() { + ys.push(y); + y = Vec::new(); + } + y.push(text.as_str().to_string()); + } + } + } + if !y.is_empty() { + ys.push(y); + } + Ok(ys) } pub fn batch(&self) -> usize { diff --git a/src/utils/mod.rs b/src/utils/mod.rs index dc4aef1..9c8710b 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -5,9 +5,11 @@ use rand::{distributions::Alphanumeric, thread_rng, Rng}; pub mod colormap256; pub mod names; +mod quantizer; pub use colormap256::*; pub use names::*; +pub use quantizer::Quantizer; pub(crate) const CHECK_MARK: &str = "✅"; pub(crate) const CROSS_MARK: &str = "❌"; diff --git a/src/utils/quantizer.rs b/src/utils/quantizer.rs new file mode 100644 index 0000000..554f5f2 --- /dev/null +++ b/src/utils/quantizer.rs @@ -0,0 +1,78 @@ +// TODO + +#[derive(Debug)] +pub struct Quantizer { + bins: (usize, usize), +} + +impl Default for Quantizer { + fn default() -> Self { + Self { bins: (1000, 1000) } + } +} + +impl Quantizer { + pub fn new(bins: (usize, usize)) -> Self { + Quantizer { bins } + } + + fn quantize_value(&self, val: f32, bin_size: f64, max_bin: i32) -> i32 { + ((val as f64 / bin_size).floor() as i32).clamp(0, max_bin - 1) + } + + fn dequantize_value(&self, val: i32, bin_size: f64) -> f32 { + ((val as f64 + 0.5) * bin_size) as f32 + } + + fn quantize_internal(&self, input: &[f32], size: (usize, usize)) -> Vec { + let (bins_w, bins_h) = self.bins; + let (size_w, size_h) = size; + + let size_per_bin_w = size_w as f64 / bins_w as f64; + let size_per_bin_h = size_h as f64 / bins_h as f64; + + match input.len() { + 4 => vec![ + self.quantize_value(input[0], size_per_bin_w, bins_w as i32), + self.quantize_value(input[1], size_per_bin_h, bins_h as i32), + self.quantize_value(input[2], size_per_bin_w, bins_w as i32), + self.quantize_value(input[3], size_per_bin_h, bins_h as i32), + ], + 2 => vec![ + self.quantize_value(input[0], size_per_bin_w, bins_w as i32), + self.quantize_value(input[1], size_per_bin_h, bins_h as i32), + ], + _ => panic!("Unsupported input length"), + } + } + + fn dequantize_internal(&self, input: &[i32], size: (usize, usize)) -> Vec { + let (bins_w, bins_h) = self.bins; + let (size_w, size_h) = size; + + let size_per_bin_w = size_w as f64 / bins_w as f64; + let size_per_bin_h = size_h as f64 / bins_h as f64; + + match input.len() { + 4 => vec![ + self.dequantize_value(input[0], size_per_bin_w), + self.dequantize_value(input[1], size_per_bin_h), + self.dequantize_value(input[2], size_per_bin_w), + self.dequantize_value(input[3], size_per_bin_h), + ], + 2 => vec![ + self.dequantize_value(input[0], size_per_bin_w), + self.dequantize_value(input[1], size_per_bin_h), + ], + _ => panic!("Unsupported input length"), + } + } + + pub fn quantize(&self, input: &[f32], size: (usize, usize)) -> Vec { + self.quantize_internal(input, size) + } + + pub fn dequantize(&self, input: &[i32], size: (usize, usize)) -> Vec { + self.dequantize_internal(input, size) + } +}