Skip to content

Commit

Permalink
Imagen cherry pick from develop (#980)
Browse files Browse the repository at this point in the history
* fix imagen training bug (#966)

* fix imagen training bug

* modify resolution

* modify resolution

Co-authored-by: Liujie0926 <44688141+Liujie0926@users.noreply.github.com>

* fix imagen training bug (#966)

* fix imagen training bug

* modify resolution

* modify resolution

Co-authored-by: Liujie0926 <44688141+Liujie0926@users.noreply.github.com>

Co-authored-by: Liujie0926 <44688141+Liujie0926@users.noreply.github.com>
  • Loading branch information
firestonelib and Liujie0926 authored Dec 1, 2022
1 parent 4779491 commit 63e631d
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 11 deletions.
4 changes: 2 additions & 2 deletions ppfleetx/configs/multimodal/imagen/imagen_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ Data:
Train:
dataset:
name: ImagenDataset
input_path: ./data/cc12m_base64.lst
input_path: ./projects/imagen/filelist/cc12m_base64.lst
shuffle: True
input_resolusion: 64
input_resolution: 64
max_seq_len: 128
loader:
num_workers: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Data:
name: ImagenDataset
input_path: ./data/cc12m_base64.lst
shuffle: True
input_resolusion: 1024
input_resolution: 1024
max_seq_len: 128
loader:
num_workers: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Data:
name: ImagenDataset
input_path: ./data/cc12m_base64.lst
shuffle: True
input_resolusion: 256
input_resolution: 256
max_seq_len: 128
loader:
num_workers: 8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Data:
name: ImagenDataset
input_path: ./data/cc12m_base64.lst
shuffle: True
input_resolusion: 512
input_resolution: 512
max_seq_len: 128
loader:
num_workers: 8
Expand Down
6 changes: 3 additions & 3 deletions ppfleetx/data/dataset/multimodal_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self,
input_path,
input_format='embed_base64_cc12m',
shuffle=False,
input_resolusion=64,
input_resolution=64,
second_size=256,
max_seq_len=128,
filter_image_resolution=128,
Expand All @@ -109,7 +109,7 @@ def __init__(self,
self.filename = get_files(
input_path, gpu_num=device_world_size, shuffle=shuffle)
self.filter_image_resolution = filter_image_resolution
self.input_resolusion = input_resolusion
self.input_resolution = input_resolution
self.max_seq_len = max_seq_len
self.split = split
if not isinstance(self.filename, list):
Expand Down Expand Up @@ -170,7 +170,7 @@ def __getitem__(self, index):
text_embed = self.load_file(data_dir, data[1])
attn_mask = self.load_file(data_dir, data[2])
image = self.base64_to_image(data[3])
image = data_augmentation_for_imagen(image, self.input_resolusion)
image = data_augmentation_for_imagen(image, self.input_resolution)

return image, paddle.to_tensor(
text_embed, dtype='float32'), paddle.to_tensor(
Expand Down
64 changes: 64 additions & 0 deletions projects/imagen/filelist/cc12m_base64.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
cc12m_base64/part-0
cc12m_base64/part-1
cc12m_base64/part-2
cc12m_base64/part-3
cc12m_base64/part-4
cc12m_base64/part-5
cc12m_base64/part-6
cc12m_base64/part-7
cc12m_base64/part-0
cc12m_base64/part-1
cc12m_base64/part-2
cc12m_base64/part-3
cc12m_base64/part-4
cc12m_base64/part-5
cc12m_base64/part-6
cc12m_base64/part-7
cc12m_base64/part-0
cc12m_base64/part-1
cc12m_base64/part-2
cc12m_base64/part-3
cc12m_base64/part-4
cc12m_base64/part-5
cc12m_base64/part-6
cc12m_base64/part-7
cc12m_base64/part-0
cc12m_base64/part-1
cc12m_base64/part-2
cc12m_base64/part-3
cc12m_base64/part-4
cc12m_base64/part-5
cc12m_base64/part-6
cc12m_base64/part-7
cc12m_base64/part-0
cc12m_base64/part-1
cc12m_base64/part-2
cc12m_base64/part-3
cc12m_base64/part-4
cc12m_base64/part-5
cc12m_base64/part-6
cc12m_base64/part-7
cc12m_base64/part-0
cc12m_base64/part-1
cc12m_base64/part-2
cc12m_base64/part-3
cc12m_base64/part-4
cc12m_base64/part-5
cc12m_base64/part-6
cc12m_base64/part-7
cc12m_base64/part-0
cc12m_base64/part-1
cc12m_base64/part-2
cc12m_base64/part-3
cc12m_base64/part-4
cc12m_base64/part-5
cc12m_base64/part-6
cc12m_base64/part-7
cc12m_base64/part-0
cc12m_base64/part-1
cc12m_base64/part-2
cc12m_base64/part-3
cc12m_base64/part-4
cc12m_base64/part-5
cc12m_base64/part-6
cc12m_base64/part-7
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
# limitations under the License.

export CUDA_VISIBLE_DEVICES=0
python3 tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_1024.yaml -o Data.Train.loader.num_workers=0
python3 tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_super_resolution_1024.yaml -o Data.Train.loader.num_workers=0
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
# limitations under the License.

export CUDA_VISIBLE_DEVICES=0
python3 tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_super_resolusion_512.yaml -o Data.Train.loader.num_workers=8
python3 tools/train.py -c ppfleetx/configs/multimodal/imagen/imagen_super_resolution_512.yaml -o Data.Train.loader.num_workers=8
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ pybind11==2.10.0
numpy==1.21.6
paddleslim>=2.4.0rc
opencv-python==4.2.0.32
Pillow==9.0.1
Pillow==9.3.0
blobfile==1.3.3
paddlenlp>=2.4.3

0 comments on commit 63e631d

Please # to comment.