Skip to content

Commit 1826dce

Browse files
sachinprasadhsmattdangerw
authored andcommitted
Fix timm conversion for rersnet (#1814)
1 parent eaff91b commit 1826dce

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

keras_nlp/src/utils/timm/convert_resnet.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def convert_backbone_config(timm_config):
5555
stackwise_num_strides=[1, 2, 2, 2],
5656
block_type=block_type,
5757
use_pre_activation=use_pre_activation,
58+
input_conv_filters=[64],
59+
input_conv_kernel_sizes=[7],
5860
)
5961

6062

@@ -99,10 +101,10 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix):
99101
for stack_index in range(num_stacks):
100102
for block_idx in range(backbone.stackwise_num_blocks[stack_index]):
101103
if version == "v1":
102-
keras_name = f"v1_stack{stack_index}_block{block_idx}"
104+
keras_name = f"stack{stack_index}_block{block_idx}"
103105
hf_name = f"layer{stack_index+1}.{block_idx}"
104106
else:
105-
keras_name = f"v2_stack{stack_index}_block{block_idx}"
107+
keras_name = f"stack{stack_index}_block{block_idx}"
106108
hf_name = f"stages.{stack_index}.blocks.{block_idx}"
107109

108110
if version == "v1":

0 commit comments

Comments
 (0)