Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Training models with torch.Tensor input #2736

Open
j-bac opened this issue Apr 22, 2024 · 4 comments
Open

Training models with torch.Tensor input #2736

j-bac opened this issue Apr 22, 2024 · 4 comments
Assignees

Comments

@j-bac
Copy link

j-bac commented Apr 22, 2024

Is your feature request related to a problem? Please describe.
It is not currently straightforward to pass external dataloaders to train a model. In particular, loading torch.Tensor data and directly feeding it to a model as input doesn't seem possible because scvi.data._utils._check_nonnegative_integers does not handle torch.Tensor.

It would be very useful to be able to feed a custom dataloader, dictionary or AnnData as direct input to model.train() without having to copy torch.Tensor back to numpy or pandas. Maybe this can be implemented using model.train(data_module=data_module) ?

Describe the solution you'd like

import torch
import scanpy as sc
import scvi

counts = torch.randint(0,10,(500, 10))

adata = sc.AnnData(scipy.sparse.csr_matrix(counts.shape), #AnnData does not allow torch.Tensor in .X field
                                 layers={'counts':counts})

scvi.model.SCVI.setup_anndata(adata,layer="counts")
model = scvi.model.SCVI(adata)
model.train()
@canergen
Copy link
Member

We cover the enhancement to use custom dataloader in the recent version of scVI-tools.
However, it is not clear yet which minimal checks (integer, gene names) we still want to perform.
About your example: @Intron7: Is this idea of having AnnData in torch recommended? What analysis capabilities are possible in this scenario? I thought this is meant to be done in rapids_singlecell. Does rapids copy back and forth between CPU and GPU or is the full data kept between processing steps on GPU?

@j-bac
Copy link
Author

j-bac commented Apr 22, 2024

Thanks! Is there already a link to an example usage of this new version?
AFAIK rapids_singlecell keeps matrices on GPU without back and forth, (not using torch though) https://rapids-singlecell.readthedocs.io/en/latest/Usage_Principles.html

@Intron7
Copy link
Member

Intron7 commented Apr 22, 2024

We are still talking about how this would work. However at the moment whenever I use rsc I have to transform back to cpu and than use scvi. Rapids-singlecell really wants .X and .layers on the GPU so everything has to be in memory. I would really like if we used DLPack for this. DLPack allows for the 0 copy conversion from cupy and jax to torch.

@martinkim0
Copy link
Contributor

Hi @j-bac, thanks for the suggestion. We will be releasing a tutorial with our next release (v1.2) that covers a basic usecase with a custom dataloader. I'll note that we currently don't support inference methods yet (e.g. get_latent_representation), but it's something we're working on.

@martinkim0 martinkim0 added the P1 label Jul 12, 2024
@martinkim0 martinkim0 added this to the scvi-tools 1.2 milestone Jul 12, 2024
@martinkim0 martinkim0 self-assigned this Jul 12, 2024
@canergen canergen assigned ori-kron-wis and unassigned martinkim0 Jul 26, 2024
# for free to join this conversation on GitHub. Already have an account? # to comment
Projects
None yet
Development

No branches or pull requests

5 participants