Skip to content

Files

Latest commit

 

History

History

example_extend_in_python

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 

Extending PyTorch in Python

In this folder, I show how to add new functions to autograd.

For this purpose, we create a class with a new Function subclass and two static methods called forward and backward as

class NewOperator(nn.autograd.Function):
    @staticmethod
    def forward(ctx, inputs):
        ...
        return outputs
    
    @staticmethod
    def backward(ctx, grad_outputs):
        ...
        return grads

Then, we create a function to call this class using method apply as

def new_operator(inputs):
    return NewOperator().apply(inputs)

Reference