Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

apply_to_collection doesn't work for cached properties #279

Open
jackdent opened this issue Jul 3, 2024 · 3 comments
Open

apply_to_collection doesn't work for cached properties #279

jackdent opened this issue Jul 3, 2024 · 3 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@jackdent
Copy link

jackdent commented Jul 3, 2024

Motivation

When running apply_to_collection on a dataclass, cached properties do not get modified. This can cause subtle issues: for example, suppose I initialize a dataclass on CPU in a dataworker, and then move it onto GPU for a model batch. All of the dataclass fields that contain Tensors get moved correctly, but the cached_propertys continue to residue on the original device.

Steps to reproduce

import dataclasses
from functools import cached_property

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor


@dataclasses.dataclass
class Data:
    a: Tensor

    @cached_property
    def b(self):
        print("*" * 10)
        print("Computing and cache prop b")
        print("*" * 10)
        return self.a * 2


print("*" * 10)
print("Data on CPU")
print("*" * 10)

data = Data(a=torch.tensor([1, 2, 3], device="cuda"))
print(f"{data.a=}")
print(f"{data.a.device=}")

print(f"{data.b=}")
print(f"{data.b=}")  # do this a second time to make sure we're caching it
print(f"{data.b.device=}")

print("*" * 10)
print("Move Data to GPU")
print("*" * 10)

new_data = apply_to_collection(data, Tensor, lambda x: x.to("cpu"))
print(f"{new_data.a=}")
print(f"{new_data.a.device=}")

print(f"{new_data.b=}")
print(f"{new_data.b=}")  # do this a second time to make sure we're caching it
print(f"{new_data.b.device=}")

Yields the following output:

**********
Start with data on GPU
**********
data.a=tensor([1, 2, 3], device='cuda:0')
data.a.device=device(type='cuda', index=0)
**********
Computing and cache prop b
**********
data.b=tensor([2, 4, 6], device='cuda:0')
data.b=tensor([2, 4, 6], device='cuda:0')
data.b.device=device(type='cuda', index=0)
**********
Move Data to CPU
**********
new_data.a=tensor([1, 2, 3])
new_data.a.device=device(type='cpu')
new_data.b=tensor([2, 4, 6], device='cuda:0')
new_data.b=tensor([2, 4, 6], device='cuda:0')
new_data.b.device=device(type='cuda', index=0)
@jackdent jackdent added enhancement New feature or request help wanted Extra attention is needed labels Jul 3, 2024
@jackdent
Copy link
Author

jackdent commented Jul 3, 2024

The Lightning apply_to_collection logic is defined here and relies on dataclass.fields, which doesn't include cached properties

@Borda
Copy link
Member

Borda commented Jul 13, 2024

@awaelchli, do you have any experience with this one?

@Borda Borda changed the title apply_to_collection doesn't work for cached properties apply_to_collection doesn't work for cached properties Jul 13, 2024
@awaelchli
Copy link
Contributor

Hey @jackdent
This is a rare use case and I won't have the bandwidth to look into it. We would be grateful for a contribution here if you're interested. The fix is probably to just reset the cache when running apply_to_collection.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants