Skip to content

Allow kwargs in TimmUniversalEncoder #954

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Closed
DimitrisMantas opened this issue Oct 19, 2024 · 5 comments · Fixed by #960
Closed

Allow kwargs in TimmUniversalEncoder #954

DimitrisMantas opened this issue Oct 19, 2024 · 5 comments · Fixed by #960

Comments

@DimitrisMantas
Copy link
Contributor

In my opinion, the most attractive use case for certain timm encoders such as ResNet-18, which is also available in torchvision, is that timm generally allows for various additional configuration parameters to be passed to the constructor, such as anti-aliasing, attention, and stochastic depth.

However, smp does not support this this feature at the moment. This is because TimmUniveralEncoder has a local list of kwargs that it passed off to timm and does not accept any others in its initializer.

A very easy fix for this would be to allow the initializer to accept its own kwargs and join them to the corresponding local variable before making any calls to timm.

@JulienMaille
Copy link
Contributor

Sorry for the OT comment but may I ask how you add anti-aliasing to ResNet through timm?

@DimitrisMantas
Copy link
Contributor Author

The ResNet constructor accepts an aa_layer argument (https://github.com/huggingface/pytorch-image-models/blob/310ffa32c5758474b0a4481e5db1494dd419aa23/timm/models/resnet.py#L405), which you can set to timm.layers.BlurPool2d (https://github.com/huggingface/pytorch-image-models/blob/310ffa32c5758474b0a4481e5db1494dd419aa23/timm/layers/blur_pool.py#L20)

@DimitrisMantas
Copy link
Contributor Author

This is what I'm suggesting, the super call should stay as is I think and the local kwargs should be joined to the new kwargs or passed individually to timm.create_model

@JulienMaille
Copy link
Contributor

Have you tried something like this?

class TimmUniversalEncoder(nn.Module):
    def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32, **kwargs):
        super().__init__()
        
        # Initialize default kwargs
        default_kwargs = dict(
            in_chans=in_channels,
            features_only=True,
            output_stride=output_stride,
            pretrained=pretrained,
            out_indices=tuple(range(depth)),
        )
        
        # update with any provided kwargs
        default_kwargs.update(kwargs)

@DimitrisMantas
Copy link
Contributor Author

Not yet, but it should work

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants