-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
fixed the gamma computation error #79
Conversation
Can one of the admins verify this patch? |
Is the only code modification from:
to:
I'm looking at the function signature of grad:
The change is very subtle. What are the components of It looks good, but would be nice to clarify that pytorch operation. |
You can test it out by following code import torch
from torch.autograd import grad
'''
z = (xy)^2
x = 3, y =2
first order deriv [24 36]
d2z/dx2 = 8
d2z/dxdy = 24
d2z/dy2 = 18
'''
inputs = torch.tensor([3.0,2.0], requires_grad=True)
z = (inputs[0]*inputs[1])**2
first_order_grad = grad(z, inputs, create_graph=True)
second_order_grad_original, = grad(first_order_grad[0], inputs,
torch.ones_like(first_order_grad[0]), retain_graph=True) # Does not give expected answer
second_order_grad_x, = grad(first_order_grad[0][0], inputs, retain_graph=True) #
second_order_grad_y, = grad(first_order_grad[0][1], inputs) the old code is to calcuate the sum of gradients with respect to the parameters, i.e. (d2z/dx2 + d2z/dxdy, d2z/dy2 + d2z/dxdy) |
That's an excellent simplified example.
Now I understand a lot better what's going on. Could you add a comment explaining why you are specifying the access of the arrays using index 2?
I don't know of a good succinct comment to explain it, but at least pointing out that S0 corresponds to index 2 of the inputs tensor and that's why you're indexing with 2. Maybe having a separate cell in the notebook demonstrating how |
I added the example into the notebook |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks.
The old code computes the sum of second-order differentiation, which is wrong. This one fixed it.