diff --git a/video_chat2/dataset/base_dataset.py b/video_chat2/dataset/base_dataset.py index c80241e..63be21c 100644 --- a/video_chat2/dataset/base_dataset.py +++ b/video_chat2/dataset/base_dataset.py @@ -62,8 +62,16 @@ def load_and_transform_media_data(self, index, data_path): else: return self.load_and_transform_media_data_video(index, data_path) - def load_and_transform_media_data_image(self, index, data_path): + def load_and_transform_media_data_image(self, index, data_path, dynamic_config=None): image = load_image_from_path(data_path, client=self.client) + if dynamic_config: + local_size = dynamic_config["local_size"] + hd_num = dynamic_config["hd_num"] + padding = dynamic_config["padding"] + if padding: + image = HD_transform_padding(image.float(), image_size=local_size, hd_num=hd_num) + else: + image = HD_transform_no_padding(image.float(), image_size=local_size, hd_num=hd_num) image = self.transform(image) return image, index