Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Calling :backward() more then rho times should error #400

Open
achalddave opened this issue Mar 4, 2017 · 0 comments
Open

Calling :backward() more then rho times should error #400

achalddave opened this issue Mar 4, 2017 · 0 comments

Comments

@achalddave
Copy link
Contributor

Calling backward more than rho times on an nn.Recurrent module can lead to undesired/undocumented behavior, when really it should be explicitly guarded against and cause an error. See the following code, which

  1. sets rho = 1
  2. Calls forward() 4 times
  3. Calls backward() 3 times

After 4 calls to forward(), there are only 2 clones of the recurrent module in the sharedClones. Calling backward 2 times affects these 2 clones, but a third call to backward will create a clone of the recurrentModule and call backward on it.

require 'nn'
require 'rnn'
local _ = require 'moses'

local model = nn.Recurrent(nn.Identity(), nn.Identity(), nil, nil, 1 --[[rho]])

local input = torch.rand(1)

model:training()
print('Step 1')
model:forward(input)
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 2

print('Step 2')
model:forward(input)
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 2

print('Step 3')
model:forward(input)
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 2, 3

print('Step 4')
model:forward(input)
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 3, 4

print('Step 1 backward')
model:backward(input, input) -- Calls backward on module 4.
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 3, 4

print('Step 2 backward')
model:backward(input, input) -- Calls backward on module 3.
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 3, 4

print('Step 3 backward')
model:backward(input, input) -- Creates new module 2, and calls backward?!
print('Valid sharedClones:', _.keys(model.sharedClones)) -- 2, 3, 4
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant