diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index afdeab26fb..1c9f9b1c57 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,8 +39,6 @@ repos: exclude: thirdparty/|tests/run.py - id: requirements-txt-fixer exclude: thirdparty/|tests/run.py - - id: double-quote-string-fixer - exclude: thirdparty/|tests/run.py - id: check-merge-conflict exclude: thirdparty/|tests/run.py - id: fix-encoding-pragma diff --git a/.pre-commit-config_local.yaml b/.pre-commit-config_local.yaml index d9d36e7bab..d3349da4ff 100644 --- a/.pre-commit-config_local.yaml +++ b/.pre-commit-config_local.yaml @@ -37,8 +37,6 @@ repos: exclude: thirdparty/|tests/run.py - id: end-of-file-fixer exclude: thirdparty/ - - id: requirements-txt-fixer - exclude: thirdparty/|tests/run.py - id: double-quote-string-fixer exclude: thirdparty/|tests/run.py - id: check-merge-conflict diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py index 143bd703e7..4021515741 100644 --- a/swift/plugin/orm.py +++ b/swift/plugin/orm.py @@ -383,5 +383,5 @@ def __call__(self, completions, **kwargs) -> List[float]: 'format': Format, 'react_format': ReActFormat, 'cosine': CosineReward, - 'repetition': RepetitionPenalty, + 'repetition': RepetitionPenalty } diff --git a/swift/plugin/tool_call.py b/swift/plugin/tool_call.py new file mode 100644 index 0000000000..8f8500654f --- /dev/null +++ b/swift/plugin/tool_call.py @@ -0,0 +1,10 @@ +from typing import Tuple, Any, Optional + + +class TOOL_CALL: + + def __call__(self, completion: str) -> Tuple[Any, bool, Optional[float]]: + raise NotImplementedError + + +tools = {} diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 62417c19e7..14a8a77e25 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass from functools import wraps -from typing import Any, Dict, Literal, Optional, Union +from typing import Any, Dict, Literal, Optional, Union, Callable import torch import torch.utils.checkpoint @@ -104,6 +104,9 @@ class GRPOArgumentsMixin: offload_optimizer: bool = False offload_model: bool = False gc_collect_after_offload: bool = False + is_reward_tool_call: bool = True #是否额外单独计算每个tool call的format得分 + tool_call_weight: float = 1.0 + tool_call: str = None @dataclass diff --git a/swift/trainers/rlhf_arguments.py b/swift/trainers/rlhf_arguments.py index b4add66cfd..2c37cf09a9 100644 --- a/swift/trainers/rlhf_arguments.py +++ b/swift/trainers/rlhf_arguments.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Optional +from typing import List, Optional, Callable from trl import CPOConfig as HfCPOConfig from trl import DPOConfig as HfDPOConfig diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 86d850fcaa..a683ed0a53 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from huggingface/trl. +from builtins import property import concurrent.futures import inspect import os @@ -13,8 +14,7 @@ from math import ceil from queue import Queue from types import MethodType -from typing import Any, Callable, Dict, List, Optional, Union - +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn @@ -30,7 +30,7 @@ get_node_setting, is_lmdeploy_available, is_vllm_available, is_wandb_available) from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin - +from swift.plugin.tool_call import tools try: from trl.extras.profiling import profiling_decorator except ImportError: @@ -89,6 +89,10 @@ class DataCache: distributed_idx: List[List] = field(default_factory=list) +def tool_call(): + pass + + class GRPOTrainer(RLHFTrainerMixin, SwiftMixin, HFGRPOTrainer): executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) @@ -101,6 +105,20 @@ def __init__(self, **kwargs): from swift.trainers.rlhf_arguments import GRPOConfig args: GRPOConfig = kwargs['args'] + #add tool call + self.tool_call = tools[args.tool_call] + args.tool_call_weight = args.tool_call_weight + self.reward_weights = torch.ones(1, dtype=torch.float32) #通过配置 + self.is_reward_tool_call = args.is_reward_tool_call #add GRPO config + # In the __init__ method, after initializing reward_weights: + if self.is_reward_tool_call and self.tool_call is not None: + # Add a weight for tool call rewards + if args.tool_call_weight is not None: + self.reward_weights = torch.cat( + [self.reward_weights, + torch.tensor([args.tool_call_weight], dtype=torch.float32)]) + else: + self.reward_weights = torch.cat([self.reward_weights, torch.ones(1, dtype=torch.float32)]) self.args = args self.queue = None self.train_queue = Queue() @@ -693,70 +711,103 @@ def old_policy(self): def _generate_and_score_completions( self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: - device = self.accelerator.device - # Generate completions using either vLLM or regular generation - if self.args.use_vllm or self.args.use_lmdeploy: - inputs, outputs = self._fast_infer(inputs) - # Broadcast the completions from the main process to all processes, ensuring each process receives its - # corresponding slice. - # outputs = broadcast_object_list(outputs, from_process=0) - else: - # Regular generation path - is_multimodal = self.model.model_meta.is_multimodal - if is_multimodal: - models = self.template.remove_post_encode_hook() - with unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation): - # same reference - outputs = self.engine.infer(inputs, self.request_config, use_tqdm=False) - self.model.train() - if is_multimodal: - self.template.register_post_encode_hook(models) - - # Slice to keep only the local part of the data + total_tool_calls = 0 + # 确定当前进程的数据切片 process_slice = slice( self.accelerator.process_index * len(inputs), (self.accelerator.process_index + 1) * len(inputs), ) - if self.args.use_vllm or self.args.use_lmdeploy: - outputs = outputs[process_slice] + # Process tool calls if tool_call function is available + if self.tool_call is not None: + # Deep copy inputs to preserve the original prompt for tool calling + tool_inputs = deepcopy(inputs) + tool_inputs, total_tool_calls = self._process_tool_calls(tool_inputs) - for i, output in enumerate(outputs): - messages = inputs[i]['messages'] - InferRequest.remove_response(messages) - messages.append({'role': 'assistant', 'content': output.choices[0].message.content}) + # Replace original inputs with tool-processed inputs + inputs = tool_inputs + + else: + if self.args.use_vllm or self.args.use_lmdeploy: + inputs, outputs = self._fast_infer(inputs) + else: + # Regular generation path + is_multimodal = self.model.model_meta.is_multimodal + if is_multimodal: + models = self.template.remove_post_encode_hook() + with unwrap_model_for_generation( + self.model_wrapped, self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation): + outputs = self.engine.infer(inputs, self.request_config, use_tqdm=False) + self.model.train() + if is_multimodal: + self.template.register_post_encode_hook(models) + + # # Slice to keep only the local part of the data + # process_slice = slice( + # self.accelerator.process_index * len(inputs), + # (self.accelerator.process_index + 1) * len(inputs), + # ) + if self.args.use_vllm or self.args.use_lmdeploy: + outputs = outputs[process_slice] + + # Add final assistant responses to message history + for i, output in enumerate(outputs): + messages = inputs[i]['messages'] + InferRequest.remove_response(messages) + messages.append({'role': 'assistant', 'content': output.choices[0].message.content}) + + # Now process all inputs with all messages for training from copy import copy template = copy(self.template) with self._template_context(template): batched_inputs = [template.encode(infer_request) for infer_request in inputs] - outputs = to_device(template.data_collator(batched_inputs), self.model.device) - - # we only need to compute the logits for the completion tokens - labels = outputs.pop('labels') - logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() - outputs['logits_to_keep'] = logits_to_keep - outputs['completion_mask'] = labels[:, -logits_to_keep:] != -100 - + encoded_outputs = to_device(template.data_collator(batched_inputs), self.model.device) + + # Extract labels and create mask for all assistant outputs + labels = encoded_outputs.pop('labels') + # Identify all positions where labels != -100, which captures all assistant outputs + completion_mask = (labels != -100) + + # Ensure we're not exceeding max completion length + if completion_mask.sum(1).max() > self.args.max_completion_length: + # Truncate to max length if needed + truncated_mask = torch.zeros_like(completion_mask) + for i in range(len(truncated_mask)): + # Find all indices where labels != -100 + indices = torch.where(completion_mask[i])[0] + # Take only up to max_completion_length indices + indices = indices[:self.args.max_completion_length] + # Set those indices to True in the truncated mask + truncated_mask[i, indices] = True + completion_mask = truncated_mask + + # Add to outputs + encoded_outputs['completion_mask'] = completion_mask + encoded_outputs['logits_to_keep'] = completion_mask.shape[1] # Use full length + encoded_outputs['labels'] = labels # Put labels back for proper processing + + # Continue with the existing logic for computing logps and rewards with torch.no_grad(): if self.old_policy: - outputs['old_per_token_logps'] = self._get_per_token_logps(self.model, outputs) + encoded_outputs['old_per_token_logps'] = self._get_per_token_logps(self.model, encoded_outputs) else: - outputs['old_per_token_logps'] = None + encoded_outputs['old_per_token_logps'] = None if self.beta == 0.0: ref_per_token_logps = None elif self.ref_model is not None: - ref_per_token_logps = self._get_per_token_logps(self.ref_model, outputs) + ref_per_token_logps = self._get_per_token_logps(self.ref_model, encoded_outputs) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): - ref_per_token_logps = self._get_per_token_logps(self.model, outputs) + ref_per_token_logps = self._get_per_token_logps(self.model, encoded_outputs) + # Calculate rewards rewards_per_func = torch.zeros((len(inputs), len(self.reward_funcs)), device=device) completions = [example['messages'][-1]['content'] for example in inputs] for i, (reward_func, reward_template) in enumerate(zip(self.reward_funcs, self.reward_templates)): - if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models + if isinstance(reward_func, nn.Module): with self._template_context(reward_template): batched_inputs = [reward_template.encode(infer_request) for infer_request in inputs] reward_inputs = to_device(reward_template.data_collator(batched_inputs), reward_func.device) @@ -764,11 +815,17 @@ def _generate_and_score_completions( with torch.inference_mode(): rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] else: - # Repeat all input columns (but "messages" and "completion") to match the number of generations reward_kwargs = RowPreprocessor.rows_to_batched(inputs) output_reward_func = reward_func(completions, **reward_kwargs) rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + # Add tool call rewards if enabled + if self.is_reward_tool_call and self.tool_call is not None: + tool_rewards = torch.tensor([inputs[i].get('_tool_call_reward', 0.0) for i in range(len(inputs))], + dtype=torch.float32, + device=device).unsqueeze(1) + rewards_per_func = torch.cat([rewards_per_func, tool_rewards], dim=1) + rewards_per_func = gather(rewards_per_func) # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) @@ -783,47 +840,66 @@ def _generate_and_score_completions( advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) advantages = advantages[process_slice] - # Log the metrics + # Log metrics mode = 'eval' if self.control.should_evaluate else 'train' - completion_length = self.accelerator.gather_for_metrics(outputs['completion_mask'].sum(1)).float().mean().item() + completion_length = self.accelerator.gather_for_metrics( + encoded_outputs['completion_mask'].sum(1)).float().mean().item() self._metrics[mode]['completion_length'].append(completion_length) - # clip ratio + + # Add tool call metrics + if self.tool_call is not None: + avg_tool_calls = self.accelerator.gather_for_metrics( + torch.tensor([total_tool_calls / len(inputs)], device=device)).mean().item() + self._metrics[mode]['tool_call_nums'].append(avg_tool_calls) + + # Other existing metrics response_clip_ratio = torch.gt( - self.accelerator.gather_for_metrics(outputs['completion_mask'].sum(1)), + self.accelerator.gather_for_metrics(encoded_outputs['completion_mask'].sum(1)), self.args.max_completion_length).float().mean().item() self._metrics[mode]['response_clip_ratio'].append(response_clip_ratio) + reward_per_func = rewards_per_func.mean(0) for i, reward_func in enumerate(self.reward_funcs): - if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models + if isinstance(reward_func, nn.Module): reward_func_name = reward_func.config._name_or_path.split('/')[-1] else: if inspect.isfunction(reward_func): - reward_func_name = reward_func.__name__ # function + reward_func_name = reward_func.__name__ else: - reward_func_name = reward_func.__class__.__name__ # method + reward_func_name = reward_func.__class__.__name__ self._metrics[mode][f'rewards/{reward_func_name}'].append(reward_per_func[i].item()) + # Add tool call reward metric if enabled + if self.is_reward_tool_call and self.tool_call is not None: + self._metrics[mode]['rewards/tool_call'].append(reward_per_func[-1].item()) #todo 应该用反射记录每个tool call的reward + self._metrics[mode]['reward'].append(rewards.mean().item()) self._metrics[mode]['reward_std'].append(std_grouped_rewards.mean().item()) - outputs.update({ + encoded_outputs.update({ 'ref_per_token_logps': ref_per_token_logps, 'advantages': advantages, }) + # Log completions if self.log_completions and self.state.global_step % self.args.logging_steps == 0: - # For logging table = { 'step': [str(self.state.global_step)] * len(rewards), 'messages': [inputs['messages'][:-1] for inputs in gather_object(inputs)], 'completion': gather_object(completions), 'reward': rewards.tolist(), } + + if self.tool_call is not None: + # Log tool call counts for each prompt + tool_call_counts = [inputs[i].get('_tool_call_count', 0) for i in range(len(inputs))] + table['tool_call_nums'] = tool_call_counts * (len(rewards) // len(inputs)) + self.jsonl_writer.append(table) if 'wandb' in self.args.report_to and wandb.run is not None and self.accelerator.is_main_process: import pandas as pd df = pd.DataFrame(table) wandb.log({'completions': wandb.Table(dataframe=df)}) - return outputs + return encoded_outputs @profiling_decorator def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): @@ -867,24 +943,50 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N @profiling_decorator def _get_per_token_logps(self, model, inputs): from trl.trainer.utils import selective_log_softmax - logits_to_keep = inputs['logits_to_keep'] + + # Extract input_ids and completion_mask input_ids = inputs['input_ids'] + completion_mask = inputs.get('completion_mask') + + # Get unwrapped model unwrapped_model = self.accelerator.unwrap_model(model) parameters = inspect.signature(unwrapped_model.forward).parameters - if not unwrapped_model.model_meta.is_multimodal and 'logits_to_keep' in parameters: - # save memory - return super()._get_per_token_logps(model, input_ids, inputs['attention_mask'], logits_to_keep) - inputs = { + + if not unwrapped_model.model_meta.is_multimodal and 'completion_mask' in parameters: + # If model supports direct masking, use that + return super()._get_per_token_logps(model, input_ids, inputs['attention_mask'], completion_mask) + + # Otherwise, compute full logits and mask afterwards + forward_inputs = { k: v for k, v in inputs.items() if k not in - ['logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps'] + ['completion_mask', 'logits_to_keep', 'ref_per_token_logps', 'advantages', 'old_per_token_logps'] } - logits = model(**inputs).logits - # exclude the last logit: it corresponds to the next token pred - logits = logits[:, -(logits_to_keep + 1):-1, :] + + # Get logits from model + logits = model(**forward_inputs).logits logits = logits / self.temperature - input_ids = input_ids[:, -logits_to_keep:] - return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + + # Find positions corresponding to completion tokens (where completion_mask is True) + # Create a mask that selects these positions + all_positions = torch.where(completion_mask) + batch_indices = all_positions[0] + seq_indices = all_positions[1] + + # Get the corresponding logits and input IDs + selected_logits = logits[batch_indices, seq_indices - 1] # -1 because logits predict next token + selected_tokens = input_ids[batch_indices, seq_indices] + + # Calculate log probabilities for selected tokens + log_probs = selective_log_softmax(selected_logits.unsqueeze(1), selected_tokens.unsqueeze(1)).squeeze(1) + + # Create output tensor with same shape as completion_mask, filled with zeros + result = torch.zeros_like(completion_mask, dtype=torch.float) + + # Place log probs at appropriate positions + result[batch_indices, seq_indices] = log_probs + + return result def evaluation_loop(self, dataloader, *args, **kwargs): self.queue = self.eval_queue @@ -896,3 +998,225 @@ def evaluation_loop(self, dataloader, *args, **kwargs): output.metrics.update(metrics) self.queue = self.train_queue return output + + def _process_tool_calls(self, inputs: List[Dict]) -> Tuple[List[Dict], int]: + """ + Args: + inputs: List of input dictionaries with message history + + Returns: + Updated inputs after tool calling process and total number of tool calls + """ + # Track tool call metrics for this batch + tool_call_counts = [0] * len(inputs) + tool_call_rewards = [0.0] * len(inputs) + assistant_tokens_used = [0] * len(inputs) + total_tool_calls = 0 + + # Process inputs until all are finished + active_indices = list(range(len(inputs))) + + while active_indices: + # Generate responses for active inputs with adjusted max_tokens + active_inputs = [inputs[i] for i in active_indices] + + # Create a copy of the request config with adjusted max_tokens for each input + adjusted_request_configs = [] + for idx in active_indices: + tokens_used = assistant_tokens_used[idx] + available_tokens = max(1, self.args.max_completion_length - tokens_used) + + # Create custom request config with adjusted max_tokens + config = copy(self.request_config) + config.max_tokens = available_tokens + adjusted_request_configs.append(config) + + # Generate responses with adjusted token limits + if self.args.use_vllm or self.args.use_lmdeploy: + active_inputs = [inputs[i] for i in active_indices] + _, active_outputs = self._fast_infer_with_custom_configs(active_inputs, adjusted_request_configs) + else: + is_multimodal = self.model.model_meta.is_multimodal + if is_multimodal: + models = self.template.remove_post_encode_hook() + with unwrap_model_for_generation( + self.model_wrapped, self.accelerator, + gather_deepspeed3_params=self.args.ds3_gather_for_generation): + active_outputs = [] + for i, input_data in enumerate(active_inputs): + # Use custom config for each input + output = self.engine.infer([input_data], adjusted_request_configs[i], use_tqdm=False) + active_outputs.extend(output) + self.model.train() + if is_multimodal: + self.template.register_post_encode_hook(models) + + # Process tool calls asynchronously + futures = [] + executor = concurrent.futures.ThreadPoolExecutor(max_workers=len(active_indices)) + + # Start all tool calls in parallel + for i, idx in enumerate(active_indices): + completion = active_outputs[i].choices[0].message.content + tokens_used = len(self.tokenizer(completion).input_ids) + + # Add assistant response to message history + messages = active_inputs[i]['messages'] + InferRequest.remove_response(messages) + messages.append({'role': 'assistant', 'content': completion}) + + # Update token count + assistant_tokens_used[idx] += tokens_used + + futures.append((executor.submit(self.tool_call, completion), idx)) + + # Process results and update inputs + new_active_indices = [] + for future, idx in futures: + tool_result, finish, tool_reward = future.result() + + if not finish: + # Add user message with tool result + inputs[idx]['messages'].append({'role': 'user', 'content': tool_result}) + new_active_indices.append(idx) + + # Update tracking variables + tool_call_counts[idx] += 1 + total_tool_calls += 1 + if self.is_reward_tool_call and not finish: #remove last reponse + tool_call_rewards[idx] += tool_reward + + executor.shutdown() + active_indices = new_active_indices + + # Break if no active indices remain + if not active_indices: + break + + # Store tool call metrics in inputs for later use + for i in range(len(inputs)): + inputs[i]['_tool_call_count'] = tool_call_counts[i] + inputs[i]['_tool_call_reward'] = tool_call_rewards[i] + inputs[i]['_assistant_tokens_used'] = assistant_tokens_used[i] + + return inputs, total_tool_calls + + def _fast_infer_with_custom_configs(self, inputs, request_configs): + """ + Similar to _fast_infer but uses a different request_config for each input. + + Args: + inputs: List of input dicts + request_configs: List of request configs (same length as inputs) + + Returns: + inputs, outputs tuple + """ + if self.args.sleep_level > 0 and self.infer_rank >= 0: + if self.args.offload_model: + self.offload_model() + if self.args.offload_optimizer: + self.offload_optimizer() + if self.args.gc_collect_after_offload: + gc_collect() + self.engine.engine.wake_up() + + # First, have main process load weights if needed + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm_lmdeploy() + self._last_loaded_step = self.state.global_step + + # Gather all prompts + all_inputs = gather_object(inputs) + all_configs = gather_object(request_configs) + + # Distribute inputs to different workers + distributed_idx = self.round_robin(len(all_inputs), get_node_setting()[1] * self.args.num_infer_workers) + + if self.infer_rank >= 0: + _input_slice = np.array(all_inputs)[distributed_idx[self.infer_rank]] + _config_slice = np.array(all_configs)[distributed_idx[self.infer_rank]] + + if self.args.async_generate: + # Modified async_infer to use custom configs + self.async_infer_with_configs(inputs, _input_slice, _config_slice, distributed_idx) + data_cache = self.queue.get() + inputs = data_cache.inputs + outputs = data_cache.outputs + distributed_idx = data_cache.distributed_idx + else: + with set_device_context(self.infer_device): + outputs = [] + # Process each input with its own config + for i in range(len(_input_slice)): + request_config = _config_slice[i] + if self.args.tensor_parallel_size > 1: + request_config.seed += self.state.global_step + output = self.engine.infer([_input_slice[i]], request_config, use_tqdm=False) + outputs.extend(output) + + if self.args.tensor_parallel_size > 1: + if self.infer_rank_tp_0 < 0: + outputs = [] + else: + _outputs = [] + for tp_idx in range(self.args.tensor_parallel_size): + for prompt_idx in range(len(outputs)): + output = deepcopy(outputs[prompt_idx]) + output.choices = [output.choices[tp_idx]] + _outputs.append(output) + outputs = _outputs + else: + if self.args.async_generate: + self.queue.put(DataCache(inputs, [], distributed_idx)) + data_cache = self.queue.get() + inputs = data_cache.inputs + distributed_idx = data_cache.distributed_idx + outputs = [] + + outputs = gather_object(outputs) + outputs = self.reorder_outputs(outputs, distributed_idx) + + if self.args.sleep_level > 0 and self.infer_rank >= 0: + self.engine.engine.sleep(level=self.args.sleep_level) + if self.args.gc_collect_after_offload: + gc_collect() + if self.args.offload_model: + self.load_model() + if self.args.offload_optimizer: + self.load_optimizer() + + return inputs, outputs + + def async_infer_with_configs(self, inputs, inputs_slice, configs_slice, distributed_idx): + """ + Asynchronous inference with custom configs for each input. + + Args: + inputs: Original inputs + inputs_slice: Slice of inputs for this worker + configs_slice: Slice of configs corresponding to inputs_slice + distributed_idx: Distribution of indices among workers + """ + + def infer_task(): + with set_device_context(self.infer_device): + results = [] + # Process each input with its own config + for i in range(len(inputs_slice)): + result = self.engine.infer( + infer_requests=[inputs_slice[i]], request_config=configs_slice[i], use_tqdm=False) + results.extend(result) + return results + + future: Future = self.executor.submit(infer_task) + + def done(_self): + self.queue.put(DataCache(inputs, _self.result(), distributed_idx)) + + future.add_done_callback(done) + + @property + def tokenizer(self): + """Get the tokenizer from the template.""" + return self.processing_class