Skip to content

Latest commit

 

History

History
32 lines (23 loc) · 790 Bytes

File metadata and controls

32 lines (23 loc) · 790 Bytes

Extending PyTorch in Python

In this folder, I show how to add new functions to autograd.

For this purpose, we create a class with a new Function subclass and two static methods called forward and backward as

class NewOperator(nn.autograd.Function):
    @staticmethod
    def forward(ctx, inputs):
        ...
        return outputs
    
    @staticmethod
    def backward(ctx, grad_outputs):
        ...
        return grads

Then, we create a function to call this class using method apply as

def new_operator(inputs):
    return NewOperator().apply(inputs)

Reference