diff --git a/pdelfin/train/dataprep.py b/pdelfin/train/dataprep.py index ef600da..259c398 100644 --- a/pdelfin/train/dataprep.py +++ b/pdelfin/train/dataprep.py @@ -93,7 +93,7 @@ def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim } -def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: int, target_anchor_text_len: int): +def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]): # Process each example in the batch using the helper function processed_examples = [] for i in range(len(batch["response"])): diff --git a/pdelfin/train/utils.py b/pdelfin/train/utils.py index b59b07f..2fdca54 100644 --- a/pdelfin/train/utils.py +++ b/pdelfin/train/utils.py @@ -74,8 +74,8 @@ def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset partial( batch_prepare_data_for_qwen2_training, processor=processor, - target_longest_image_dim=target_longest_image_dim, - target_anchor_text_len=target_anchor_text_len, + target_longest_image_dim=list(target_longest_image_dim), + target_anchor_text_len=list(target_anchor_text_len), ) ) @@ -86,8 +86,8 @@ def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset partial( batch_prepare_data_for_qwen2_training, processor=processor, - target_longest_image_dim=source.target_longest_image_dim, - target_anchor_text_len=source.target_anchor_text_len, + target_longest_image_dim=list(source.target_longest_image_dim), + target_anchor_text_len=list(source.target_anchor_text_len), ) ) for source in config.valid_data.sources