From 3252142c16a767627f3d27ff034efff3d7d8dd56 Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Mon, 14 Nov 2022 15:59:48 -0500 Subject: [PATCH 1/2] changes per Patrik's comments --- src/diffusers/models/embeddings.py | 2 +- src/diffusers/models/resnet.py | 12 ++++-------- src/diffusers/models/unet_1d.py | 10 +++++----- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b5356a2bc94e..0221d891f171 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -69,7 +69,7 @@ def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", self.act = None if act_fn == "silu": self.act = nn.SiLU() - if act_fn == "mish": + elif act_fn == "mish": self.act = nn.Mish() if out_dim is not None: diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 99b6092aed06..52d056ae96fb 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -523,13 +523,9 @@ def forward(self, x): class ResidualTemporalBlock1D(nn.Module): def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): super().__init__() + self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) + self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) - self.blocks = nn.ModuleList( - [ - Conv1dBlock(inp_channels, out_channels, kernel_size), - Conv1dBlock(out_channels, out_channels, kernel_size), - ] - ) self.time_emb_act = nn.Mish() self.time_emb = nn.Linear(embed_dim, out_channels) @@ -548,8 +544,8 @@ def forward(self, x, t): """ t = self.time_emb_act(t) t = self.time_emb(t) - out = self.blocks[0](x) + rearrange_dims(t) - out = self.blocks[1](out) + out = self.conv_in(x) + rearrange_dims(t) + out = self.conv_out(out) return out + self.residual_conv(x) diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index c4fa92275fe6..c974d0a82cb6 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -77,7 +77,7 @@ def __init__( time_embedding_type: str = "fourier", flip_sin_to_cos: bool = True, use_timestep_embedding: bool = False, - downscale_freq_shift: float = 0.0, + freq_shift: float = 0.0, down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), mid_block_type: Tuple[str] = "UNetMidBlock1D", @@ -86,7 +86,7 @@ def __init__( act_fn: str = None, norm_num_groups: int = 8, layers_per_block: int = 1, - always_downsample: bool = False, + downsample_each_block: bool = False, ): super().__init__() self.sample_size = sample_size @@ -99,7 +99,7 @@ def __init__( timestep_input_dim = 2 * block_out_channels[0] elif time_embedding_type == "positional": self.time_proj = Timesteps( - block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=downscale_freq_shift + block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift ) timestep_input_dim = block_out_channels[0] @@ -134,7 +134,7 @@ def __init__( in_channels=input_channel, out_channels=output_channel, temb_channels=block_out_channels[0], - add_downsample=not is_final_block or always_downsample, + add_downsample=not is_final_block or downsample_each_block, ) self.down_blocks.append(down_block) @@ -146,7 +146,7 @@ def __init__( out_channels=block_out_channels[-1], embed_dim=block_out_channels[0], num_layers=layers_per_block, - add_downsample=always_downsample, + add_downsample=downsample_each_block, ) # up From edc2f1c8d448e9e523da7ea16a5f3aa771b49cd4 Mon Sep 17 00:00:00 2001 From: Ben Glickenhaus Date: Mon, 14 Nov 2022 16:16:46 -0500 Subject: [PATCH 2/2] update conversion script --- .../convert_models_diffuser_to_diffusers.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/scripts/convert_models_diffuser_to_diffusers.py b/scripts/convert_models_diffuser_to_diffusers.py index 4b4608358c17..9475f7da93fb 100644 --- a/scripts/convert_models_diffuser_to_diffusers.py +++ b/scripts/convert_models_diffuser_to_diffusers.py @@ -29,6 +29,19 @@ def unet(hor): block_out_channels=block_out_channels, up_block_types=up_block_types, layers_per_block=1, + use_timestep_embedding=True, + out_block_type="OutConv1DBlock", + norm_num_groups=8, + downsample_each_block=False, + in_channels=14, + out_channels=14, + extra_in_channels=0, + time_embedding_type="positional", + flip_sin_to_cos=False, + freq_shift=1, + sample_size=65536, + mid_block_type="MidResTemporalBlock1D", + act_fn="mish", ) hf_value_function = UNet1DModel(**config) print(f"length of state dict: {len(state_dict.keys())}") @@ -52,7 +65,16 @@ def value_function(): mid_block_type="ValueFunctionMidBlock1D", block_out_channels=(32, 64, 128, 256), layers_per_block=1, - always_downsample=True, + downsample_each_block=True, + sample_size=65536, + out_channels=14, + extra_in_channels=0, + time_embedding_type="positional", + use_timestep_embedding=True, + flip_sin_to_cos=False, + freq_shift=1, + norm_num_groups=8, + act_fn="mish", ) model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")