-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Add get_dtype
and get_device_type
methods for torch_tensor
#251
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all looks good code-wise @jwallwork23
You are right that it was a tricky review!!
My overarching thought is perhaps we need an example for basic tensor manipulation.
I think that would be a separate task from this PR however. Thoughts?
Now that you explicitly set the device to create the tensor in some overloads I have a question.
What happens if we call this with tensors that are on different devices?
I presume it fails with a meaningful error message from the C++, but does it provide a useful traceback to where the error originated in the Fortran? I recall sometimes libtorch gives an error report, but no code location making it hard to work out where your Fortran is going wrong.
const torch_device_t get_ftorch_device(torch::DeviceType device_type) { | ||
switch (device_type) { | ||
case torch::kCPU: | ||
return torch_kCPU; | ||
case torch::kCUDA: | ||
return torch_kCUDA; | ||
default: | ||
std::cerr << "[ERROR]: device type " << device_type << " not implemented in FTorch" | ||
<< std::endl; | ||
exit(EXIT_FAILURE); | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This will need extending in #209
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it would be worth adding a test for the other functions running on a CUDA device (get rank etc)?
On thinking about this, if they do not map over from CPU to other devices then that is likely an issue with the backends, rather than our code, so would only be serving as a warning that there were problems with the underlying dependencies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such tests would be duplicating those in the CPU tests and then we'd need to do the same for XPU, etc. I think we can lean on the underlying implementation for this, but can add tests if preferred.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nah, I think I agree, and am keen to avoid the verbosity until proven otherwise.
Just wondered what your thoughts were. :)
Apologies!
That's a good idea. I'm increasingly thinking we need to rethink the ordering of the examples. (See #258 (comment).) Such an example should be near the start.
Will check, thanks for pointing out this case. |
Cool, opened #261 |
@jatkinson1000 hm the error isn't so helpful (in fact there isn't even one). Making the modifications in the last commit on 248_get-dtype-devicetype_GPU-test - which attempts to assign a tensor on a CUDA device to a tensor on the CPU - I get the output
That is, it doesn't raise an error at all. So I guess we should build in errors for when you try to apply operator overloads to tensors on different devices. |
Opened #269. Will merge and follow up there. |
* Add dtype and device_type attrs for torch_tensor; implement getters * Rename get_<rank/shape> as torch_tensor_get_<rank/shape> for consistency * Make torch_tensor_get_device_index a class method * Add unit test for torch_tensor_get_device_type on CPU * Add unit test for torch_tensor_get_device_type on CUDA device * Add unit test for torch_tensor_get_dtype * Make use of getters for device type and index * Alias methods to be less verbose * Implement get_device_type on C++ side; introduce get_ftorch_device * Implement get_dtype on C++ side; introduce get_ftorch_dtype * Drop dtype/device type attributes
Closes #248.
This PR adds functions for getting the data type and device type of a tensor, including unit tests. It also improves the consistency of the existing functions so that they are called
torch_tensor_get_X
but are mapped to methods of thetorch_tensor
class as justget_X
. This makes it clearer what we are getting the rank/shape of when we call as functions, and reduces unnecessary verbosity when calling as methods.Using the new utility for getting the device type, the operator overloads involving scalars are corrected so that CPU isn't assumed.