Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

shuffle stacked loss #3

Open
cfifty opened this issue Jun 11, 2020 · 4 comments
Open

shuffle stacked loss #3

cfifty opened this issue Jun 11, 2020 · 4 comments

Comments

@cfifty
Copy link

cfifty commented Jun 11, 2020

Consider replacing

tf.random.shuffle(loss)

with

loss = tf.gather(loss, tf.random.shuffle(tf.range(tf.shape(loss)[0])))

@luzai
Copy link

luzai commented Aug 16, 2020

Hi @cfifty , May I ask why not replacing with loss=tf.random.shuffle(loss)?

@cfifty
Copy link
Author

cfifty commented Aug 16, 2020

  1. In non-eager mode, tf.random.shuffle(loss) is never called, so the loss list is not shuffled if you use graph mode TensorFlow.

  2. If you use loss=tf.random.shuffle(loss), the backwards pass of tf.random.shuffle is not defined. Thus, you can't compute gradients through this operation and an error is thrown. See https://stackoverflow.com/questions/55701407/how-to-shuffle-tensor-in-tensorflow-errorno-gradient-defined-for-operation-ra for additional context.

@luzai
Copy link

luzai commented Aug 16, 2020

Thank you very much for your detailed explanation!

  1. The loss list is not shuffled if using tf.random.shuffle(loss). For the reason, I think, tf.random.shuffle is not an inplace operation, and thus the input argument loss is not shuffled.
  2. It seems loss=tf.random.shuffle(loss) do not throw an error with tf 1.15.3. Maybe in the new version, the gradient for this operation is registered.
    Overall, I think loss = tf.gather(loss, tf.random.shuffle(tf.range(tf.shape(loss)[0]))) is a greater choice for compatibility,

@luzai
Copy link

luzai commented Aug 19, 2020

I am sorry that I made a mistake, the gradient operation is still not defined for loss=tf.random.shuffle(loss) in tf 1.15.3.
image

We should consider use loss = tf.gather(loss, tf.random.shuffle(tf.range(tf.shape(loss)[0]))) instead.

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants