Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Allow nested batch_arg_name in BatchSizeFinder/Tuner.scale_batch_size() #20560

Open
ibro45 opened this issue Jan 23, 2025 · 1 comment
Open
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers

Comments

@ibro45
Copy link

ibro45 commented Jan 23, 2025

Description & Motivation

Hi,

Would it be possible to allow dot-notation in, for example, tuner.scale_batch_size(model, batch_arg_name="dataloaders.train.batch_size")?

Perhaps some other features would benefit from this. It should be simple to achieve that through lightning_hasattr and lightning_getattr.

Pitch

No response

Alternatives

No response

Additional context

No response

cc @lantiga @Borda

@ibro45 ibro45 added feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers labels Jan 23, 2025
@ibro45
Copy link
Author

ibro45 commented Jan 24, 2025

I added this to my class Model(LightningModule) that allows for that behavior, but I don't think this is ideal.


    def __getattr__(self, name: str) -> Any:
        """
        Override __getattr__ to handle dot-delimited attribute access.

        Args:
            name (str): The attribute name, potentially dot-delimited.

        Returns:
            Any: The value of the nested attribute.

        Raises:
            AttributeError: If the attribute path does not exist.
        """
        if '.' in name:
            try:
                attrs = name.split('.')
                obj = self
                for attr in attrs:
                    obj = getattr(obj, attr)
                return obj
            except AttributeError as e:
                raise AttributeError(f"Attribute path '{name}' not found.") from e
        # If no dot, defer to the default behavior
        return super().__getattr__(name)
    
    def __setattr__(self, name: str, value: Any) -> None:
        """
        Override __setattr__ to handle dot-delimited attribute setting.

        Args:
            name (str): The attribute name, potentially dot-delimited.
            value (Any): The value to set.

        Raises:
            AttributeError: If the attribute path does not exist.
        """
        if '.' in name:
            attrs = name.split('.')
            obj = self
            for attr in attrs[:-1]:
                try:
                    obj = getattr(obj, attr)
                except AttributeError as e:
                    raise AttributeError(f"Attribute path '{'.'.join(attrs[:-1])}' not found.") from e
            setattr(obj, attrs[-1], value)
        else:
            # Use the default behavior for single attributes
            super().__setattr__(name, value)

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
feature Is an improvement or enhancement needs triage Waiting to be triaged by maintainers
Projects
None yet
Development

No branches or pull requests

1 participant