Skip to content

Update KJT stride calculation logic to be based off of inverse_indices for VBE KJTs. #2949

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

jd7-tr
Copy link
Contributor

@jd7-tr jd7-tr commented May 6, 2025

Summary:
Update the _maybe_compute_stride_kjt logic to calculate stride based off of inverse_indices for VBE KJTs.

Currently, stride of VBE KJT with stride_per_key_per_rank is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of inverse_indices. This causes issues in IR module serialization: debug doc.

Differential Revision: D74273083

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 6, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D74273083

jd7-tr added a commit to jd7-tr/torchrec that referenced this pull request May 6, 2025
…s for VBE KJTs. (pytorch#2949)

Summary:

Update the `_maybe_compute_stride_kjt` logic to calculate stride based off of `inverse_indices` for VBE KJTs.

Currently, stride of VBE KJT with `stride_per_key_per_rank` is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of inverse_indices. This causes issues in IR module serialization: debug doc.

Differential Revision: D74273083
jd7-tr added a commit to jd7-tr/torchrec that referenced this pull request May 7, 2025
…s for VBE KJTs. (pytorch#2949)

Summary:

Update the `_maybe_compute_stride_kjt` logic to calculate stride based off of `inverse_indices` for VBE KJTs.

Currently, stride of VBE KJT with `stride_per_key_per_rank` is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of inverse_indices. This causes issues in IR module serialization: debug doc.

Differential Revision: D74273083
@jd7-tr jd7-tr force-pushed the export-D74273083 branch from b9ebfbb to 3304438 Compare May 7, 2025 00:29
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D74273083

jd7-tr added a commit to jd7-tr/torchrec that referenced this pull request May 7, 2025
…s for VBE KJTs. (pytorch#2949)

Summary:

Update the `_maybe_compute_stride_kjt` logic to calculate stride based off of `inverse_indices` for VBE KJTs.

Currently, stride of VBE KJT with `stride_per_key_per_rank` is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of inverse_indices. This causes issues in IR module serialization: debug doc.

Differential Revision: D74273083
jd7-tr added a commit to jd7-tr/torchrec that referenced this pull request May 7, 2025
…s for VBE KJTs. (pytorch#2949)

Summary:

Update the `_maybe_compute_stride_kjt` logic to calculate stride based off of `inverse_indices` for VBE KJTs.

Currently, stride of VBE KJT with `stride_per_key_per_rank` is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of inverse_indices. This causes issues in IR module serialization: debug doc.

Differential Revision: D74273083
TroyGarden and others added 2 commits May 9, 2025 10:15
Summary:
Pull Request resolved: pytorch#2959

# context
* this diff is part of the "variable-batch KJT refactoring" project ([doc](https://fburl.com/gdoc/svfysfai))
* previously the `stride_per_key_per_rank` variable is `List[List[int]] | None` which can't be handled correctly in PT2 IR (torch.export)
* this change makes the KJT class variable `_stride_per_key_per_rank` as `torch.IntTensor | None` so it would be compatible with PT2 IR.

# equivalency
* to check if `self._stride_per_key_per_rank` is `None`
this logic is used to differentiate variable_batch case, and should have the same behavior after this diff
* to use `self._stride_per_key_per_rank` as `List[List[int]]`
most of the callsite use the function to get the list: `def stride_per_key_per_rank(self) -> List[List[int]]:`, and this function is modified to covert the `torch.IntTensor` to list as ` _stride_per_key_per_rank.tolist()`, the results should be the same

NOTE: this `self. _stride_per_key_per_rank.tolist()` tensor should always be on CPU since it's effective the meta data of a KJT. For generic torch APIs like `.to(...)`, `record_stream()`, etc. should in general avoid altering this variable.

Differential Revision: D74366343

Reviewed By: jd7-tr
…s for VBE KJTs. (pytorch#2949)

Summary:
Pull Request resolved: pytorch#2949

Update the `_maybe_compute_stride_kjt` logic to calculate stride based off of `inverse_indices` for VBE KJTs.

Currently, stride of VBE KJT with `stride_per_key_per_rank` is calculated as the max "stride per key". This is different from the batch size of the EBC output KeyedTensor which is based off of inverse_indices. This causes issues in IR module serialization: debug doc.

Differential Revision: D74273083
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D74273083

@jd7-tr jd7-tr force-pushed the export-D74273083 branch from 3304438 to bce1d55 Compare May 13, 2025 17:21
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants