diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 36e5e1774aa0d..a777e5c3f22a7 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -67,9 +67,16 @@ def sampler_output( execute_model_req.seq_group_metadata_list): seq_data = next(iter(seq_group_metadata.seq_data.values())) + seq_len = seq_data.get_len() + # When seq_len is less than 3072 (3K), we use CPU to perform + # the ngram match. Otherwise, we use the device specified in + # the model config (normally GPU). 3072 is a rough threshold + # based on profiling on H100, and it can be adjusted based + # on the actual performance on different hardware. + cur_device = "cpu" if seq_len < 3072 else self.device input_ids = torch.as_tensor(seq_data.get_token_ids(), dtype=torch.long, - device=self.device) + device=cur_device) input_length = seq_data.get_len() for ngram_size in range( @@ -91,17 +98,15 @@ def sampler_output( # first_match includes "values" (bool), indicating whether # the match is found, and "indices", indicating the index # of the first match. - # Note that "first_match.values.item()" triggers GPU-CPU - # sync so it is a bit inefficient, but we have not found - # a better way to do this. first_match = matches.max(dim=-1) if first_match.values.item(): proposal_start_idx = first_match.indices.add_(ngram_size) spec_indices = ( proposal_start_idx).repeat(sample_len) + torch.arange( - sample_len, device=self.device) + sample_len, device=cur_device) spec_indices.clamp_(max=input_ids.shape[-1] - 1) - res = input_ids.gather(dim=-1, index=spec_indices) + res = input_ids.gather(dim=-1, + index=spec_indices).to(self.device) token_id_list.append(res) token_prob_list.append( torch.nn.functional.one_hot(