-
Notifications
You must be signed in to change notification settings - Fork 34
Implement all the arithmetic Scatter and ResourceScatter operators #121
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Conversation
@@ -304,3 +304,57 @@ const DML_BUFFER_BINDING* DmlKernel::GetPersistentResourceBinding() const { | |||
} | |||
|
|||
} // namespace tensorflow | |||
|
|||
namespace dml { | |||
DML_SCALAR_UNION ScalarTensor(double value, DML_TENSOR_DATA_TYPE data_type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: rename helper function to ScalarUnion
since this doesn't return a tensor
scalar.UInt64 = static_cast<uint64_t>(value); | ||
break; | ||
|
||
case DML_TENSOR_DATA_TYPE_FLOAT32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add DML_TENSOR_DATA_TYPE_FLOAT16?
// | ||
// TODO: ScatterElements (DML_SCATTER_ELEMENTS_OPERATOR_DESC) | ||
// | ||
inline Expression ScatterElements( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This reminds me that we should transition to the copy of DirectMLX.h in the DirectML github repo. I created bug 31071585 for this. In the meantime, make sure to port this change over to GitHub.
Merges some of the recent changes from the directml branch: * Use compute queue for AMD devices (#102) * Register List Kernels for DML (#95) * Update DirectMLX to latest (#104) * Remove extra rows from test email (#106) * Fix DML's Select kernel for int64 (#113) * Fix list kernels and tensor array ops registration (#114) * Simplify CI scripts (#112) * Fix StridedSlice's input size coalescing (#115) * Disable int64 image test (#116) * Fix network share copy path (#117) * Pipeline should continue if a test job fails (#118) * Switch network share path to use build number instead of build ID * Add missing HostMemory int32 registrations for _Arg and _RetVal (#122) * Implement all the arithmetic Scatter and ResourceScatter operators (#121) * Register emulated kernel implementations for RandomStandardNormal and TruncatedNormal (#120)
The arithmetic Scatter and ResourceScatter operators are used quite a lot in the TensorFlow python tests, so they create a lot of false positive failures. Implementing these ups gets rid of about 50% of the failures and should improve performance on some models that we track.