Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add aura sr2 support #227

Merged
merged 2 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ If we enable Tiny decoder(TAESD) we can save some memory(2GB approx) for example
- Experimental support for single file Safetensors SD 1.5 models(Civitai models), simply add local model path to configs/stable-diffusion-models.txt file.
- Add REST API support
- Add Aura SR (4x)/GigaGAN based upscaler support
- Add Aura SR v2 upscaler support

<a id="fast-inference-benchmarks"></a>

Expand Down
212 changes: 191 additions & 21 deletions src/backend/upscale/aura_sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from torchvision.utils import save_image
import math


def get_same_padding(size, kernel, dilation, stride):
Expand Down Expand Up @@ -186,6 +188,7 @@ def null_iterator():
while True:
yield None


def Downsample(dim, dim_out=None):
return nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
Expand Down Expand Up @@ -385,6 +388,7 @@ def forward(self, x):

return x


class EqualLinear(nn.Module):
def __init__(self, dim, dim_out, lr_mul=1, bias=True):
super().__init__()
Expand Down Expand Up @@ -714,11 +718,49 @@ def tile_image(image, chunk_size=64):
tiles = []
for i in range(h_chunks):
for j in range(w_chunks):
tile = image[:, i * chunk_size:(i + 1) * chunk_size, j * chunk_size:(j + 1) * chunk_size]
tile = image[
:,
i * chunk_size : (i + 1) * chunk_size,
j * chunk_size : (j + 1) * chunk_size,
]
tiles.append(tile)
return tiles, h_chunks, w_chunks


# This helps create a checkboard pattern with some edge blending
def create_checkerboard_weights(tile_size):
x = torch.linspace(-1, 1, tile_size)
y = torch.linspace(-1, 1, tile_size)

x, y = torch.meshgrid(x, y, indexing="ij")
d = torch.sqrt(x * x + y * y)
sigma, mu = 0.5, 0.0
weights = torch.exp(-((d - mu) ** 2 / (2.0 * sigma**2)))

# saturate the values to sure get high weights in the center
weights = weights**8

return weights / weights.max() # Normalize to [0, 1]


def repeat_weights(weights, image_size):
tile_size = weights.shape[0]
repeats = (
math.ceil(image_size[0] / tile_size),
math.ceil(image_size[1] / tile_size),
)
return weights.repeat(repeats)[: image_size[0], : image_size[1]]


def create_offset_weights(weights, image_size):
tile_size = weights.shape[0]
offset = tile_size // 2
full_weights = repeat_weights(
weights, (image_size[0] + offset, image_size[1] + offset)
)
return full_weights[offset:, offset:]


def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64):
# Determine the shape of the output tensor
c = tiles[0].shape[0]
Expand All @@ -737,7 +779,7 @@ def merge_tiles(tiles, h_chunks, w_chunks, chunk_size=64):
w_start = j * chunk_size

tile_h, tile_w = tile.shape[1:]
merged[:, h_start:h_start+tile_h, w_start:w_start+tile_w] = tile
merged[:, h_start : h_start + tile_h, w_start : w_start + tile_w] = tile

return merged

Expand All @@ -748,7 +790,12 @@ def __init__(self, config: dict[str, Any], device: str = "cuda"):
self.input_image_size = config["input_image_size"]

@classmethod
def from_pretrained(cls, model_id: str = "fal-ai/AuraSR",device: str="cuda",use_safetensors: bool = True):
def from_pretrained(
cls,
model_id: str = "fal-ai/AuraSR",
use_safetensors: bool = True,
device: str = "cuda",
):
import json
import torch
from pathlib import Path
Expand All @@ -757,15 +804,17 @@ def from_pretrained(cls, model_id: str = "fal-ai/AuraSR",device: str="cuda",use_
# Check if model_id is a local file
if Path(model_id).is_file():
local_file = Path(model_id)
if local_file.suffix == '.safetensors':
if local_file.suffix == ".safetensors":
use_safetensors = True
elif local_file.suffix == '.ckpt':
elif local_file.suffix == ".ckpt":
use_safetensors = False
else:
raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.")

raise ValueError(
f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files."
)

# For local files, we need to provide the config separately
config_path = local_file.with_name('config.json')
config_path = local_file.with_name("config.json")
if not config_path.exists():
raise FileNotFoundError(
f"Config file not found: {config_path}. "
Expand All @@ -774,27 +823,38 @@ def from_pretrained(cls, model_id: str = "fal-ai/AuraSR",device: str="cuda",use_
f"If you're trying to load a model from Hugging Face, "
f"please provide the model ID instead of a file path."
)

config = json.loads(config_path.read_text())
hf_model_path = local_file.parent
else:
hf_model_path = Path(snapshot_download(model_id))
hf_model_path = Path(
snapshot_download(model_id, ignore_patterns=["*.ckpt"])
)
config = json.loads((hf_model_path / "config.json").read_text())

model = cls(config,device)
model = cls(config, device)

if use_safetensors:
try:
from safetensors.torch import load_file
checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id)

checkpoint = load_file(
hf_model_path / "model.safetensors"
if not Path(model_id).is_file()
else model_id
)
except ImportError:
raise ImportError(
"The safetensors library is not installed. "
"Please install it with `pip install safetensors` "
"or use `use_safetensors=False` to load the model with PyTorch."
)
else:
checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id)
checkpoint = torch.load(
hf_model_path / "model.ckpt"
if not Path(model_id).is_file()
else model_id
)

model.upsampler.load_state_dict(checkpoint, strict=True)
return model
Expand All @@ -806,29 +866,139 @@ def upscale_4x(self, image: Image.Image, max_batch_size=8) -> Image.Image:

image_tensor = tensor_transform(image).unsqueeze(0)
_, _, h, w = image_tensor.shape
pad_h = (self.input_image_size - h % self.input_image_size) % self.input_image_size
pad_w = (self.input_image_size - w % self.input_image_size) % self.input_image_size
pad_h = (
self.input_image_size - h % self.input_image_size
) % self.input_image_size
pad_w = (
self.input_image_size - w % self.input_image_size
) % self.input_image_size

# Pad the image
image_tensor = torch.nn.functional.pad(image_tensor, (0, pad_w, 0, pad_h), mode='reflect').squeeze(0)
image_tensor = torch.nn.functional.pad(
image_tensor, (0, pad_w, 0, pad_h), mode="reflect"
).squeeze(0)
tiles, h_chunks, w_chunks = tile_image(image_tensor, self.input_image_size)

# Batch processing of tiles
num_tiles = len(tiles)
batches = [tiles[i:i + max_batch_size] for i in range(0, num_tiles, max_batch_size)]
batches = [
tiles[i : i + max_batch_size] for i in range(0, num_tiles, max_batch_size)
]
reconstructed_tiles = []

for batch in batches:
model_input = torch.stack(batch).to(device)
generator_output = self.upsampler(
lowres_image=model_input,
noise=torch.randn(model_input.shape[0], 128, device=device)
noise=torch.randn(model_input.shape[0], 128, device=device),
)
reconstructed_tiles.extend(
list(generator_output.clamp_(0, 1).detach().cpu())
)
reconstructed_tiles.extend(list(generator_output.clamp_(0, 1).detach().cpu()))

merged_tensor = merge_tiles(reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4)
unpadded = merged_tensor[:, :h * 4, :w * 4]
merged_tensor = merge_tiles(
reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4
)
unpadded = merged_tensor[:, : h * 4, : w * 4]

to_pil = transforms.ToPILImage()
return to_pil(unpadded)

# Tiled 4x upscaling with overlapping tiles to reduce seam artifacts
# weights options are 'checkboard' and 'constant'
@torch.no_grad()
def upscale_4x_overlapped(self, image, max_batch_size=8, weight_type="checkboard"):
tensor_transform = transforms.ToTensor()
device = self.upsampler.device

image_tensor = tensor_transform(image).unsqueeze(0)
_, _, h, w = image_tensor.shape

# Calculate paddings
pad_h = (
self.input_image_size - h % self.input_image_size
) % self.input_image_size
pad_w = (
self.input_image_size - w % self.input_image_size
) % self.input_image_size

# Pad the image
image_tensor = torch.nn.functional.pad(
image_tensor, (0, pad_w, 0, pad_h), mode="reflect"
).squeeze(0)

# Function to process tiles
def process_tiles(tiles, h_chunks, w_chunks):
num_tiles = len(tiles)
batches = [
tiles[i : i + max_batch_size]
for i in range(0, num_tiles, max_batch_size)
]
reconstructed_tiles = []

for batch in batches:
model_input = torch.stack(batch).to(device)
generator_output = self.upsampler(
lowres_image=model_input,
noise=torch.randn(model_input.shape[0], 128, device=device),
)
reconstructed_tiles.extend(
list(generator_output.clamp_(0, 1).detach().cpu())
)

return merge_tiles(
reconstructed_tiles, h_chunks, w_chunks, self.input_image_size * 4
)

# First pass
tiles1, h_chunks1, w_chunks1 = tile_image(image_tensor, self.input_image_size)
result1 = process_tiles(tiles1, h_chunks1, w_chunks1)

# Second pass with offset
offset = self.input_image_size // 2
image_tensor_offset = torch.nn.functional.pad(
image_tensor, (offset, offset, offset, offset), mode="reflect"
).squeeze(0)

tiles2, h_chunks2, w_chunks2 = tile_image(
image_tensor_offset, self.input_image_size
)
result2 = process_tiles(tiles2, h_chunks2, w_chunks2)

# unpad
offset_4x = offset * 4
result2_interior = result2[:, offset_4x:-offset_4x, offset_4x:-offset_4x]

if weight_type == "checkboard":
weight_tile = create_checkerboard_weights(self.input_image_size * 4)

weight_shape = result2_interior.shape[1:]
weights_1 = create_offset_weights(weight_tile, weight_shape)
weights_2 = repeat_weights(weight_tile, weight_shape)

normalizer = weights_1 + weights_2
weights_1 = weights_1 / normalizer
weights_2 = weights_2 / normalizer

weights_1 = weights_1.unsqueeze(0).repeat(3, 1, 1)
weights_2 = weights_2.unsqueeze(0).repeat(3, 1, 1)
elif weight_type == "constant":
weights_1 = torch.ones_like(result2_interior) * 0.5
weights_2 = weights_1
else:
raise ValueError(
"weight_type should be either 'gaussian' or 'constant' but got",
weight_type,
)

result1 = result1 * weights_2
result2 = result2_interior * weights_1

# Average the overlapping region
result1 = result1 + result2

# Remove padding
unpadded = result1[:, : h * 4, : w * 4]

to_pil = transforms.ToPILImage()
return to_pil(unpadded)
2 changes: 1 addition & 1 deletion src/backend/upscale/aura_sr_upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@

def upscale_aura_sr(image_path: str):

aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR", device="cpu")
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2", device="cpu")
image_in = Image.open(image_path) # .resize((256, 256))
return aura_sr.upscale_4x(image_in)
2 changes: 1 addition & 1 deletion src/frontend/webui/upscaler_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_upscaler_ui() -> None:
with gr.Row():
upscale_mode = gr.Radio(
["EDSR", "SD", "AURA-SR"],
label="Upscale Mode (2x) | AURA-SR (4x)",
label="Upscale Mode (2x) | AURA-SR v2 (4x)",
info="Select upscale method, SD Upscale is experimental",
value="EDSR",
)
Expand Down