Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix PReLU Broadcasting Bug for Multiple Parameters #565

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

hishambarakat16
Copy link

#################Summary#################
Fixed a bug in the PReLU function in jittor/nn.py where broadcasting the weight parameter caused errors when num_parameters was greater than 1. The previous implementation did not correctly broadcast the weights to match the input dimensions, leading to runtime errors.

#################Changes Made#################
Modified the execute method in PReLU class to correctly broadcast weight parameter for cases where num_parameters is greater than 1.

#################Original Code:#################

def init(self, num_parameters=1, init_=0.25):
self.num_parameters = num_parameters
self.weight = init.constant((num_parameters,), "float32", init_)

def execute(self, x):
if self.num_parameters != 1:
assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU"
return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x)
else:
return jt.maximum(0, x) + self.weight * jt.minimum(0, x)

############Updated Code:##############

def init(self, num_parameters=1, init_=0.25):
self.num_parameters = num_parameters
self.weight = init.constant((num_parameters,), "float32", init_)

def execute(self, x):
if self.num_parameters != 1:
assert self.num_parameters == x.shape[1], f"num_parameters does not match input channels in PReLU"
weight_broadcasted = self.weight.broadcast([x.shape[0], self.num_parameters, *([1] * (len(x.shape) - 2))])
return jt.maximum(0, x) + weight_broadcasted * jt.minimum(0, x)
else:
return jt.maximum(0, x) + self.weight * jt.minimum(0, x)

#################Testing#################
Tested the updated PReLU function with various configurations to ensure proper functionality:

import jittor as jt
from jittor import nn

Create input data with the specified shape

def create_input_data(shape):
num_elements = 1
for dim in shape:
num_elements *= dim
return jt.array(list(range(-num_elements // 2, num_elements // 2)), dtype=jt.float32).reshape(shape)

Test the PReLU activation function

def test_prelu(num_parameters, input_shape):
prelu_layer = nn.PReLU(num_parameters=num_parameters)
input_data = create_input_data(input_shape)
print(f"Testing PReLU with num_parameters={num_parameters} and input_shape={input_shape}")
print(f"Input Data:\n{input_data.numpy()}")
output_data = prelu_layer(input_data)
print(f"Output Data (PReLU):\n{output_data.numpy()}\n")

if name == "main":
test_configs = [
(1, (5,)), # Single parameter
(5, (5, 5)), # Five parameters matching the number of channels
(3, (3, 3)), # Three parameters matching the number of channels
]
for num_parameters, input_shape in test_configs:
test_prelu(num_parameters, input_shape)

#################Test Results:#################

Testing PReLU with num_parameters=1 and input_shape=(5,) Input Data:
[-3. -2. -1. 0. 1.]
Output Data (PReLU):
[-0.75 -0.5 -0.25 0. 1. ]

Testing PReLU with num_parameters=5 and input_shape=(5, 5) Input Data:
[[-13. -12. -11. -10. -9.]
[ -8. -7. -6. -5. -4.]
[ -3. -2. -1. 0. 1.]
[ 2. 3. 4. 5. 6.]
[ 7. 8. 9. 10. 11.]]
Output Data (PReLU):
[[-3.25 -3. -2.75 -2.5 -2.25]
[-2. -1.75 -1.5 -1.25 -1. ]
[-0.75 -0.5 -0.25 0. 1. ]
[ 2. 3. 4. 5. 6. ]
[ 7. 8. 9. 10. 11. ]]

Testing PReLU with num_parameters=3 and input_shape=(3, 3) Input Data:
[[-5. -4. -3.]
[-2. -1. 0.]
[ 1. 2. 3.]]
Output Data (PReLU):
[[-1.25 -1. -0.75]
[-0.5 -0.25 0. ]
[ 1. 2. 3. ]]

##################################
This fix ensures that the PReLU activation function can handle multiple parameters correctly by properly broadcasting the weight parameter to match the input tensor dimensions.

#################Summary#################
Fixed a bug in the PReLU function in jittor/nn.py where broadcasting the weight parameter caused errors when num_parameters was greater than 1. The previous implementation did not correctly broadcast the weights to match the input dimensions, leading to runtime errors.

#################Changes Made#################
Modified the execute method in PReLU class to correctly broadcast weight parameter for cases where num_parameters is greater than 1.

#################Code Changes#################
#################Original Code:#################

def __init__(self, num_parameters=1, init_=0.25):
    self.num_parameters = num_parameters
    self.weight = init.constant((num_parameters,), "float32", init_)

def execute(self, x):
    if self.num_parameters != 1:
        assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU"
        return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x)
    else:
        return jt.maximum(0, x) + self.weight * jt.minimum(0, x)

############Updated Code:##############

def __init__(self, num_parameters=1, init_=0.25):
    self.num_parameters = num_parameters
    self.weight = init.constant((num_parameters,), "float32", init_)

def execute(self, x):
    if self.num_parameters != 1:
        assert self.num_parameters == x.shape[1], f"num_parameters does not match input channels in PReLU"
        weight_broadcasted = self.weight.broadcast([x.shape[0], self.num_parameters, *([1] * (len(x.shape) - 2))])
        return jt.maximum(0, x) + weight_broadcasted * jt.minimum(0, x)
    else:
        return jt.maximum(0, x) + self.weight * jt.minimum(0, x)

#################Testing#################
Tested the updated PReLU function with various configurations to ensure proper functionality:

import jittor as jt
from jittor import nn

# Create input data with the specified shape
def create_input_data(shape):
    num_elements = 1
    for dim in shape:
        num_elements *= dim
    return jt.array(list(range(-num_elements // 2, num_elements // 2)), dtype=jt.float32).reshape(shape)

# Test the PReLU activation function
def test_prelu(num_parameters, input_shape):
    prelu_layer = nn.PReLU(num_parameters=num_parameters)
    input_data = create_input_data(input_shape)
    print(f"Testing PReLU with num_parameters={num_parameters} and input_shape={input_shape}")
    print(f"Input Data:\n{input_data.numpy()}")
    output_data = prelu_layer(input_data)
    print(f"Output Data (PReLU):\n{output_data.numpy()}\n")

if __name__ == "__main__":
    test_configs = [
        (1, (5,)),      # Single parameter
        (5, (5, 5)),    # Five parameters matching the number of channels
        (3, (3, 3)),    # Three parameters matching the number of channels
    ]
    for num_parameters, input_shape in test_configs:
        test_prelu(num_parameters, input_shape)
#################Test Results:#################

Testing PReLU with num_parameters=1 and input_shape=(5,)
Input Data:
[-3. -2. -1.  0.  1.]
Output Data (PReLU):
[-0.75 -0.5  -0.25  0.    1.  ]

Testing PReLU with num_parameters=5 and input_shape=(5, 5)
Input Data:
[[-13. -12. -11. -10.  -9.]
 [ -8.  -7.  -6.  -5.  -4.]
 [ -3.  -2.  -1.   0.   1.]
 [  2.   3.   4.   5.   6.]
 [  7.   8.   9.  10.  11.]]
Output Data (PReLU):
[[-3.25 -3.   -2.75 -2.5  -2.25]
 [-2.   -1.75 -1.5  -1.25 -1.  ]
 [-0.75 -0.5  -0.25  0.    1.  ]
 [ 2.    3.    4.    5.    6.  ]
 [ 7.    8.    9.   10.   11.  ]]

Testing PReLU with num_parameters=3 and input_shape=(3, 3)
Input Data:
[[-5. -4. -3.]
 [-2. -1.  0.]
 [ 1.  2.  3.]]
Output Data (PReLU):
[[-1.25 -1.   -0.75]
 [-0.5  -0.25  0.  ]
 [ 1.    2.    3.  ]]

##################################
This fix ensures that the PReLU activation function can handle multiple parameters correctly by properly broadcasting the weight parameter to match the input tensor dimensions.
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant