diff --git a/Readme.md b/Readme.md index 1c655fa..cf7cd22 100644 --- a/Readme.md +++ b/Readme.md @@ -118,6 +118,7 @@ If we enable Tiny decoder(TAESD) we can save some memory(2GB approx) for example - Experimental support for single file Safetensors SD 1.5 models(Civitai models), simply add local model path to configs/stable-diffusion-models.txt file. - Add REST API support - Add Aura SR (4x)/GigaGAN based upscaler support +- Add Aura SR v2 upscaler support diff --git a/src/backend/upscale/aura_sr.py b/src/backend/upscale/aura_sr.py index be6efa3..787a66f 100644 --- a/src/backend/upscale/aura_sr.py +++ b/src/backend/upscale/aura_sr.py @@ -14,6 +14,8 @@ from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange +from torchvision.utils import save_image +import math def get_same_padding(size, kernel, dilation, stride): @@ -186,6 +188,7 @@ def null_iterator(): while True: yield None + def Downsample(dim, dim_out=None): return nn.Sequential( Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2), @@ -385,6 +388,7 @@ def forward(self, x): return x + class EqualLinear(nn.Module): def __init__(self, dim, dim_out, lr_mul=1, bias=True): super().__init__() @@ -714,11 +718,49 @@ def tile_image(image, chunk_size=64): tiles = [] for i in range(h_chunks): for j in range(w_chunks): - tile = image[:, i * chunk_size:(i + 1) * chunk_size, j * chunk_size:(j + 1) * chunk_size] + tile = image[ + :, + i * chunk_size : (i + 1) * chunk_size, + j * chunk_size : (j + 1) * chunk_size, + ] tiles.append(tile) return tiles, h_chunks, w_chunks +# This helps create a checkboard pattern with some edge blending +def create_checkerboard_weights(tile_size): + x = torch.linspace(-1, 1, tile_size) + y = torch.linspace(-1, 1, tile_size) + + x, y = torch.meshgrid(x, y, indexing="ij") + d = torch.sqrt(x * x + y * y) + sigma, mu = 0.5, 0.0 + weights = torch.exp(-((d - mu) ** 2 / (2.0 * sigma**2))) + + # saturate the values to sure get high weights in the center + weights = weights**8 + + return weights / weights.max() # Normalize to [0, 1] + + +def repeat_weights(weights, image_size): + tile_size = weights.shape[0] + repeats = ( + math.ceil(image_size[0] / tile_size), + math.ceil(image_size[1] / tile_size), + ) + return weights.repeat(repeats)[: image_size[0], : image_size[1]] + + +def create_offset_weights(weights, image_size): + tile_size = weights.shape[0] + offset = tile_size // 2 + full_weights = repeat_weights( + weights, (image_size[0] + offset, image_size[1] + offset) + ) + return full_weights[offset:, offset:] + + def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64): # Determine the shape of the output tensor c = tiles[0].shape[0] @@ -737,7 +779,7 @@ def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64): w_start = j * chunk_size tile_h, tile_w = tile.shape[1:] - merged[:, h_start:h_start+tile_h, w_start:w_start+tile_w] = tile + merged[:, h_start : h_start + tile_h, w_start : w_start + tile_w] = tile return merged @@ -748,7 +790,12 @@ def __init__(self, config: dict[str, Any], device: str = "cuda"): self.input_image_size = config["input_image_size"] @classmethod - def from_pretrained(cls, model_id: str = "fal-ai/AuraSR",device: str="cuda",use_safetensors: bool = True): + def from_pretrained( + cls, + model_id: str = "fal-ai/AuraSR", + use_safetensors: bool = True, + device: str = "cuda", + ): import json import torch from pathlib import Path @@ -757,15 +804,17 @@ def from_pretrained(cls, model_id: str = "fal-ai/AuraSR",device: str="cuda",use_ # Check if model_id is a local file if Path(model_id).is_file(): local_file = Path(model_id) - if local_file.suffix == '.safetensors': + if local_file.suffix == ".safetensors": use_safetensors = True - elif local_file.suffix == '.ckpt': + elif local_file.suffix == ".ckpt": use_safetensors = False else: - raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.") - + raise ValueError( + f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files." + ) + # For local files, we need to provide the config separately - config_path = local_file.with_name('config.json') + config_path = local_file.with_name("config.json") if not config_path.exists(): raise FileNotFoundError( f"Config file not found: {config_path}. " @@ -774,19 +823,26 @@ def from_pretrained(cls, model_id: str = "fal-ai/AuraSR",device: str="cuda",use_ f"If you're trying to load a model from Hugging Face, " f"please provide the model ID instead of a file path." ) - + config = json.loads(config_path.read_text()) hf_model_path = local_file.parent else: - hf_model_path = Path(snapshot_download(model_id)) + hf_model_path = Path( + snapshot_download(model_id, ignore_patterns=["*.ckpt"]) + ) config = json.loads((hf_model_path / "config.json").read_text()) - model = cls(config,device) + model = cls(config, device) if use_safetensors: try: from safetensors.torch import load_file - checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id) + + checkpoint = load_file( + hf_model_path / "model.safetensors" + if not Path(model_id).is_file() + else model_id + ) except ImportError: raise ImportError( "The safetensors library is not installed. " @@ -794,7 +850,11 @@ def from_pretrained(cls, model_id: str = "fal-ai/AuraSR",device: str="cuda",use_ "or use `use_safetensors=False` to load the model with PyTorch." ) else: - checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id) + checkpoint = torch.load( + hf_model_path / "model.ckpt" + if not Path(model_id).is_file() + else model_id + ) model.upsampler.load_state_dict(checkpoint, strict=True) return model @@ -806,29 +866,139 @@ def upscale_4x(self, image: Image.Image, max_batch_size=8) -> Image.Image: image_tensor = tensor_transform(image).unsqueeze(0) _, _, h, w = image_tensor.shape - pad_h = (self.input_image_size - h % self.input_image_size) % self.input_image_size - pad_w = (self.input_image_size - w % self.input_image_size) % self.input_image_size + pad_h = ( + self.input_image_size - h % self.input_image_size + ) % self.input_image_size + pad_w = ( + self.input_image_size - w % self.input_image_size + ) % self.input_image_size # Pad the image - image_tensor = torch.nn.functional.pad(image_tensor, (0, pad_w, 0, pad_h), mode='reflect').squeeze(0) + image_tensor = torch.nn.functional.pad( + image_tensor, (0, pad_w, 0, pad_h), mode="reflect" + ).squeeze(0) tiles, h_chunks, w_chunks = tile_image(image_tensor, self.input_image_size) # Batch processing of tiles num_tiles = len(tiles) - batches = [tiles[i:i + max_batch_size] for i in range(0, num_tiles, max_batch_size)] + batches = [ + tiles[i : i + max_batch_size] for i in range(0, num_tiles, max_batch_size) + ] reconstructed_tiles = [] for batch in batches: model_input = torch.stack(batch).to(device) generator_output = self.upsampler( lowres_image=model_input, - noise=torch.randn(model_input.shape[0], 128, device=device) + noise=torch.randn(model_input.shape[0], 128, device=device), + ) + reconstructed_tiles.extend( + list(generator_output.clamp_(0, 1).detach().cpu()) ) - reconstructed_tiles.extend(list(generator_output.clamp_(0, 1).detach().cpu())) - merged_tensor = merge_tiles(reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4) - unpadded = merged_tensor[:, :h * 4, :w * 4] + merged_tensor = merge_tiles( + reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4 + ) + unpadded = merged_tensor[:, : h * 4, : w * 4] to_pil = transforms.ToPILImage() return to_pil(unpadded) + # Tiled 4x upscaling with overlapping tiles to reduce seam artifacts + # weights options are 'checkboard' and 'constant' + @torch.no_grad() + def upscale_4x_overlapped(self, image, max_batch_size=8, weight_type="checkboard"): + tensor_transform = transforms.ToTensor() + device = self.upsampler.device + + image_tensor = tensor_transform(image).unsqueeze(0) + _, _, h, w = image_tensor.shape + + # Calculate paddings + pad_h = ( + self.input_image_size - h % self.input_image_size + ) % self.input_image_size + pad_w = ( + self.input_image_size - w % self.input_image_size + ) % self.input_image_size + + # Pad the image + image_tensor = torch.nn.functional.pad( + image_tensor, (0, pad_w, 0, pad_h), mode="reflect" + ).squeeze(0) + + # Function to process tiles + def process_tiles(tiles, h_chunks, w_chunks): + num_tiles = len(tiles) + batches = [ + tiles[i : i + max_batch_size] + for i in range(0, num_tiles, max_batch_size) + ] + reconstructed_tiles = [] + + for batch in batches: + model_input = torch.stack(batch).to(device) + generator_output = self.upsampler( + lowres_image=model_input, + noise=torch.randn(model_input.shape[0], 128, device=device), + ) + reconstructed_tiles.extend( + list(generator_output.clamp_(0, 1).detach().cpu()) + ) + + return merge_tiles( + reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4 + ) + + # First pass + tiles1, h_chunks1, w_chunks1 = tile_image(image_tensor, self.input_image_size) + result1 = process_tiles(tiles1, h_chunks1, w_chunks1) + + # Second pass with offset + offset = self.input_image_size // 2 + image_tensor_offset = torch.nn.functional.pad( + image_tensor, (offset, offset, offset, offset), mode="reflect" + ).squeeze(0) + + tiles2, h_chunks2, w_chunks2 = tile_image( + image_tensor_offset, self.input_image_size + ) + result2 = process_tiles(tiles2, h_chunks2, w_chunks2) + + # unpad + offset_4x = offset * 4 + result2_interior = result2[:, offset_4x:-offset_4x, offset_4x:-offset_4x] + + if weight_type == "checkboard": + weight_tile = create_checkerboard_weights(self.input_image_size * 4) + + weight_shape = result2_interior.shape[1:] + weights_1 = create_offset_weights(weight_tile, weight_shape) + weights_2 = repeat_weights(weight_tile, weight_shape) + + normalizer = weights_1 + weights_2 + weights_1 = weights_1 / normalizer + weights_2 = weights_2 / normalizer + + weights_1 = weights_1.unsqueeze(0).repeat(3, 1, 1) + weights_2 = weights_2.unsqueeze(0).repeat(3, 1, 1) + elif weight_type == "constant": + weights_1 = torch.ones_like(result2_interior) * 0.5 + weights_2 = weights_1 + else: + raise ValueError( + "weight_type should be either 'gaussian' or 'constant' but got", + weight_type, + ) + + result1 = result1 * weights_2 + result2 = result2_interior * weights_1 + + # Average the overlapping region + result1 = result1 + result2 + + # Remove padding + unpadded = result1[:, : h * 4, : w * 4] + + to_pil = transforms.ToPILImage() + return to_pil(unpadded) diff --git a/src/backend/upscale/aura_sr_upscale.py b/src/backend/upscale/aura_sr_upscale.py index 932487c..5bebb1c 100644 --- a/src/backend/upscale/aura_sr_upscale.py +++ b/src/backend/upscale/aura_sr_upscale.py @@ -4,6 +4,6 @@ def upscale_aura_sr(image_path: str): - aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR", device="cpu") + aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2", device="cpu") image_in = Image.open(image_path) # .resize((256, 256)) return aura_sr.upscale_4x(image_in) diff --git a/src/frontend/webui/upscaler_ui.py b/src/frontend/webui/upscaler_ui.py index ec58312..6c3f7fa 100644 --- a/src/frontend/webui/upscaler_ui.py +++ b/src/frontend/webui/upscaler_ui.py @@ -52,7 +52,7 @@ def get_upscaler_ui() -> None: with gr.Row(): upscale_mode = gr.Radio( ["EDSR", "SD", "AURA-SR"], - label="Upscale Mode (2x) | AURA-SR (4x)", + label="Upscale Mode (2x) | AURA-SR v2 (4x)", info="Select upscale method, SD Upscale is experimental", value="EDSR", )