From 49def20665a819ad203fa86d282af070b577bad4 Mon Sep 17 00:00:00 2001 From: Sachin Prasad Date: Thu, 5 Sep 2024 12:55:09 -0700 Subject: [PATCH] Fix timm conversion for rersnet (#1814) --- keras_nlp/src/utils/timm/convert_resnet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras_nlp/src/utils/timm/convert_resnet.py b/keras_nlp/src/utils/timm/convert_resnet.py index de2224eb9..d1627e057 100644 --- a/keras_nlp/src/utils/timm/convert_resnet.py +++ b/keras_nlp/src/utils/timm/convert_resnet.py @@ -56,6 +56,8 @@ def convert_backbone_config(timm_config): stackwise_num_strides=[1, 2, 2, 2], block_type=block_type, use_pre_activation=use_pre_activation, + input_conv_filters=[64], + input_conv_kernel_sizes=[7], ) @@ -100,10 +102,10 @@ def port_batch_normalization(keras_layer_name, hf_weight_prefix): for stack_index in range(num_stacks): for block_idx in range(backbone.stackwise_num_blocks[stack_index]): if version == "v1": - keras_name = f"v1_stack{stack_index}_block{block_idx}" + keras_name = f"stack{stack_index}_block{block_idx}" hf_name = f"layer{stack_index+1}.{block_idx}" else: - keras_name = f"v2_stack{stack_index}_block{block_idx}" + keras_name = f"stack{stack_index}_block{block_idx}" hf_name = f"stages.{stack_index}.blocks.{block_idx}" if version == "v1":