diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..1f24212
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,32 @@
+.ipynb_checkpoints/
+.idea/
+__pycache__/
+
+results
+results.*
+notebooks/
+outputs
+outputs/
+configs/temp/
+pytests/
+_cache/
+scripts/
+
+*.ckpt
+core.*
+
+__pycache__/
+**/__pycache__/
+*.py[cod]
+**/*.py[cod]
+**/*.pyc
+result*/
+results*/
+backup*/
+test.*/
+.nfs*
+
+.ipynb_*/
+
+# MacOSX
+**/*.DS_Store
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..b86ff54
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,82 @@
+Copyright (c) 2022 Kakao Brain
+
+CreativeML Open RAIL-M
+dated August 22, 2022
+
+Section I: PREAMBLE
+
+Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
+
+Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
+
+In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
+
+Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
+
+This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
+
+NOW THEREFORE, You and Licensor agree as follows:
+
+1. Definitions
+
+- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
+- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
+- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
+- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
+- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
+- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
+- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
+- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
+- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
+- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
+- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
+- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
+
+Section II: INTELLECTUAL PROPERTY RIGHTS
+
+Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
+
+2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
+3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
+
+Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
+
+4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
+Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
+You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
+You must cause any modified files to carry prominent notices stating that You changed the files;
+You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
+You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
+5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
+6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
+
+Section IV: OTHER PROVISIONS
+
+7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
+8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
+9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
+10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
+11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
+12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
+
+END OF TERMS AND CONDITIONS
+
+
+
+
+Attachment A
+
+Use Restrictions
+
+You agree not to use the Model or Derivatives of the Model:
+- In any way that violates any applicable national, federal, state, local or international law or regulation;
+- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
+- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
+- To generate or disseminate personal identifiable information that can be used to harm an individual;
+- To defame, disparage or otherwise harass others;
+- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
+- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
+- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
+- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
+- To provide medical advice and medical results interpretation;
+- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..7be5de3
--- /dev/null
+++ b/README.md
@@ -0,0 +1,153 @@
+# Karlo-v1.0.alpha on COYO-100M and CC15M
+
+Karlo is a text-conditional image generation model based on OpenAI's unCLIP architecture with the improvement over the standard super-resolution model from 64px to 256px, recovering high-frequency details only in the small number of denoising steps.
+
+
+
+
+
+
+ "a portrait of an old monk, highly detailed."
+
+
+
+
+
+ "Photo of a business woman, silver hair"
+
+
+
+
+
+ "A teddy bear on a skateboard, children drawing style."
+
+
+
+
+
+ "Goryeo celadon in the shape of bird"
+
+
+
+
+
+This alpha version of Karlo is trained on 115M image-text pairs, including [COYO](https://github.com/kakaobrain/coyo-dataset)-100M high-quality subset, CC3M, and CC12M. For those who are interested in a better version of Karlo trained on more large-scale high-quality datasets, please visit the landing page of our application [B^DISCOVER](https://bdiscover.kakaobrain.com/).
+
+### Updates
+* [2022-12-01] Karlo-v1.0.alpha is released!
+
+## Model Architecture
+
+### Overview
+Karlo is a text-conditional diffusion model based on unCLIP, composed of prior, decoder, and super-resolution modules. In this repository, we include the improved version of the standard super-resolution module for upscaling 64px to 256px only in 7 reverse steps, as illustrated in the figure below:
+
+
+
+
+
+In specific, the standard SR module trained by DDPM objective upscales 64px to 256px in the first 6 denoising steps based on the respacing technique. Then, the additional fine-tuned SR module trained by [VQ-GAN](https://compvis.github.io/taming-transformers/)-style loss performs the final reverse step to recover high-frequency details. We observe that this approach is very effective to upscale the low-resolution in a small number of reverse steps.
+
+### Details
+We train all components from scratch on 115M image-text pairs including COYO-100M, CC3M, and CC12M. In the case of Prior and Decoder, we use ViT-L/14 provided by OpenAI’s [CLIP repository](https://github.com/openai/CLIP). Unlike the original implementation of unCLIP, we replace the trainable transformer in the decoder into the text encoder in ViT-L/14 for efficiency. In the case of the SR module, we first train the model using the DDPM objective in 1M steps, followed by additional 234K steps to fine-tune the additional component. The table below summarizes the important statistics of our components:
+
+| | Prior | Decoder | SR |
+|:------|----:|----:|----:|
+| CLIP | ViT-L/14 | ViT-L/14 | - |
+| #param | 1B | 900M | 700M + 700M |
+| #optimization steps | 1M | 1M | 1M + 0.2M |
+| #sampling steps | 25 | 50 (default), 25 (fast) | 7 |
+|Checkpoint links| [ViT-L-14](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/096db1af569b284eb76b3881534822d9/ViT-L-14.pt), [ViT-L-14 stats](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th), [model](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt) | [model](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt) | [model](https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt) |
+
+In the checkpoint links, ViT-L-14 is equivalent to the original version, but we include it for convenience. We also remark that ViT-L-14-stats is required to normalize the outputs of the prior module.
+
+### Evaluation
+We quantitatively measure the performance of Karlo-v1.0.alpha in the validation split of CC3M and MS-COCO. The table below presents CLIP-score and FID. To measure FID, we resize the image of the shorter side to 256px, followed by cropping it at the center. We set classifier-free guidance scales for prior and decoder to 4 and 8 in all cases. We observe that our model achieves reasonable performance even with 25 sampling steps of decoder.
+
+CC3M
+| Sampling step | CLIP-s (ViT-B/16) | FID (13k from val)|
+|:------|----:|----:|
+| Prior (25) + Decoder (25) + SR (7) | 0.3081 | 14.37 |
+| Prior (25) + Decoder (50) + SR (7) | 0.3086 | 13.95 |
+
+MS-COCO
+| Sampling step | CLIP-s (ViT-B/16) | FID (30k from val)|
+|:------|----:|----:|
+| Prior (25) + Decoder (25) + SR (7) | 0.3192 | 15.24 |
+| Prior (25) + Decoder (50) + SR (7) | 0.3192 | 14.43 |
+
+
+For more information, please refer to the upcoming technical report.
+
+
+## Environment Setup
+We use a single V100 of 32GB VRAM for sampling under PyTorch >= 1.10 and CUDA >= 11. The following commands install additional python packages and get pretrained model checkpoints. Or, you can simply install the package and download the weights via [setup.sh](setup.sh)
+- Additional python packages
+```
+pip install -r requirements.txt
+```
+- Model checkpoints
+```
+wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/096db1af569b284eb76b3881534822d9/ViT-L-14.pt -P $KARLO_ROOT_DIR # same with the official ViT-L/14 from OpenAI
+wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th -P $KARLO_ROOT_DIR
+wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt -P $KARLO_ROOT_DIR
+wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt -P $KARLO_ROOT_DIR
+wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt -P $KARLO_ROOT_DIR
+```
+
+## Sampling
+
+### Gradio demo (T2I and Image variation)
+The following command launches gradio demo for text-to-image generation and image variation. We notice that the second run in the gradio is unexpectedly slower than the usual case in PyTorch>=1.12. We guess that this happens because launching the cuda kernels takes some time, usually up to 2 minutes.
+```
+python demo/product_demo.py --host 0.0.0.0 --port $PORT --root-dir $KARLO_ROOT_DIR
+```
+
+Samples below are non-cherry picked T2I and image variation examples of random seed 0.
+In each case, the first row shows T2I samples and the second shows the image variation samples of the leftmost image in the first row.
+
+
+ [T2I + Image variation] "A man with a face of avocado, in the drawing style of Rene Magritte."
+
+
+
+
+
+
+
+ [T2I + Image variation] "a black porcelain in the shape of pikachu"
+
+
+
+
+
+
+
+### T2I command line example
+Here, we include the command line example of T2I. For image variation, you can refer to [karlo/sampler/i2i.py](karlo/sampler/i2i.py) on how to replace the prior into the clip image feature.
+```python
+python example.py --root-dir=$KARLO_ROOT_DIR \
+ --prompt="A man with a face of avocado, in the drawing style of Rene Magritte" \
+ --output-dir=$OUTPUT_DIR \
+ --max-bsz=2 \
+ --sampling-type=fast
+```
+
+## Licence and Disclaimer
+This project including the weights are distributed under [CreativeML Open RAIL-M license](LICENSE), equivalent version of [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion/blob/main/LICENSE). You may use this model in commercial applications, but it is highly recommended to adopt a powerful safe checker as a post-processing. We also remark that we are not responsible for any kinds of use of the generated images.
+
+## BibTex
+If you find this repository useful in your research, please cite:
+```
+@misc{kakaobrain2022karlo-v1-alpha,
+ title = {Karlo-v1.0.alpha on COYO-100M and CC15M},
+ author = {Donghoon Lee, Jiseob Kim, Jisu Choi, Jongmin Kim, Minwoo Byeon, Woonhyuk Baek and Saehoon Kim},
+ year = {2022},
+ howpublished = {\url{https://github.com/kakaobrain/karlo}},
+}
+```
+
+## Acknowledgement
+We deeply appreciate all the contributors to OpenAI’s [Guided-Diffusion](https://github.com/openai/guided-diffusion) project.
+
+## Contact
+If you would like to collaborate with us or share a feedback, please e-mail to us, contact@kakaobrain.com
diff --git a/assets/A man with a face of avocado, in the drawing style of Rene Magritte..png b/assets/A man with a face of avocado, in the drawing style of Rene Magritte..png
new file mode 100644
index 0000000..581c45f
Binary files /dev/null and b/assets/A man with a face of avocado, in the drawing style of Rene Magritte..png differ
diff --git a/assets/A teddy bear on a skateboard, children drawing style..png b/assets/A teddy bear on a skateboard, children drawing style..png
new file mode 100644
index 0000000..3a09cf6
Binary files /dev/null and b/assets/A teddy bear on a skateboard, children drawing style..png differ
diff --git a/assets/Goryeo celadon in the shape of bird.png b/assets/Goryeo celadon in the shape of bird.png
new file mode 100644
index 0000000..ef91fb6
Binary files /dev/null and b/assets/Goryeo celadon in the shape of bird.png differ
diff --git a/assets/Photo of a business woman, silver hair.png b/assets/Photo of a business woman, silver hair.png
new file mode 100644
index 0000000..b9fd1e8
Binary files /dev/null and b/assets/Photo of a business woman, silver hair.png differ
diff --git a/assets/a black porcelain in the shape of pikachu.png b/assets/a black porcelain in the shape of pikachu.png
new file mode 100644
index 0000000..d0583e7
Binary files /dev/null and b/assets/a black porcelain in the shape of pikachu.png differ
diff --git a/assets/a portrait of an old monk, highly detailed.png b/assets/a portrait of an old monk, highly detailed.png
new file mode 100644
index 0000000..4b2b206
Binary files /dev/null and b/assets/a portrait of an old monk, highly detailed.png differ
diff --git a/assets/example.gif b/assets/example.gif
new file mode 100644
index 0000000..1960285
Binary files /dev/null and b/assets/example.gif differ
diff --git a/assets/improved_sr_arch.png b/assets/improved_sr_arch.png
new file mode 100644
index 0000000..bcf13c0
Binary files /dev/null and b/assets/improved_sr_arch.png differ
diff --git a/assets/variation_A man with a face of avocado, in the drawing style of Rene Magritte..png b/assets/variation_A man with a face of avocado, in the drawing style of Rene Magritte..png
new file mode 100644
index 0000000..b49fc80
Binary files /dev/null and b/assets/variation_A man with a face of avocado, in the drawing style of Rene Magritte..png differ
diff --git a/assets/variation_a black porcelain in the shape of pikachu.png b/assets/variation_a black porcelain in the shape of pikachu.png
new file mode 100644
index 0000000..a710672
Binary files /dev/null and b/assets/variation_a black porcelain in the shape of pikachu.png differ
diff --git a/configs/decoder_900M_vit_l.yaml b/configs/decoder_900M_vit_l.yaml
new file mode 100644
index 0000000..4a99ee7
--- /dev/null
+++ b/configs/decoder_900M_vit_l.yaml
@@ -0,0 +1,39 @@
+model:
+ type: t2i-decoder
+ diffusion_sampler: uniform
+ hparams:
+ image_size: 64
+ num_channels: 320
+ num_res_blocks: 3
+ channel_mult: ''
+ attention_resolutions: 32,16,8
+ num_heads: -1
+ num_head_channels: 64
+ num_heads_upsample: -1
+ use_scale_shift_norm: true
+ dropout: 0.1
+ clip_dim: 768
+ clip_emb_mult: 4
+ text_ctx: 77
+ xf_width: 1536
+ xf_layers: 0
+ xf_heads: 0
+ xf_final_ln: false
+ xf_padding: false
+ resblock_updown: true
+ learn_sigma: true
+ cache_text_emb: false
+ text_drop: 0.3
+ clip_emb_type: image
+ clip_emb_drop: 0.1
+ use_plm: true
+
+diffusion:
+ steps: 1000
+ learn_sigma: true
+ sigma_small: false
+ noise_schedule: squaredcos_cap_v2
+ use_kl: false
+ predict_xstart: false
+ rescale_learned_sigmas: true
+ timestep_respacing: ''
diff --git a/configs/improved_sr_64_256_1.4B.yaml b/configs/improved_sr_64_256_1.4B.yaml
new file mode 100644
index 0000000..282d3cb
--- /dev/null
+++ b/configs/improved_sr_64_256_1.4B.yaml
@@ -0,0 +1,27 @@
+model:
+ type: improved_sr_64_256
+ diffusion_sampler: uniform
+ hparams:
+ channels: 320
+ depth: 3
+ channels_multiple:
+ - 1
+ - 2
+ - 3
+ - 4
+ dropout: 0.0
+
+diffusion:
+ steps: 1000
+ learn_sigma: false
+ sigma_small: true
+ noise_schedule: squaredcos_cap_v2
+ use_kl: false
+ predict_xstart: false
+ rescale_learned_sigmas: true
+ timestep_respacing: '7'
+
+
+sampling:
+ timestep_respacing: '7' # fix
+ clip_denoise: true
diff --git a/configs/prior_1B_vit_l.yaml b/configs/prior_1B_vit_l.yaml
new file mode 100644
index 0000000..0b9ddba
--- /dev/null
+++ b/configs/prior_1B_vit_l.yaml
@@ -0,0 +1,23 @@
+model:
+ type: prior
+ diffusion_sampler: uniform
+ hparams:
+ text_ctx: 77
+ xf_width: 2048
+ xf_layers: 20
+ xf_heads: 32
+ xf_final_ln: true
+ xf_padding: false
+ text_drop: 0.2
+ clip_dim: 768
+ clip_xf_width: 768
+
+diffusion:
+ steps: 1000
+ learn_sigma: false
+ sigma_small: true
+ noise_schedule: squaredcos_cap_v2
+ use_kl: false
+ predict_xstart: true
+ rescale_learned_sigmas: false
+ timestep_respacing: ''
diff --git a/demo/components.py b/demo/components.py
new file mode 100644
index 0000000..a6c127a
--- /dev/null
+++ b/demo/components.py
@@ -0,0 +1,333 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+import time
+import sys
+import os
+import threading
+import logging
+from queue import Queue
+from PIL import Image
+
+import gradio as gr
+import numpy as np
+import torch
+
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+
+from karlo.sampler.template import CKPT_PATH, BaseSampler
+from karlo.sampler.t2i import T2ISampler
+from karlo.sampler.i2i import I2ISampler
+from karlo.utils.util import set_seed
+
+
+def tensor_to_images(tensor: torch.Tensor, output_res=(1024, 1024)):
+ assert tensor.ndim == 4
+ tensor = torch.clone(tensor)
+ # NCHW -> NHWC
+ images = torch.permute(tensor * 255.0, [0, 2, 3, 1]).type(torch.uint8).cpu().numpy()
+ concat_image = np.concatenate(images, axis=1)
+ target_size = (output_res[1] * tensor.shape[0], output_res[0])
+ concat_image = Image.fromarray(concat_image).resize(
+ target_size, resample=Image.NEAREST
+ )
+ return images, concat_image
+
+
+class GradioSampler:
+ def __init__(
+ self,
+ root_dir,
+ max_bsz,
+ progressive,
+ sampling_type: str,
+ ):
+ self._root_dir = root_dir
+ self._max_bsz = max_bsz
+ self._progressive = progressive
+ self._sampling_type = sampling_type
+
+ self.load_ckpt()
+ self.set_options_from_sampler()
+
+ self.result_queue = Queue()
+
+ def load_ckpt(self):
+ base_sampler = BaseSampler(root_dir=self._root_dir)
+ base_sampler.load_clip(clip_path="ViT-L-14.pt")
+ base_sampler.load_prior(
+ f"{CKPT_PATH['prior']}",
+ clip_stat_path="ViT-L-14_stats.th",
+ )
+ base_sampler.load_decoder(f"{CKPT_PATH['decoder']}")
+ base_sampler.load_sr_64_256(f"{CKPT_PATH['sr_256']}")
+
+ self.t2i_sampler = T2ISampler(
+ root_dir=self._root_dir, sampling_type=self._sampling_type
+ )
+ self.i2i_sampler = I2ISampler(
+ root_dir=self._root_dir, sampling_type=self._sampling_type
+ )
+
+ self.t2i_sampler._clip = base_sampler._clip
+ self.t2i_sampler._tokenizer = base_sampler._tokenizer
+ self.t2i_sampler._prior = base_sampler._prior
+ self.t2i_sampler._decoder = base_sampler._decoder
+ self.t2i_sampler._sr_64_256 = base_sampler._sr_64_256
+
+ self.i2i_sampler._clip = base_sampler._clip
+ self.i2i_sampler._tokenizer = base_sampler._tokenizer
+ self.i2i_sampler._prior = base_sampler._prior
+ self.i2i_sampler._decoder = base_sampler._decoder
+ self.i2i_sampler._sr_64_256 = base_sampler._sr_64_256
+
+ self.ckpt_info = f"""
+ * **prior**: `{self._root_dir}/{CKPT_PATH['prior']}`
+ * **decoder**: `{self._root_dir}/{CKPT_PATH['decoder']}`
+ * **sr_64_256**: `{self._root_dir}/{CKPT_PATH['sr_256']}`
+ """
+
+ def set_options_from_sampler(self):
+ self.global_options = {"seed": 0, "max_bsz": self._max_bsz}
+
+ self.prior_options = {
+ "sm": self.t2i_sampler._prior_sm,
+ "cf_scale": self.t2i_sampler._prior_cf_scale,
+ }
+ self.decoder_options = {
+ "sm": self.t2i_sampler._decoder_sm,
+ "cf_scale": self.t2i_sampler._decoder_cf_scale,
+ }
+ self.sr_64_256_options = {
+ "sm": self.t2i_sampler._sr_sm,
+ }
+
+ def make_global_options(self):
+ gr.Markdown("Global Options")
+ with gr.Row():
+ return [
+ gr.Slider(
+ label="seed",
+ value=self.global_options["seed"],
+ minimum=np.iinfo(np.uint32).min,
+ maximum=np.iinfo(np.uint32).max,
+ step=1,
+ ),
+ gr.Slider(
+ label="maximum batch size",
+ value=self.global_options["max_bsz"],
+ minimum=1,
+ maximum=4,
+ step=1,
+ ),
+ ]
+
+ def make_prior_options(self):
+ gr.Markdown("Prior Options")
+ return [
+ gr.Textbox(
+ label="sampling method",
+ value=self.prior_options["sm"],
+ ),
+ gr.Slider(
+ label="Classifier-free guidance scales",
+ value=self.prior_options["cf_scale"],
+ minimum=0.1,
+ maximum=24,
+ ),
+ ]
+
+ def make_decoder_options(self):
+ gr.Markdown("Decoder Options")
+ with gr.Row():
+ return [
+ gr.Textbox(
+ label="sampling method",
+ value=self.decoder_options["sm"],
+ ),
+ gr.Slider(
+ label="Classifier-free guidance scales",
+ value=self.decoder_options["cf_scale"],
+ minimum=0.1,
+ maximum=24,
+ ),
+ ]
+
+ def make_sr_64_256_options(self):
+ return [gr.Variable(self.sr_64_256_options["sm"])]
+
+ def make_basic_options(self):
+ self.global_options_gr = self.make_global_options()
+ self.prior_optios_gr = self.make_prior_options()
+ self.decoder_options_gr = self.make_decoder_options()
+ self.sr_64_256_options_gr = self.make_sr_64_256_options()
+
+ def seed(self, seed):
+ set_seed(seed)
+
+ def _sample(self, output_generator):
+ for k, out in enumerate(output_generator):
+ self.result_queue.put((out, False))
+ self.result_queue.put((None, True))
+
+ def t2i_sample(
+ self,
+ text_input,
+ prior_sm,
+ prior_cf_scale,
+ decoder_sm,
+ decoder_cf_scale,
+ sr_sm,
+ seed,
+ max_bsz,
+ ):
+ t0 = time.time()
+ assert hasattr(self.t2i_sampler, "_prior_sm")
+ assert hasattr(self.t2i_sampler, "_prior_cf_scale")
+ assert hasattr(self.t2i_sampler, "_decoder_sm")
+ assert hasattr(self.t2i_sampler, "_decoder_cf_scale")
+ assert hasattr(self.t2i_sampler, "_sr_sm")
+
+ print("-" * 100)
+ print(f"text_input: {text_input}")
+ print(f"prior_sm: {prior_sm}")
+ print(f"prior_cf_scale: {prior_cf_scale}")
+ print(f"decoder_sm: {decoder_sm}")
+ print(f"decoder_cf_scale: {decoder_cf_scale}")
+ print(f"sr_sm: {sr_sm}")
+ print(f"seed: {seed}")
+ print(f"max_bsz: {max_bsz}")
+
+ self.t2i_sampler._prior_sm = prior_sm
+ self.t2i_sampler._prior_cf_scale = prior_cf_scale
+
+ self.t2i_sampler._decoder_sm = decoder_sm
+ self.t2i_sampler._decoder_cf_scale = decoder_cf_scale
+
+ self.t2i_sampler._sr_sm = sr_sm
+
+ self.seed(seed)
+
+ output_generator = self.t2i_sampler(
+ prompt=text_input,
+ bsz=max_bsz,
+ progressive_mode=self._progressive,
+ )
+
+ thread = threading.Thread(target=self._sample, args=(output_generator,))
+ thread.start()
+ done = False
+
+ while not done:
+ if self.result_queue.empty():
+ time.sleep(0.1)
+ else:
+ while not self.result_queue.empty():
+ _out, done = self.result_queue.get(0) # get last item to display
+ if not done:
+ out = _out
+ images, concat_image = tensor_to_images(out, (256, 256))
+ yield (text_input, images), concat_image
+
+ thread.join()
+ yield (text_input, images), concat_image
+
+ t1 = time.time()
+ execution_time = t1 - t0
+ logging.info(f"Generation done. {text_input} -- {execution_time:.6f}secs")
+ print("-" * 100)
+
+ def i2i_sample(
+ self,
+ image_input,
+ decoder_sm,
+ decoder_cf_scale,
+ sr_sm,
+ seed,
+ max_bsz,
+ ):
+ t0 = time.time()
+ assert hasattr(self.i2i_sampler, "_decoder_sm")
+ assert hasattr(self.i2i_sampler, "_decoder_cf_scale")
+ assert hasattr(self.i2i_sampler, "_sr_sm")
+
+ print("-" * 100)
+ print(f"decoder_sm: {decoder_sm}")
+ print(f"decoder_cf_scale: {decoder_cf_scale}")
+ print(f"sr_sm: {sr_sm}")
+ print(f"seed: {seed}")
+ print(f"max_bsz: {max_bsz}")
+
+ self.i2i_sampler._decoder_sm = decoder_sm
+ self.i2i_sampler._decoder_cf_scale = decoder_cf_scale
+
+ self.i2i_sampler._sr_sm = sr_sm
+
+ self.seed(seed)
+
+ output_generator = self.i2i_sampler(
+ image=image_input,
+ bsz=max_bsz,
+ progressive_mode=self._progressive,
+ )
+
+ thread = threading.Thread(target=self._sample, args=(output_generator,))
+ thread.start()
+ done = False
+
+ while not done:
+ if self.result_queue.empty():
+ time.sleep(0.1)
+ else:
+ while not self.result_queue.empty():
+ _out, done = self.result_queue.get(0) # get last item to display
+ if not done:
+ out = _out
+ images, concat_image = tensor_to_images(out, (256, 256))
+ yield ("", images), concat_image
+
+ thread.join()
+ yield ("", images), concat_image
+
+ t1 = time.time()
+ execution_time = t1 - t0
+ logging.info(f"Variation done. {execution_time:.6f}secs")
+ print("-" * 100)
+
+
+class ImageSelecter:
+ @classmethod
+ def make_basic_ui(cls, max_bsz):
+ with gr.Box():
+ i2i_select_idx = gr.Radio(
+ choices=[str(i) for i in range(0, max_bsz)],
+ value="0",
+ label="Image index",
+ )
+ i2i_select_button = gr.Button(
+ "Select for Image Variation", variant="primary"
+ )
+ return {
+ "i2i_select_idx": i2i_select_idx,
+ "i2i_select_button": i2i_select_button,
+ }
+
+ @classmethod
+ def select_fn(cls, stash, idx):
+ if stash is not None:
+ return Image.fromarray(stash[1][int(idx)].copy())
+
+ @classmethod
+ def setup_button_click(
+ cls,
+ selector_ui,
+ stash,
+ i2i_input_images,
+ ):
+ selector_ui["i2i_select_button"].click(
+ fn=cls.select_fn,
+ inputs=[stash, selector_ui["i2i_select_idx"]],
+ outputs=[i2i_input_images],
+ )
diff --git a/demo/product_demo.py b/demo/product_demo.py
new file mode 100644
index 0000000..6e82a79
--- /dev/null
+++ b/demo/product_demo.py
@@ -0,0 +1,125 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+import argparse
+import logging
+import gradio as gr
+import os
+import sys
+
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from karlo import __version__ as karlo_ver
+from demo.components import GradioSampler, ImageSelecter
+
+
+class GradioDemo:
+ def __init__(
+ self,
+ root_dir: str,
+ max_bsz: int,
+ progressive: str,
+ sampling_type: str,
+ ):
+ sampler = GradioSampler(
+ root_dir=root_dir,
+ max_bsz=max_bsz,
+ progressive=progressive,
+ sampling_type=sampling_type,
+ )
+
+ demo = gr.Blocks()
+ with demo:
+ gr.Markdown(f"# Karlo Demo {karlo_ver}")
+ with gr.Box():
+ gr.Markdown("## Generate 64px images + Upscaling to 256px")
+
+ with gr.Tabs():
+ with gr.TabItem("Image Generation"):
+ t2i_text_input = gr.Textbox(
+ lines=1,
+ placeholder="Type text prompt...",
+ label="Text prompts",
+ )
+ t2i_button = gr.Button("Generate", variant="primary")
+ with gr.TabItem("Image Variation"):
+ i2i_img_input = gr.Image(label="Image input", type="pil")
+ i2i_button = gr.Button("Generate", variant="primary")
+
+ with gr.Box():
+ outputs = gr.Image(label="Generated", type="pil")
+ stash = gr.Variable()
+ with gr.Row():
+ selector_ui = ImageSelecter.make_basic_ui(max_bsz=max_bsz)
+
+ with gr.Box():
+ with gr.Accordion(label="Advanced Options", open=False):
+ sampler.make_basic_options()
+
+ with gr.Box():
+ with gr.Accordion(label="Checkpoint Information", open=False):
+ gr.Markdown(sampler.ckpt_info)
+
+ t2i_button.click(
+ fn=sampler.t2i_sample,
+ inputs=[t2i_text_input]
+ + sampler.prior_optios_gr
+ + sampler.decoder_options_gr
+ + sampler.sr_64_256_options_gr
+ + sampler.global_options_gr,
+ outputs=[stash, outputs],
+ )
+ i2i_button.click(
+ fn=sampler.i2i_sample,
+ inputs=[i2i_img_input]
+ + sampler.decoder_options_gr
+ + sampler.sr_64_256_options_gr
+ + sampler.global_options_gr,
+ outputs=[stash, outputs],
+ )
+
+ ImageSelecter.setup_button_click(selector_ui, stash, i2i_img_input)
+
+ demo.queue()
+ self.demo = demo
+
+
+def default_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--root-dir", type=str, default=None)
+ parser.add_argument("--max_bsz", type=int, default=4)
+ parser.add_argument(
+ "--progressive", type=str, default="loop", choices=("loop", "stage", "final")
+ )
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=6006)
+
+ parser.add_argument(
+ "--sampling-type",
+ type=str,
+ default="fast",
+ choices=("fast", "default"),
+ )
+
+ return parser
+
+
+if __name__ == "__main__":
+ parser = default_parser()
+ args = parser.parse_args()
+ logging.getLogger().setLevel(logging.INFO)
+
+ assert (
+ args.root_dir is not None
+ ), "--root-dir argument should be specified to load the pretrained ckpt"
+
+ """Making Gradio"""
+ gradio_demo = GradioDemo(
+ root_dir=args.root_dir,
+ max_bsz=args.max_bsz,
+ progressive=args.progressive,
+ sampling_type=args.sampling_type,
+ )
+ gradio_demo.demo.launch(server_name=args.host, server_port=args.port)
diff --git a/example.py b/example.py
new file mode 100644
index 0000000..531ce8f
--- /dev/null
+++ b/example.py
@@ -0,0 +1,86 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+import os
+import argparse
+import logging
+import time
+from datetime import datetime
+
+import torch
+from PIL import Image
+
+from karlo.sampler.t2i import T2ISampler
+from karlo.utils.util import set_seed
+
+
+def default_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--root-dir", type=str, required=True, help="path for model checkpoints"
+ )
+ parser.add_argument("--max-bsz", type=int, default=1, help="#images to generate")
+ parser.add_argument(
+ "--output-dir",
+ type=str,
+ default="outputs",
+ help="output path for generated images",
+ )
+ parser.add_argument(
+ "--sampling-type",
+ type=str,
+ default="fast",
+ choices=("fast", "default"),
+ )
+ parser.add_argument(
+ "--prompt", type=str, default="A photo of a baby puppy waiting for her mom."
+ )
+ parser.add_argument("--seed", type=int, default=0)
+
+ return parser
+
+
+if __name__ == "__main__":
+ parser = default_parser()
+ args = parser.parse_args()
+
+ set_seed(args.seed)
+ logging.getLogger().setLevel(logging.INFO)
+
+ save_dir = os.path.join(args.output_dir, datetime.now().strftime("%d%m%Y_%H%M%S"))
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ model = T2ISampler.from_pretrained(
+ root_dir=args.root_dir,
+ clip_model_path="ViT-L-14.pt",
+ clip_stat_path="ViT-L-14_stats.th",
+ sampling_type=args.sampling_type,
+ )
+
+ for i in range(5):
+ t1 = time.time()
+
+ images = iter(
+ model(
+ prompt=args.prompt,
+ bsz=args.max_bsz,
+ progressive_mode="final",
+ )
+ ).__next__()
+
+ # NCHW, [0, 1], float32 -> NHWC, [0, 255], uint8
+ images = (
+ torch.permute(images * 255.0, [0, 2, 3, 1]).type(torch.uint8).cpu().numpy()
+ )
+
+ t2 = time.time()
+ execution_time = t2 - t1
+ logging.info(f"Iteration {i} -- {execution_time:.6f}secs")
+
+ # Select the first one
+ image = Image.fromarray(images[0])
+ image_name = "_".join(args.prompt.split(" "))
+ image.save(f"{save_dir}/{image_name}_{i:02d}.jpg")
diff --git a/karlo/__init__.py b/karlo/__init__.py
new file mode 100644
index 0000000..d47bcf7
--- /dev/null
+++ b/karlo/__init__.py
@@ -0,0 +1 @@
+__version__ = "1.0.alpha"
diff --git a/karlo/models/__init__.py b/karlo/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/karlo/models/clip.py b/karlo/models/clip.py
new file mode 100644
index 0000000..961d815
--- /dev/null
+++ b/karlo/models/clip.py
@@ -0,0 +1,182 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+# ------------------------------------------------------------------------------------
+# Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/)
+# ------------------------------------------------------------------------------------
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import clip
+
+from clip.model import CLIP, convert_weights
+from clip.simple_tokenizer import SimpleTokenizer, default_bpe
+
+
+"""===== Monkey-Patching original CLIP for JIT compile ====="""
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = F.layer_norm(
+ x.type(torch.float32),
+ self.normalized_shape,
+ self.weight,
+ self.bias,
+ self.eps,
+ )
+ return ret.type(orig_type)
+
+
+clip.model.LayerNorm = LayerNorm
+delattr(clip.model.CLIP, "forward")
+
+"""===== End of Monkey-Patching ====="""
+
+
+class CustomizedCLIP(CLIP):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ @torch.jit.export
+ def encode_image(self, image):
+ return self.visual(image)
+
+ @torch.jit.export
+ def encode_text(self, text):
+ # re-define this function to return unpooled text features
+
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+
+ x_seq = x
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+ return x_out, x_seq
+
+ @torch.jit.ignore
+ def forward(self, image, text):
+ super().forward(image, text)
+
+ @classmethod
+ def load_from_checkpoint(cls, ckpt_path: str):
+ state_dict = torch.load(ckpt_path, map_location="cpu").state_dict()
+
+ vit = "visual.proj" in state_dict
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len(
+ [
+ k
+ for k in state_dict.keys()
+ if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
+ ]
+ )
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round(
+ (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
+ )
+ image_resolution = vision_patch_size * grid_size
+ else:
+ counts: list = [
+ len(
+ set(
+ k.split(".")[2]
+ for k in state_dict
+ if k.startswith(f"visual.layer{b}")
+ )
+ )
+ for b in [1, 2, 3, 4]
+ ]
+ vision_layers = tuple(counts)
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round(
+ (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
+ )
+ vision_patch_size = None
+ assert (
+ output_width**2 + 1
+ == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ )
+ image_resolution = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(
+ set(
+ k.split(".")[2]
+ for k in state_dict
+ if k.startswith("transformer.resblocks")
+ )
+ )
+
+ model = cls(
+ embed_dim,
+ image_resolution,
+ vision_layers,
+ vision_width,
+ vision_patch_size,
+ context_length,
+ vocab_size,
+ transformer_width,
+ transformer_heads,
+ transformer_layers,
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ if key in state_dict:
+ del state_dict[key]
+
+ convert_weights(model)
+ model.load_state_dict(state_dict)
+ model.eval()
+ model.float()
+ return model
+
+
+class CustomizedTokenizer(SimpleTokenizer):
+ def __init__(self):
+ super().__init__(bpe_path=default_bpe())
+
+ self.sot_token = self.encoder["<|startoftext|>"]
+ self.eot_token = self.encoder["<|endoftext|>"]
+
+ def padded_tokens_and_mask(self, texts, text_ctx):
+ assert isinstance(texts, list) and all(
+ isinstance(elem, str) for elem in texts
+ ), "texts should be a list of strings"
+
+ all_tokens = [
+ [self.sot_token] + self.encode(text) + [self.eot_token] for text in texts
+ ]
+
+ mask = [
+ [True] * min(text_ctx, len(tokens))
+ + [False] * max(text_ctx - len(tokens), 0)
+ for tokens in all_tokens
+ ]
+ mask = torch.tensor(mask, dtype=torch.bool)
+ result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int)
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > text_ctx:
+ tokens = tokens[:text_ctx]
+ tokens[-1] = self.eot_token
+ result[i, : len(tokens)] = torch.tensor(tokens)
+
+ return result, mask
diff --git a/karlo/models/decoder_model.py b/karlo/models/decoder_model.py
new file mode 100644
index 0000000..1654182
--- /dev/null
+++ b/karlo/models/decoder_model.py
@@ -0,0 +1,186 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+import copy
+import torch
+
+from ..modules import create_gaussian_diffusion
+from ..modules.unet import PLMImUNet
+
+
+class Text2ImProgressiveModel(torch.nn.Module):
+ def __init__(
+ self,
+ config,
+ tokenizer,
+ ):
+ super().__init__()
+
+ self._conf = config
+ self._model_conf = config.model.hparams
+ self._diffusion_kwargs = dict(
+ steps=config.diffusion.steps,
+ learn_sigma=config.diffusion.learn_sigma,
+ sigma_small=config.diffusion.sigma_small,
+ noise_schedule=config.diffusion.noise_schedule,
+ use_kl=config.diffusion.use_kl,
+ predict_xstart=config.diffusion.predict_xstart,
+ rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
+ timestep_respacing=config.diffusion.timestep_respacing,
+ )
+ self._tokenizer = tokenizer
+
+ self.model = self.create_plm_dec_model()
+
+ cf_token, cf_mask = self.set_cf_text_tensor()
+ self.register_buffer("cf_token", cf_token, persistent=False)
+ self.register_buffer("cf_mask", cf_mask, persistent=False)
+
+ @classmethod
+ def load_from_checkpoint(cls, config, tokenizer, ckpt_path, strict: bool = True):
+ ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
+
+ model = cls(config, tokenizer)
+ model.load_state_dict(ckpt, strict=strict)
+ return model
+
+ def create_plm_dec_model(self):
+ image_size = self._model_conf.image_size
+ if self._model_conf.channel_mult == "":
+ if image_size == 256:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif image_size == 128:
+ channel_mult = (1, 1, 2, 3, 4)
+ elif image_size == 64:
+ channel_mult = (1, 2, 3, 4)
+ else:
+ raise ValueError(f"unsupported image size: {image_size}")
+ else:
+ channel_mult = tuple(
+ int(ch_mult) for ch_mult in self._model_conf.channel_mult.split(",")
+ )
+ assert 2 ** (len(channel_mult) + 2) == image_size
+
+ attention_ds = []
+ for res in self._model_conf.attention_resolutions.split(","):
+ attention_ds.append(image_size // int(res))
+
+ return PLMImUNet(
+ text_ctx=self._model_conf.text_ctx,
+ xf_width=self._model_conf.xf_width,
+ in_channels=3,
+ model_channels=self._model_conf.num_channels,
+ out_channels=6 if self._model_conf.learn_sigma else 3,
+ num_res_blocks=self._model_conf.num_res_blocks,
+ attention_resolutions=tuple(attention_ds),
+ dropout=self._model_conf.dropout,
+ channel_mult=channel_mult,
+ num_heads=self._model_conf.num_heads,
+ num_head_channels=self._model_conf.num_head_channels,
+ num_heads_upsample=self._model_conf.num_heads_upsample,
+ use_scale_shift_norm=self._model_conf.use_scale_shift_norm,
+ resblock_updown=self._model_conf.resblock_updown,
+ clip_dim=self._model_conf.clip_dim,
+ clip_emb_mult=self._model_conf.clip_emb_mult,
+ clip_emb_type=self._model_conf.clip_emb_type,
+ clip_emb_drop=self._model_conf.clip_emb_drop,
+ )
+
+ def set_cf_text_tensor(self):
+ return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
+
+ def get_sample_fn(self, timestep_respacing):
+ use_ddim = timestep_respacing.startswith(("ddim", "fast"))
+
+ diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
+ diffusion_kwargs.update(timestep_respacing=timestep_respacing)
+ diffusion = create_gaussian_diffusion(**diffusion_kwargs)
+ sample_fn = (
+ diffusion.ddim_sample_loop_progressive
+ if use_ddim
+ else diffusion.p_sample_loop_progressive
+ )
+
+ return sample_fn
+
+ def forward(
+ self,
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ img_feat=None,
+ cf_guidance_scales=None,
+ timestep_respacing=None,
+ ):
+ # cfg should be enabled in inference
+ assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
+ assert img_feat is not None
+
+ bsz = txt_feat.shape[0]
+ img_sz = self._model_conf.image_size
+
+ def guided_model_fn(x_t, ts, **kwargs):
+ half = x_t[: len(x_t) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.model(combined, ts, **kwargs)
+ eps, rest = model_out[:, :3], model_out[:, 3:]
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cf_guidance_scales.view(-1, 1, 1, 1) * (
+ cond_eps - uncond_eps
+ )
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+ cf_feat = self.model.cf_param.unsqueeze(0)
+ cf_feat = cf_feat.expand(bsz // 2, -1)
+ feat = torch.cat([img_feat, cf_feat.to(txt_feat.device)], dim=0)
+
+ cond = {
+ "y": feat,
+ "txt_feat": txt_feat,
+ "txt_feat_seq": txt_feat_seq,
+ "mask": mask,
+ }
+ sample_fn = self.get_sample_fn(timestep_respacing)
+ sample_outputs = sample_fn(
+ guided_model_fn,
+ (bsz, 3, img_sz, img_sz),
+ noise=None,
+ device=txt_feat.device,
+ clip_denoised=True,
+ model_kwargs=cond,
+ )
+
+ for out in sample_outputs:
+ sample = out["sample"]
+ yield sample if cf_guidance_scales is None else sample[
+ : sample.shape[0] // 2
+ ]
+
+
+class Text2ImModel(Text2ImProgressiveModel):
+ def forward(
+ self,
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ img_feat=None,
+ cf_guidance_scales=None,
+ timestep_respacing=None,
+ ):
+ last_out = None
+ for out in super().forward(
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ img_feat,
+ cf_guidance_scales,
+ timestep_respacing,
+ ):
+ last_out = out
+ return last_out
diff --git a/karlo/models/prior_model.py b/karlo/models/prior_model.py
new file mode 100644
index 0000000..9f15981
--- /dev/null
+++ b/karlo/models/prior_model.py
@@ -0,0 +1,131 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+import copy
+import torch
+
+from ..modules import create_gaussian_diffusion
+from ..modules.xf import PriorTransformer
+
+
+class PriorDiffusionModel(torch.nn.Module):
+ def __init__(self, config, tokenizer, clip_mean, clip_std):
+ super().__init__()
+
+ self._conf = config
+ self._model_conf = config.model.hparams
+ self._diffusion_kwargs = dict(
+ steps=config.diffusion.steps,
+ learn_sigma=config.diffusion.learn_sigma,
+ sigma_small=config.diffusion.sigma_small,
+ noise_schedule=config.diffusion.noise_schedule,
+ use_kl=config.diffusion.use_kl,
+ predict_xstart=config.diffusion.predict_xstart,
+ rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
+ timestep_respacing=config.diffusion.timestep_respacing,
+ )
+ self._tokenizer = tokenizer
+
+ self.register_buffer("clip_mean", clip_mean[None, :], persistent=False)
+ self.register_buffer("clip_std", clip_std[None, :], persistent=False)
+
+ causal_mask = self.get_causal_mask()
+ self.register_buffer("causal_mask", causal_mask, persistent=False)
+
+ self.model = PriorTransformer(
+ text_ctx=self._model_conf.text_ctx,
+ xf_width=self._model_conf.xf_width,
+ xf_layers=self._model_conf.xf_layers,
+ xf_heads=self._model_conf.xf_heads,
+ xf_final_ln=self._model_conf.xf_final_ln,
+ xf_padding=self._model_conf.xf_padding,
+ clip_dim=self._model_conf.clip_dim,
+ clip_xf_width=self._model_conf.clip_xf_width,
+ )
+
+ cf_token, cf_mask = self.set_cf_text_tensor()
+ self.register_buffer("cf_token", cf_token, persistent=False)
+ self.register_buffer("cf_mask", cf_mask, persistent=False)
+
+ @classmethod
+ def load_from_checkpoint(
+ cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True
+ ):
+ ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
+
+ model = cls(config, tokenizer, clip_mean, clip_std)
+ model.load_state_dict(ckpt, strict=strict)
+ return model
+
+ def set_cf_text_tensor(self):
+ return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
+
+ def get_sample_fn(self, timestep_respacing):
+ use_ddim = timestep_respacing.startswith(("ddim", "fast"))
+
+ diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
+ diffusion_kwargs.update(timestep_respacing=timestep_respacing)
+ diffusion = create_gaussian_diffusion(**diffusion_kwargs)
+ sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop
+
+ return sample_fn
+
+ def get_causal_mask(self):
+ seq_len = self._model_conf.text_ctx + 4
+ mask = torch.empty(seq_len, seq_len)
+ mask.fill_(float("-inf"))
+ mask.triu_(1)
+ mask = mask[None, ...]
+ return mask
+
+ def forward(
+ self,
+ txt_feat,
+ txt_feat_seq,
+ mask,
+ cf_guidance_scales=None,
+ timestep_respacing=None,
+ denoised_fn=True,
+ ):
+ # cfg should be enabled in inference
+ assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
+
+ bsz_ = txt_feat.shape[0]
+ bsz = bsz_ // 2
+
+ def guided_model_fn(x_t, ts, **kwargs):
+ half = x_t[: len(x_t) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.model(combined, ts, **kwargs)
+ eps, rest = (
+ model_out[:, : int(x_t.shape[1])],
+ model_out[:, int(x_t.shape[1]) :],
+ )
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * (
+ cond_eps - uncond_eps
+ )
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+ cond = {
+ "text_emb": txt_feat,
+ "text_enc": txt_feat_seq,
+ "mask": mask,
+ "causal_mask": self.causal_mask,
+ }
+ sample_fn = self.get_sample_fn(timestep_respacing)
+ sample = sample_fn(
+ guided_model_fn,
+ (bsz_, self.model.clip_dim),
+ noise=None,
+ device=txt_feat.device,
+ clip_denoised=False,
+ denoised_fn=lambda x: torch.clamp(x, -10, 10),
+ model_kwargs=cond,
+ )
+ sample = (sample * self.clip_std) + self.clip_mean
+
+ return sample[:bsz]
diff --git a/karlo/models/sr_256_1k.py b/karlo/models/sr_256_1k.py
new file mode 100644
index 0000000..690b2a5
--- /dev/null
+++ b/karlo/models/sr_256_1k.py
@@ -0,0 +1,10 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+from .sr_64_256 import SupRes64to256Progressive
+
+
+class SupRes256to1kProgressive(SupRes64to256Progressive):
+ pass # no difference currently
diff --git a/karlo/models/sr_64_256.py b/karlo/models/sr_64_256.py
new file mode 100644
index 0000000..c300ad7
--- /dev/null
+++ b/karlo/models/sr_64_256.py
@@ -0,0 +1,88 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+import copy
+import torch
+
+from ..modules.unet import SuperResUNetModel
+from ..modules import create_gaussian_diffusion
+
+
+class ImprovedSupRes64to256ProgressiveModel(torch.nn.Module):
+ """
+ ImprovedSR model fine-tunes the pretrained DDPM-based SR model by using adversarial and perceptual losses.
+ In specific, the low-resolution sample is iteratively recovered by 6 steps with the frozen pretrained SR model.
+ In the following additional one step, a seperate fine-tuned model recovers high-frequency details.
+ This approach greatly improves the fidelity of images of 256x256px, even with small number of reverse steps.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self._config = config
+ self._diffusion_kwargs = dict(
+ steps=config.diffusion.steps,
+ learn_sigma=config.diffusion.learn_sigma,
+ sigma_small=config.diffusion.sigma_small,
+ noise_schedule=config.diffusion.noise_schedule,
+ use_kl=config.diffusion.use_kl,
+ predict_xstart=config.diffusion.predict_xstart,
+ rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
+ )
+
+ self.model_first_steps = SuperResUNetModel(
+ in_channels=3, # auto-changed to 6 inside the model
+ model_channels=config.model.hparams.channels,
+ out_channels=3,
+ num_res_blocks=config.model.hparams.depth,
+ attention_resolutions=(), # no attention
+ dropout=config.model.hparams.dropout,
+ channel_mult=config.model.hparams.channels_multiple,
+ resblock_updown=True,
+ use_middle_attention=False,
+ )
+ self.model_last_step = SuperResUNetModel(
+ in_channels=3, # auto-changed to 6 inside the model
+ model_channels=config.model.hparams.channels,
+ out_channels=3,
+ num_res_blocks=config.model.hparams.depth,
+ attention_resolutions=(), # no attention
+ dropout=config.model.hparams.dropout,
+ channel_mult=config.model.hparams.channels_multiple,
+ resblock_updown=True,
+ use_middle_attention=False,
+ )
+
+ @classmethod
+ def load_from_checkpoint(cls, config, ckpt_path, strict: bool = True):
+ ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
+
+ model = cls(config)
+ model.load_state_dict(ckpt, strict=strict)
+ return model
+
+ def get_sample_fn(self, timestep_respacing):
+ diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
+ diffusion_kwargs.update(timestep_respacing=timestep_respacing)
+ diffusion = create_gaussian_diffusion(**diffusion_kwargs)
+ return diffusion.p_sample_loop_progressive_for_improved_sr
+
+ def forward(self, low_res, timestep_respacing="7", **kwargs):
+ assert (
+ timestep_respacing == "7"
+ ), "different respacing method may work, but no guaranteed"
+
+ sample_fn = self.get_sample_fn(timestep_respacing)
+ sample_outputs = sample_fn(
+ self.model_first_steps,
+ self.model_last_step,
+ shape=low_res.shape,
+ clip_denoised=True,
+ model_kwargs=dict(low_res=low_res),
+ **kwargs,
+ )
+ for x in sample_outputs:
+ sample = x["sample"]
+ yield sample
diff --git a/karlo/modules/__init__.py b/karlo/modules/__init__.py
new file mode 100644
index 0000000..11d4358
--- /dev/null
+++ b/karlo/modules/__init__.py
@@ -0,0 +1,49 @@
+# ------------------------------------------------------------------------------------
+# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
+# ------------------------------------------------------------------------------------
+
+
+from .diffusion import gaussian_diffusion as gd
+from .diffusion.respace import (
+ SpacedDiffusion,
+ space_timesteps,
+)
+
+
+def create_gaussian_diffusion(
+ steps,
+ learn_sigma,
+ sigma_small,
+ noise_schedule,
+ use_kl,
+ predict_xstart,
+ rescale_learned_sigmas,
+ timestep_respacing,
+):
+ betas = gd.get_named_beta_schedule(noise_schedule, steps)
+ if use_kl:
+ loss_type = gd.LossType.RESCALED_KL
+ elif rescale_learned_sigmas:
+ loss_type = gd.LossType.RESCALED_MSE
+ else:
+ loss_type = gd.LossType.MSE
+ if not timestep_respacing:
+ timestep_respacing = [steps]
+
+ return SpacedDiffusion(
+ use_timesteps=space_timesteps(steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=(
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
+ ),
+ model_var_type=(
+ (
+ gd.ModelVarType.FIXED_LARGE
+ if not sigma_small
+ else gd.ModelVarType.FIXED_SMALL
+ )
+ if not learn_sigma
+ else gd.ModelVarType.LEARNED_RANGE
+ ),
+ loss_type=loss_type,
+ )
diff --git a/karlo/modules/diffusion/gaussian_diffusion.py b/karlo/modules/diffusion/gaussian_diffusion.py
new file mode 100644
index 0000000..6a111aa
--- /dev/null
+++ b/karlo/modules/diffusion/gaussian_diffusion.py
@@ -0,0 +1,828 @@
+# ------------------------------------------------------------------------------------
+# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
+# ------------------------------------------------------------------------------------
+
+import enum
+import math
+
+import numpy as np
+import torch as th
+
+
+def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
+ betas[:warmup_time] = np.linspace(
+ beta_start, beta_end, warmup_time, dtype=np.float64
+ )
+ return betas
+
+
+def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
+ """
+ This is the deprecated API for creating beta schedules.
+ See get_named_beta_schedule() for the new library of schedules.
+ """
+ if beta_schedule == "quad":
+ betas = (
+ np.linspace(
+ beta_start**0.5,
+ beta_end**0.5,
+ num_diffusion_timesteps,
+ dtype=np.float64,
+ )
+ ** 2
+ )
+ elif beta_schedule == "linear":
+ betas = np.linspace(
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif beta_schedule == "warmup10":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
+ elif beta_schedule == "warmup50":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
+ elif beta_schedule == "const":
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
+ betas = 1.0 / np.linspace(
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
+ )
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert betas.shape == (num_diffusion_timesteps,)
+ return betas
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ return get_beta_schedule(
+ "linear",
+ beta_start=scale * 0.0001,
+ beta_end=scale * 0.02,
+ num_diffusion_timesteps=num_diffusion_timesteps,
+ )
+ elif schedule_name == "squaredcos_cap_v2":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+class GaussianDiffusion(th.nn.Module):
+ """
+ Utilities for training and sampling diffusion models.
+ Original ported from this codebase:
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type,
+ ):
+ super(GaussianDiffusion, self).__init__()
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+ alphas_cumprod_next = np.append(alphas_cumprod[1:], 0.0)
+ assert alphas_cumprod_prev.shape == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
+ sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod)
+ log_one_minus_alphas_cumprod = np.log(1.0 - alphas_cumprod)
+ sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
+ sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
+ )
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ posterior_log_variance_clipped = np.log(
+ np.append(posterior_variance[1], posterior_variance[1:])
+ )
+ posterior_mean_coef1 = (
+ betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
+ )
+ posterior_mean_coef2 = (
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
+ )
+
+ self.register_buffer("betas", th.from_numpy(betas), persistent=False)
+ self.register_buffer(
+ "alphas_cumprod", th.from_numpy(alphas_cumprod), persistent=False
+ )
+ self.register_buffer(
+ "alphas_cumprod_prev", th.from_numpy(alphas_cumprod_prev), persistent=False
+ )
+ self.register_buffer(
+ "alphas_cumprod_next", th.from_numpy(alphas_cumprod_next), persistent=False
+ )
+
+ self.register_buffer(
+ "sqrt_alphas_cumprod", th.from_numpy(sqrt_alphas_cumprod), persistent=False
+ )
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod",
+ th.from_numpy(sqrt_one_minus_alphas_cumprod),
+ persistent=False,
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod",
+ th.from_numpy(log_one_minus_alphas_cumprod),
+ persistent=False,
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod",
+ th.from_numpy(sqrt_recip_alphas_cumprod),
+ persistent=False,
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod",
+ th.from_numpy(sqrt_recipm1_alphas_cumprod),
+ persistent=False,
+ )
+
+ self.register_buffer(
+ "posterior_variance", th.from_numpy(posterior_variance), persistent=False
+ )
+ self.register_buffer(
+ "posterior_log_variance_clipped",
+ th.from_numpy(posterior_log_variance_clipped),
+ persistent=False,
+ )
+ self.register_buffer(
+ "posterior_mean_coef1",
+ th.from_numpy(posterior_mean_coef1),
+ persistent=False,
+ )
+ self.register_buffer(
+ "posterior_mean_coef2",
+ th.from_numpy(posterior_mean_coef2),
+ persistent=False,
+ )
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ )
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
+ )
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+ In other words, sample from q(x_t | x_0).
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+ q(x_{t-1} | x_t, x_0)
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ assert (
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ **ignore_kwargs,
+ ):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, C = x.shape[:2]
+ assert t.shape == (B,)
+ model_output = model(x, t, **model_kwargs)
+ if isinstance(model_output, tuple):
+ model_output, extra = model_output
+ else:
+ extra = None
+
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ if self.model_var_type == ModelVarType.LEARNED:
+ model_log_variance = model_var_values
+ model_variance = th.exp(model_log_variance)
+ else:
+ min_log = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x.shape
+ )
+ max_log = _extract_into_tensor(th.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ th.cat([self.posterior_variance[1][None], self.betas[1:]]),
+ th.log(th.cat([self.posterior_variance[1][None], self.betas[1:]])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
+ )
+ model_mean = model_output
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ )
+ model_mean, _, _ = self.q_posterior_mean_variance(
+ x_start=pred_xstart, x_t=x, t=t
+ )
+ else:
+ raise NotImplementedError(self.model_mean_type)
+
+ assert (
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ )
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ )
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, t, **model_kwargs)
+ new_mean = (
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
+ )
+ return new_mean
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+ See condition_mean() for details on cond_fn.
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(
+ x_start=out["pred_xstart"], x_t=x, t=t
+ )
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(
+ cond_fn, out, x, t, model_kwargs=model_kwargs
+ )
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model.
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for idx, i in enumerate(indices):
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def p_sample_loop_progressive_for_improved_sr(
+ self,
+ model,
+ model_aux,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Modified version of p_sample_loop_progressive for sampling from the improved sr model
+ """
+
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for idx, i in enumerate(indices):
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model_aux if len(indices) - 1 == idx else model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = (
+ eta
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ )
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
+ )
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ + th.sqrt(1 - alpha_bar_next) * eps
+ )
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Generate samples from the model using DDIM.
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ yield out
+ img = out["sample"]
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = arr.to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diff --git a/karlo/modules/diffusion/respace.py b/karlo/modules/diffusion/respace.py
new file mode 100644
index 0000000..71fbd62
--- /dev/null
+++ b/karlo/modules/diffusion/respace.py
@@ -0,0 +1,111 @@
+# ------------------------------------------------------------------------------------
+# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
+# ------------------------------------------------------------------------------------
+
+
+import torch as th
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ elif section_counts == "fast27":
+ steps = space_timesteps(num_timesteps, "10,10,3,2,2")
+ # Help reduce DDIM artifacts from noisiest timesteps.
+ steps.remove(num_timesteps - 1)
+ steps.add(num_timesteps - 3)
+ return steps
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ timestep_map = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ timestep_map.append(i)
+ kwargs["betas"] = th.tensor(new_betas).numpy()
+ super().__init__(**kwargs)
+ self.register_buffer("timestep_map", th.tensor(timestep_map), persistent=False)
+
+ def p_mean_variance(self, model, *args, **kwargs):
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ def wrapped(x, ts, **kwargs):
+ return model(
+ x, self.timestep_map[ts].to(device=ts.device, dtype=ts.dtype), **kwargs
+ )
+
+ return wrapped
diff --git a/karlo/modules/nn.py b/karlo/modules/nn.py
new file mode 100644
index 0000000..2eef3f5
--- /dev/null
+++ b/karlo/modules/nn.py
@@ -0,0 +1,114 @@
+# ------------------------------------------------------------------------------------
+# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
+# ------------------------------------------------------------------------------------
+
+import math
+
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class GroupNorm32(nn.GroupNorm):
+ def __init__(self, num_groups, num_channels, swish, eps=1e-5):
+ super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
+ self.swish = swish
+
+ def forward(self, x):
+ y = super().forward(x.float()).to(x.dtype)
+ if self.swish == 1.0:
+ y = F.silu(y)
+ elif self.swish:
+ y = y * F.sigmoid(y * float(self.swish))
+ return y
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def normalization(channels, swish=0.0):
+ """
+ Make a standard normalization layer, with an optional swish activation.
+
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = th.exp(
+ -math.log(max_period)
+ * th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device)
+ / half
+ )
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
diff --git a/karlo/modules/resample.py b/karlo/modules/resample.py
new file mode 100644
index 0000000..485421a
--- /dev/null
+++ b/karlo/modules/resample.py
@@ -0,0 +1,68 @@
+# ------------------------------------------------------------------------------------
+# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
+# ------------------------------------------------------------------------------------
+
+from abc import abstractmethod
+
+import torch as th
+
+
+def create_named_schedule_sampler(name, diffusion):
+ """
+ Create a ScheduleSampler from a library of pre-defined samplers.
+
+ :param name: the name of the sampler.
+ :param diffusion: the diffusion object to sample for.
+ """
+ if name == "uniform":
+ return UniformSampler(diffusion)
+ else:
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
+
+
+class ScheduleSampler(th.nn.Module):
+ """
+ A distribution over timesteps in the diffusion process, intended to reduce
+ variance of the objective.
+
+ By default, samplers perform unbiased importance sampling, in which the
+ objective's mean is unchanged.
+ However, subclasses may override sample() to change how the resampled
+ terms are reweighted, allowing for actual changes in the objective.
+ """
+
+ @abstractmethod
+ def weights(self):
+ """
+ Get a numpy array of weights, one per diffusion step.
+
+ The weights needn't be normalized, but must be positive.
+ """
+
+ def sample(self, batch_size, device):
+ """
+ Importance-sample timesteps for a batch.
+
+ :param batch_size: the number of timesteps.
+ :param device: the torch device to save to.
+ :return: a tuple (timesteps, weights):
+ - timesteps: a tensor of timestep indices.
+ - weights: a tensor of weights to scale the resulting losses.
+ """
+ w = self.weights()
+ p = w / th.sum(w)
+ indices = p.multinomial(batch_size, replacement=True)
+ weights = 1 / (len(p) * p[indices])
+ return indices, weights
+
+
+class UniformSampler(ScheduleSampler):
+ def __init__(self, diffusion):
+ super(UniformSampler, self).__init__()
+ self.diffusion = diffusion
+ self.register_buffer(
+ "_weights", th.ones([diffusion.num_timesteps]), persistent=False
+ )
+
+ def weights(self):
+ return self._weights
diff --git a/karlo/modules/unet.py b/karlo/modules/unet.py
new file mode 100644
index 0000000..245e893
--- /dev/null
+++ b/karlo/modules/unet.py
@@ -0,0 +1,791 @@
+# ------------------------------------------------------------------------------------
+# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
+# ------------------------------------------------------------------------------------
+
+import math
+from abc import abstractmethod
+
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .nn import (
+ avg_pool_nd,
+ conv_nd,
+ linear,
+ normalization,
+ timestep_embedding,
+ zero_module,
+)
+from .xf import LayerNorm
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, encoder_out=None, mask=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, AttentionBlock):
+ x = layer(x, encoder_out, mask=mask)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels, swish=1.0),
+ nn.Identity(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(
+ self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0
+ ),
+ nn.SiLU() if use_scale_shift_norm else nn.Identity(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class ResBlockNoTimeEmbedding(nn.Module):
+ """
+ A residual block without time embedding
+
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+
+ self.in_layers = nn.Sequential(
+ normalization(channels, swish=1.0),
+ nn.Identity(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels, swish=1.0),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb=None):
+ """
+ Apply the block to a Tensor, NOT conditioned on a timestep embedding.
+
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert emb is None
+
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ encoder_channels=None,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels, swish=0.0)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ self.attention = QKVAttention(self.num_heads)
+
+ if encoder_channels is not None:
+ self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x, encoder_out=None, mask=None):
+ b, c, *spatial = x.shape
+ qkv = self.qkv(self.norm(x).view(b, c, -1))
+ if encoder_out is not None:
+ encoder_out = self.encoder_kv(encoder_out)
+ h = self.attention(qkv, encoder_out, mask=mask)
+ else:
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return x + h.reshape(b, c, *spatial)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv, encoder_kv=None, mask=None):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ if encoder_kv is not None:
+ assert encoder_kv.shape[1] == self.n_heads * ch * 2
+ ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
+ k = th.cat([ek, k], dim=-1)
+ v = th.cat([ev, v], dim=-1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum("bct,bcs->bts", q * scale, k * scale)
+ if mask is not None:
+ mask = F.pad(mask, (0, length), value=0.0)
+ mask = (
+ mask.unsqueeze(1)
+ .expand(-1, self.n_heads, -1)
+ .reshape(bs * self.n_heads, 1, -1)
+ )
+ weight = weight + mask
+ weight = th.softmax(weight, dim=-1)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ clip_dim=None,
+ use_checkpoint=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ use_middle_attention=True,
+ resblock_updown=False,
+ encoder_channels=None,
+ use_time_embedding=True,
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.clip_dim = clip_dim
+ self.use_checkpoint = use_checkpoint
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.use_middle_attention = use_middle_attention
+ self.use_time_embedding = use_time_embedding
+
+ if self.use_time_embedding:
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.clip_dim is not None:
+ self.clip_emb = nn.Linear(clip_dim, time_embed_dim)
+ else:
+ time_embed_dim = None
+
+ CustomResidualBlock = (
+ ResBlock if self.use_time_embedding else ResBlockNoTimeEmbedding
+ )
+ ch = input_ch = int(channel_mult[0] * model_channels)
+ self.input_blocks = nn.ModuleList(
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
+ )
+ self._feature_size = ch
+ input_block_chans = [ch]
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ CustomResidualBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=int(mult * model_channels),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(mult * model_channels)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ encoder_channels=encoder_channels,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ CustomResidualBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ CustomResidualBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ *(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ encoder_channels=encoder_channels,
+ ),
+ )
+ if self.use_middle_attention
+ else tuple(), # add AttentionBlock or not
+ CustomResidualBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ CustomResidualBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=int(model_channels * mult),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(model_channels * mult)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=num_head_channels,
+ encoder_channels=encoder_channels,
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ CustomResidualBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch, swish=1.0),
+ nn.Identity(),
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
+ )
+
+ def forward(self, x, timesteps, y=None):
+ """
+ Apply the model to an input batch.
+
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.clip_dim is not None
+ ), "must specify y if and only if the model is clip-rep-conditional"
+
+ hs = []
+ if self.use_time_embedding:
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+ if self.clip_dim is not None:
+ emb = emb + self.clip_emb(y)
+ else:
+ emb = None
+
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb)
+ hs.append(h)
+ h = self.middle_block(h, emb)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb)
+
+ return self.out(h)
+
+
+class SuperResUNetModel(UNetModel):
+ """
+ A UNetModel that performs super-resolution.
+
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
+ """
+
+ def __init__(self, *args, **kwargs):
+ if "in_channels" in kwargs:
+ kwargs = dict(kwargs)
+ kwargs["in_channels"] = kwargs["in_channels"] * 2
+ else:
+ # Curse you, Python. Or really, just curse positional arguments :|.
+ args = list(args)
+ args[1] = args[1] * 2
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x, timesteps, low_res=None, **kwargs):
+ _, _, new_height, new_width = x.shape
+ upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
+ x = th.cat([x, upsampled], dim=1)
+ return super().forward(x, timesteps, **kwargs)
+
+
+class PLMImUNet(UNetModel):
+ """
+ A UNetModel that conditions on text with an encoding transformer.
+
+ Expects an extra kwarg `tokens` of text.
+
+ :param text_ctx: number of text tokens to expect.
+ :param xf_width: width of the transformer.
+ :param xf_layers: depth of the transformer.
+ :param xf_heads: heads in the transformer.
+ :param xf_final_ln: use a LayerNorm after the output layer.
+ """
+
+ def __init__(
+ self,
+ text_ctx,
+ xf_width,
+ *args,
+ clip_emb_mult=None,
+ clip_emb_type="image",
+ clip_emb_drop=0.0,
+ **kwargs,
+ ):
+ self.text_ctx = text_ctx
+ self.xf_width = xf_width
+ self.clip_emb_mult = clip_emb_mult
+ self.clip_emb_type = clip_emb_type
+ self.clip_emb_drop = clip_emb_drop
+
+ if not xf_width:
+ super().__init__(*args, **kwargs, encoder_channels=None)
+ else:
+ super().__init__(*args, **kwargs, encoder_channels=xf_width)
+
+ # Project text encoded feat seq from pre-trained LM
+ self.text_seq_proj = nn.Sequential(
+ nn.Linear(self.clip_dim, xf_width),
+ LayerNorm(xf_width),
+ )
+ # Project CLIP text feat
+ self.text_feat_proj = nn.Linear(self.clip_dim, self.model_channels * 4)
+
+ if self.clip_emb_mult is not None:
+ assert (
+ self.clip_dim is not None
+ ), "CLIP representation dim should be specified"
+ self.clip_tok_proj = nn.Linear(
+ self.clip_dim, self.xf_width * self.clip_emb_mult
+ )
+ if self.clip_emb_drop > 0:
+ self.cf_param = nn.Parameter(th.empty(self.clip_dim, dtype=th.float32))
+
+ def proc_clip_emb_drop(self, feat):
+ if self.clip_emb_drop > 0:
+ bsz, feat_dim = feat.shape
+ assert (
+ feat_dim == self.clip_dim
+ ), f"CLIP input dim: {feat_dim}, model CLIP dim: {self.clip_dim}"
+ drop_idx = th.rand((bsz,), device=feat.device) < self.clip_emb_drop
+ feat = th.where(
+ drop_idx[..., None], self.cf_param[None].type_as(feat), feat
+ )
+ return feat
+
+ def forward(
+ self, x, timesteps, txt_feat=None, txt_feat_seq=None, mask=None, y=None
+ ):
+ bsz = x.shape[0]
+ hs = []
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+ if self.clip_dim is not None:
+ emb = emb + self.clip_emb(y)
+
+ xf_out = self.text_seq_proj(txt_feat_seq)
+ xf_out = xf_out.permute(0, 2, 1)
+ emb = emb + self.text_feat_proj(txt_feat)
+ if self.clip_emb_mult is not None:
+ xf_out = th.cat(
+ [
+ self.clip_tok_proj(y).reshape(bsz, -1, self.clip_emb_mult),
+ xf_out,
+ ],
+ dim=2,
+ )
+ mask = F.pad(mask, (self.clip_emb_mult, 0), value=True)
+ mask = th.where(mask, 0.0, float("-inf"))
+
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb, xf_out, mask=mask)
+ hs.append(h)
+ h = self.middle_block(h, emb, xf_out, mask=mask)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, xf_out, mask=mask)
+ h = self.out(h)
+
+ return h
diff --git a/karlo/modules/xf.py b/karlo/modules/xf.py
new file mode 100644
index 0000000..0f3ec1e
--- /dev/null
+++ b/karlo/modules/xf.py
@@ -0,0 +1,246 @@
+# ------------------------------------------------------------------------------------
+# Adapted from the repos below:
+# (a) Guided-Diffusion (https://github.com/openai/guided-diffusion)
+# (b) CLIP ViT (https://github.com/openai/CLIP/)
+# ------------------------------------------------------------------------------------
+
+import math
+
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .nn import timestep_embedding
+
+
+def convert_module_to_f16(param):
+ """
+ Convert primitive modules to float16.
+ """
+ if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ param.weight.data = param.weight.data.half()
+ if param.bias is not None:
+ param.bias.data = param.bias.data.half()
+
+
+class LayerNorm(nn.LayerNorm):
+ """
+ Implementation that supports fp16 inputs but fp32 gains/biases.
+ """
+
+ def forward(self, x: th.Tensor):
+ return super().forward(x.float()).to(x.dtype)
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(self, n_ctx, width, heads):
+ super().__init__()
+ self.n_ctx = n_ctx
+ self.width = width
+ self.heads = heads
+ self.c_qkv = nn.Linear(width, width * 3)
+ self.c_proj = nn.Linear(width, width)
+ self.attention = QKVMultiheadAttention(heads, n_ctx)
+
+ def forward(self, x, mask=None):
+ x = self.c_qkv(x)
+ x = self.attention(x, mask=mask)
+ x = self.c_proj(x)
+ return x
+
+
+class MLP(nn.Module):
+ def __init__(self, width):
+ super().__init__()
+ self.width = width
+ self.c_fc = nn.Linear(width, width * 4)
+ self.c_proj = nn.Linear(width * 4, width)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ return self.c_proj(self.gelu(self.c_fc(x)))
+
+
+class QKVMultiheadAttention(nn.Module):
+ def __init__(self, n_heads: int, n_ctx: int):
+ super().__init__()
+ self.n_heads = n_heads
+ self.n_ctx = n_ctx
+
+ def forward(self, qkv, mask=None):
+ bs, n_ctx, width = qkv.shape
+ attn_ch = width // self.n_heads // 3
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
+ qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
+ q, k, v = th.split(qkv, attn_ch, dim=-1)
+ weight = th.einsum("bthc,bshc->bhts", q * scale, k * scale)
+ wdtype = weight.dtype
+ if mask is not None:
+ weight = weight + mask[:, None, ...]
+ weight = th.softmax(weight, dim=-1).type(wdtype)
+ return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ n_ctx: int,
+ width: int,
+ heads: int,
+ ):
+ super().__init__()
+
+ self.attn = MultiheadAttention(
+ n_ctx,
+ width,
+ heads,
+ )
+ self.ln_1 = LayerNorm(width)
+ self.mlp = MLP(width)
+ self.ln_2 = LayerNorm(width)
+
+ def forward(self, x, mask=None):
+ x = x + self.attn(self.ln_1(x), mask=mask)
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ n_ctx: int,
+ width: int,
+ layers: int,
+ heads: int,
+ ):
+ super().__init__()
+ self.n_ctx = n_ctx
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.ModuleList(
+ [
+ ResidualAttentionBlock(
+ n_ctx,
+ width,
+ heads,
+ )
+ for _ in range(layers)
+ ]
+ )
+
+ def forward(self, x, mask=None):
+ for block in self.resblocks:
+ x = block(x, mask=mask)
+ return x
+
+
+class PriorTransformer(nn.Module):
+ """
+ A Causal Transformer that conditions on CLIP text embedding, text.
+
+ Expects an extra kwarg `tokens` of text.
+
+ :param text_ctx: number of text tokens to expect.
+ :param xf_width: width of the transformer.
+ :param xf_layers: depth of the transformer.
+ :param xf_heads: heads in the transformer.
+ :param xf_final_ln: use a LayerNorm after the output layer.
+ """
+
+ def __init__(
+ self,
+ text_ctx,
+ xf_width,
+ xf_layers,
+ xf_heads,
+ xf_final_ln,
+ xf_padding,
+ clip_dim,
+ clip_xf_width,
+ ):
+ super().__init__()
+
+ self.text_ctx = text_ctx
+ self.xf_width = xf_width
+ self.xf_layers = xf_layers
+ self.xf_heads = xf_heads
+ self.xf_padding = xf_padding
+ self.clip_dim = clip_dim
+ self.clip_xf_width = clip_xf_width
+ self.ext_len = 4
+
+ self.time_embed = nn.Sequential(
+ nn.Linear(xf_width, xf_width),
+ nn.SiLU(),
+ nn.Linear(xf_width, xf_width),
+ )
+ self.text_enc_proj = nn.Linear(clip_xf_width, xf_width)
+ self.text_emb_proj = nn.Linear(clip_dim, xf_width)
+ self.clip_img_proj = nn.Linear(clip_dim, xf_width)
+ self.out_proj = nn.Linear(xf_width, clip_dim)
+ self.transformer = Transformer(
+ text_ctx + self.ext_len,
+ xf_width,
+ xf_layers,
+ xf_heads,
+ )
+ if xf_final_ln:
+ self.final_ln = LayerNorm(xf_width)
+ else:
+ self.final_ln = None
+
+ self.positional_embedding = nn.Parameter(
+ th.empty(1, text_ctx + self.ext_len, xf_width)
+ )
+ self.prd_emb = nn.Parameter(th.randn((1, 1, xf_width)))
+
+ if self.xf_padding:
+ self.padding_embedding = nn.Parameter(
+ th.empty(text_ctx + self.ext_len, xf_width)
+ )
+ nn.init.normal_(self.padding_embedding, std=0.01)
+
+ nn.init.normal_(self.prd_emb, std=0.01)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+
+ def forward(
+ self,
+ x,
+ timesteps,
+ text_emb=None,
+ text_enc=None,
+ mask=None,
+ causal_mask=None,
+ ):
+ bsz = x.shape[0]
+ mask = F.pad(mask, (0, self.ext_len), value=True)
+
+ t_emb = self.time_embed(timestep_embedding(timesteps, self.xf_width))
+ text_enc = self.text_enc_proj(text_enc)
+ text_emb = self.text_emb_proj(text_emb)
+ x = self.clip_img_proj(x)
+
+ input_seq = [
+ text_enc,
+ text_emb[:, None, :],
+ t_emb[:, None, :],
+ x[:, None, :],
+ self.prd_emb.to(x.dtype).expand(bsz, -1, -1),
+ ]
+ input = th.cat(input_seq, dim=1)
+ input = input + self.positional_embedding.to(input.dtype)
+ if self.xf_padding:
+ input = th.where(
+ mask[..., None], input, self.padding_embedding[None].to(input.dtype)
+ )
+
+ mask = th.where(mask, 0.0, float("-inf"))
+ mask = (mask[:, None, :] + causal_mask).to(input.dtype)
+
+ out = self.transformer(input, mask=mask)
+ if self.final_ln is not None:
+ out = self.final_ln(out)
+
+ out = self.out_proj(out[:, -1])
+
+ return out
diff --git a/karlo/sampler/i2i.py b/karlo/sampler/i2i.py
new file mode 100644
index 0000000..da691fb
--- /dev/null
+++ b/karlo/sampler/i2i.py
@@ -0,0 +1,156 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+from typing import Iterator
+
+import torch
+import torchvision.transforms.functional as TVF
+from torchvision.transforms import InterpolationMode
+
+from .template import BaseSampler, CKPT_PATH
+
+
+class I2ISampler(BaseSampler):
+ def __init__(
+ self,
+ root_dir: str,
+ sampling_type: str = "default",
+ ):
+ super().__init__(root_dir, sampling_type)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ root_dir: str,
+ clip_model_path: str,
+ clip_stat_path: str,
+ sampling_type: str = "default",
+ ):
+
+ model = cls(
+ root_dir=root_dir,
+ sampling_type=sampling_type,
+ )
+ model.load_clip(clip_model_path)
+ model.load_decoder(f"{CKPT_PATH['decoder']}")
+ model.load_sr_64_256(CKPT_PATH["sr_256"])
+
+ return model
+
+ def preprocess(
+ self,
+ image,
+ prompt: str,
+ bsz: int,
+ ):
+ prompts_batch = [prompt for _ in range(bsz)]
+ decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
+ decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
+
+ # preprocess input image
+ image = TVF.normalize(
+ TVF.to_tensor(
+ TVF.resize(
+ image,
+ [224, 224],
+ interpolation=InterpolationMode.BICUBIC,
+ antialias=True,
+ )
+ ),
+ mean=[0.48145466, 0.4578275, 0.40821073],
+ std=[0.26862954, 0.26130258, 0.27577711],
+ ).unsqueeze(0)
+ image_batch = image.repeat(bsz, 1, 1, 1).cuda()
+
+ """ Get CLIP text and image features """
+ clip_model = self._clip
+ tokenizer = self._tokenizer
+ max_txt_length = 77
+
+ tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
+ cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
+ if not (cf_token.shape == tok.shape):
+ cf_token = cf_token.expand(tok.shape[0], -1)
+ cf_mask = cf_mask.expand(tok.shape[0], -1)
+
+ tok = torch.cat([tok, cf_token], dim=0)
+ mask = torch.cat([mask, cf_mask], dim=0)
+
+ tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
+ txt_feat, txt_feat_seq = clip_model.encode_text(tok)
+ img_feat = clip_model.encode_image(image_batch)
+
+ return (
+ prompts_batch,
+ decoder_cf_scales_batch,
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ img_feat,
+ )
+
+ def __call__(
+ self,
+ image,
+ bsz: int,
+ progressive_mode=None,
+ ) -> Iterator[torch.Tensor]:
+ assert progressive_mode in ("loop", "stage", "final")
+ with torch.no_grad(), torch.cuda.amp.autocast():
+ (
+ prompts_batch,
+ decoder_cf_scales_batch,
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ img_feat,
+ ) = self.preprocess(
+ image=image,
+ prompt="",
+ bsz=bsz,
+ )
+
+ """ Generate 64x64px images """
+ images_64_outputs = self._decoder(
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ img_feat,
+ cf_guidance_scales=decoder_cf_scales_batch,
+ timestep_respacing=self._decoder_sm,
+ )
+
+ images_64 = None
+ for k, out in enumerate(images_64_outputs):
+ images_64 = out
+ if progressive_mode == "loop":
+ yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
+ if progressive_mode == "stage":
+ yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
+
+ images_64 = torch.clamp(images_64, -1, 1)
+
+ """ Upsample 64x64 to 256x256 """
+ images_256 = TVF.resize(
+ images_64,
+ [256, 256],
+ interpolation=InterpolationMode.BICUBIC,
+ antialias=True,
+ )
+ images_256_outputs = self._sr_64_256(
+ images_256, timestep_respacing=self._sr_sm
+ )
+
+ for k, out in enumerate(images_256_outputs):
+ images_256 = out
+ if progressive_mode == "loop":
+ yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
+ if progressive_mode == "stage":
+ yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
+
+ yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0)
diff --git a/karlo/sampler/t2i.py b/karlo/sampler/t2i.py
new file mode 100644
index 0000000..ed189d9
--- /dev/null
+++ b/karlo/sampler/t2i.py
@@ -0,0 +1,156 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+from typing import Iterator
+
+import torch
+import torchvision.transforms.functional as TVF
+from torchvision.transforms import InterpolationMode
+
+from .template import BaseSampler, CKPT_PATH
+
+
+class T2ISampler(BaseSampler):
+ def __init__(
+ self,
+ root_dir: str,
+ sampling_type: str = "default",
+ ):
+ super().__init__(root_dir, sampling_type)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ root_dir: str,
+ clip_model_path: str,
+ clip_stat_path: str,
+ sampling_type: str = "default",
+ ):
+
+ model = cls(
+ root_dir=root_dir,
+ sampling_type=sampling_type,
+ )
+ model.load_clip(clip_model_path)
+ model.load_prior(
+ f"{CKPT_PATH['prior']}",
+ clip_stat_path=clip_stat_path,
+ )
+ model.load_decoder(f"{CKPT_PATH['decoder']}")
+ model.load_sr_64_256(CKPT_PATH["sr_256"])
+
+ return model
+
+ def preprocess(
+ self,
+ prompt: str,
+ bsz: int,
+ ):
+ """Setup prompts & cfg scales"""
+ prompts_batch = [prompt for _ in range(bsz)]
+
+ prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
+ prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
+
+ decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
+ decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
+
+ """ Get CLIP text feature """
+ clip_model = self._clip
+ tokenizer = self._tokenizer
+ max_txt_length = self._prior.model.text_ctx
+
+ tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
+ cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
+ if not (cf_token.shape == tok.shape):
+ cf_token = cf_token.expand(tok.shape[0], -1)
+ cf_mask = cf_mask.expand(tok.shape[0], -1)
+
+ tok = torch.cat([tok, cf_token], dim=0)
+ mask = torch.cat([mask, cf_mask], dim=0)
+
+ tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
+ txt_feat, txt_feat_seq = clip_model.encode_text(tok)
+
+ return (
+ prompts_batch,
+ prior_cf_scales_batch,
+ decoder_cf_scales_batch,
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ )
+
+ def __call__(
+ self,
+ prompt: str,
+ bsz: int,
+ progressive_mode=None,
+ ) -> Iterator[torch.Tensor]:
+ assert progressive_mode in ("loop", "stage", "final")
+ with torch.no_grad(), torch.cuda.amp.autocast():
+ (
+ prompts_batch,
+ prior_cf_scales_batch,
+ decoder_cf_scales_batch,
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ ) = self.preprocess(
+ prompt,
+ bsz,
+ )
+
+ """ Transform CLIP text feature into image feature """
+ img_feat = self._prior(
+ txt_feat,
+ txt_feat_seq,
+ mask,
+ prior_cf_scales_batch,
+ timestep_respacing=self._prior_sm,
+ )
+
+ """ Generate 64x64px images """
+ images_64_outputs = self._decoder(
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ img_feat,
+ cf_guidance_scales=decoder_cf_scales_batch,
+ timestep_respacing=self._decoder_sm,
+ )
+
+ images_64 = None
+ for k, out in enumerate(images_64_outputs):
+ images_64 = out
+ if progressive_mode == "loop":
+ yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
+ if progressive_mode == "stage":
+ yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
+
+ images_64 = torch.clamp(images_64, -1, 1)
+
+ """ Upsample 64x64 to 256x256 """
+ images_256 = TVF.resize(
+ images_64,
+ [256, 256],
+ interpolation=InterpolationMode.BICUBIC,
+ antialias=True,
+ )
+ images_256_outputs = self._sr_64_256(
+ images_256, timestep_respacing=self._sr_sm
+ )
+
+ for k, out in enumerate(images_256_outputs):
+ images_256 = out
+ if progressive_mode == "loop":
+ yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
+ if progressive_mode == "stage":
+ yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
+
+ yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0)
diff --git a/karlo/sampler/template.py b/karlo/sampler/template.py
new file mode 100644
index 0000000..be3a5aa
--- /dev/null
+++ b/karlo/sampler/template.py
@@ -0,0 +1,140 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+import os
+import logging
+import torch
+
+from omegaconf import OmegaConf
+
+from ..models.clip import CustomizedCLIP, CustomizedTokenizer
+from ..models.prior_model import PriorDiffusionModel
+from ..models.decoder_model import Text2ImProgressiveModel
+from ..models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel
+
+
+SAMPLING_CONF = {
+ "default": {
+ "prior_sm": "25",
+ "prior_n_samples": 1,
+ "prior_cf_scale": 4.0,
+ "decoder_sm": "50",
+ "decoder_cf_scale": 8.0,
+ "sr_sm": "7",
+ },
+ "fast": {
+ "prior_sm": "25",
+ "prior_n_samples": 1,
+ "prior_cf_scale": 4.0,
+ "decoder_sm": "25",
+ "decoder_cf_scale": 8.0,
+ "sr_sm": "7",
+ },
+}
+
+CKPT_PATH = {
+ "prior": "prior-ckpt-step=01000000-of-01000000.ckpt",
+ "decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt",
+ "sr_256": "improved-sr-ckpt-step=1.2M.ckpt",
+}
+
+
+class BaseSampler:
+ _PRIOR_CLASS = PriorDiffusionModel
+ _DECODER_CLASS = Text2ImProgressiveModel
+ _SR256_CLASS = ImprovedSupRes64to256ProgressiveModel
+
+ def __init__(
+ self,
+ root_dir: str,
+ sampling_type: str = "fast",
+ ):
+ self._root_dir = root_dir
+
+ sampling_type = SAMPLING_CONF[sampling_type]
+ self._prior_sm = sampling_type["prior_sm"]
+ self._prior_n_samples = sampling_type["prior_n_samples"]
+ self._prior_cf_scale = sampling_type["prior_cf_scale"]
+
+ assert self._prior_n_samples == 1
+
+ self._decoder_sm = sampling_type["decoder_sm"]
+ self._decoder_cf_scale = sampling_type["decoder_cf_scale"]
+
+ self._sr_sm = sampling_type["sr_sm"]
+
+ def __repr__(self):
+ line = ""
+ line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n"
+ line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n"
+ line += f"SR(64->256), sampling method: {self._sr_sm}"
+
+ return line
+
+ def load_clip(self, clip_path: str):
+ clip = CustomizedCLIP.load_from_checkpoint(
+ os.path.join(self._root_dir, clip_path)
+ )
+ clip = torch.jit.script(clip)
+ clip.cuda()
+ clip.eval()
+
+ self._clip = clip
+ self._tokenizer = CustomizedTokenizer()
+
+ def load_prior(
+ self,
+ ckpt_path: str,
+ clip_stat_path: str,
+ ):
+ logging.info(f"Loading prior: {ckpt_path}")
+
+ config = OmegaConf.load("configs/prior_1B_vit_l.yaml")
+ clip_mean, clip_std = torch.load(
+ os.path.join(self._root_dir, clip_stat_path), map_location="cpu"
+ )
+
+ prior = self._PRIOR_CLASS.load_from_checkpoint(
+ config,
+ self._tokenizer,
+ clip_mean,
+ clip_std,
+ os.path.join(self._root_dir, ckpt_path),
+ strict=True,
+ )
+ prior.cuda()
+ prior.eval()
+ logging.info("done.")
+
+ self._prior = prior
+
+ def load_decoder(self, ckpt_path: str):
+ logging.info(f"Loading decoder: {ckpt_path}")
+
+ config = OmegaConf.load("configs/decoder_900M_vit_l.yaml")
+ decoder = self._DECODER_CLASS.load_from_checkpoint(
+ config,
+ self._tokenizer,
+ os.path.join(self._root_dir, ckpt_path),
+ strict=True,
+ )
+ decoder.cuda()
+ decoder.eval()
+ logging.info("done.")
+
+ self._decoder = decoder
+
+ def load_sr_64_256(self, ckpt_path: str):
+ logging.info(f"Loading SR(64->256): {ckpt_path}")
+
+ config = OmegaConf.load("configs/improved_sr_64_256_1.4B.yaml")
+ sr = self._SR256_CLASS.load_from_checkpoint(
+ config, os.path.join(self._root_dir, ckpt_path), strict=True
+ )
+ sr.cuda()
+ sr.eval()
+ logging.info("done.")
+
+ self._sr_64_256 = sr
diff --git a/karlo/utils/util.py b/karlo/utils/util.py
new file mode 100644
index 0000000..7645ebe
--- /dev/null
+++ b/karlo/utils/util.py
@@ -0,0 +1,10 @@
+import random
+import torch
+import numpy as np
+
+
+def set_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..91591e6
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,8 @@
+torch>=1.10
+torchvision>=0.8.2
+black
+einops
+omegaconf
+matplotlib
+gradio>=3.5.0
+git+https://github.com/openai/CLIP.git
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..409f0ef
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,3 @@
+[flake8]
+max-line-length = 120
+ignore = E203, E226, E402, E731, W503, W504
diff --git a/setup.sh b/setup.sh
new file mode 100755
index 0000000..9c04926
--- /dev/null
+++ b/setup.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+pip install -r requirements.txt
+
+export KARLO_ROOT_DIR=$HOME/.cache/karlo/v1.0.alpha/
+
+wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/096db1af569b284eb76b3881534822d9/ViT-L-14.pt -P $KARLO_ROOT_DIR
+wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/0b62380a75e56f073e2844ab5199153d/ViT-L-14_stats.th -P $KARLO_ROOT_DIR
+wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/efdf6206d8ed593961593dc029a8affa/decoder-ckpt-step%3D01000000-of-01000000.ckpt -P $KARLO_ROOT_DIR
+wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/85626483eaca9f581e2a78d31ff905ca/prior-ckpt-step%3D01000000-of-01000000.ckpt -P $KARLO_ROOT_DIR
+wget https://arena.kakaocdn.net/brainrepo/models/karlo-public/v1.0.0.alpha/4226b831ae0279020d134281f3c31590/improved-sr-ckpt-step%3D1.2M.ckpt -P $KARLO_ROOT_DIR