Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

I would like to express the MIN2Net model in pytorch, is that correct? #4

Open
Buddies-as-you-know opened this issue Nov 22, 2022 · 0 comments

Comments

@Buddies-as-you-know
Copy link

Buddies-as-you-know commented Nov 22, 2022

Only the MI-EEG classification part was expressed in pytorch. Does it match?

class Conv2D_Norm_Constrained(nn.Conv2d):
    def __init__(self, max_norm_val, norm_dim, **kwargs):
        super().__init__(**kwargs)
        self.max_norm_val = max_norm_val
        self.norm_dim = norm_dim

    def get_constrained_weights(self, epsilon=1e-8):
        norm = self.weight.norm(2, dim=self.norm_dim, keepdim=True)
        return self.weight * (torch.clamp(norm, 0, self.max_norm_val) / (norm + epsilon))

    def forward(self, input):
        return F.conv2d(input, self.get_constrained_weights(), self.bias, self.stride, self.padding, self.dilation, self.groups)

class ConstrainedLinear(nn.Linear):
    def forward(self, input):
        return F.linear(input, self.weight.clamp(min=-1.0, max=0.5), self.bias)
class MinNet(nn.Module): # input = (1,16,125)
  def __init__(self, input_shape=(1,400,20)):
    super().__init__()
    self.D, self.T, self.C = input_shape
    self.subsampling_size = 100
    self.pool_size_1 = (1,self.T//self.subsampling_size)
    self.en_conv = nn.Sequential(
                    Conv2D_Norm_Constrained(in_channels=1, out_channels=16, kernel_size=(1, 64), padding="same", max_norm_val=2.0, norm_dim=(0, 1, 2)),
                    nn.ELU(),
                    nn.BatchNorm2d(16,eps=1e-05, momentum=0.1),
                    nn.AvgPool2d((1,self.pool_size_1)),
                    nn.Flatten(),
                    ConstrainedLinear(32000,64),
                    nn.ELU(),
                    ConstrainedLinear(64,3)
                )
  def forward(self,x):
      x = self.en_conv(x)
      return x
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant