Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

运行MMD时矩阵不一致问题的报错。 #400

Open
sunsenheping opened this issue Jul 18, 2023 · 5 comments
Open

运行MMD时矩阵不一致问题的报错。 #400

sunsenheping opened this issue Jul 18, 2023 · 5 comments

Comments

@sunsenheping
Copy link

在用自己的数据集运行DG库中的MMD算法时,出现了矩阵形状不一样导致无法相加的问题。下图是github中MMD计算的代码:
b71752db2290bfd5703138368d41c05
我的数据,x1和x2的形状均为(32, 600, 64)。考虑到batch问题,我在计算时将addmm换为了addbmm,但是我发现,无论是用二维数据使用addmm还是三维数据使用addbmm,在计算时存在矩阵形状不一样导致无法相加的问题,各个张量的形状如下图所示:
66e83a0c316846b11decc46f064f3c4
可以看到,x2_norm.transpose(-2, -1)的形状与matmul(x1, x2.transpose(-2, -1))的形状不一致,二者是没办法做和相加的。麻烦大佬们看一下,是哪里有了问题?感谢

@jindongwang
Copy link
Owner

@lw0517 有时间来看一下。

@lw0517
Copy link
Collaborator

lw0517 commented Jul 18, 2023

在用自己的数据集运行DG库中的MMD算法时,出现了矩阵形状不一样导致无法相加的问题。下图是github中MMD计算的代码: b71752db2290bfd5703138368d41c05 我的数据,x1和x2的形状均为(32, 600, 64)。考虑到batch问题,我在计算时将addmm换为了addbmm,但是我发现,无论是用二维数据使用addmm还是三维数据使用addbmm,在计算时存在矩阵形状不一样导致无法相加的问题,各个张量的形状如下图所示: 66e83a0c316846b11decc46f064f3c4 可以看到,x2_norm.transpose(-2, -1)的形状与matmul(x1, x2.transpose(-2, -1))的形状不一致,二者是没办法做和相加的。麻烦大佬们看一下,是哪里有了问题?感谢

建议的做法是将三维的转化成二维的进行MMD距离计算。如果用torch.addbmm,最好参考一下https://pytorch.org/docs/1.10/generated/torch.addmm.html?highlight=addmm#torch.addmm, 可以看出来维度是不一致的。

@sunsenheping
Copy link
Author

在用自己的数据集运行DG库中的MMD算法时,出现了矩阵形状不一样导致无法相加的问题。下图是github中MMD计算的代码: b71752db2290bfd5703138368d41c05 我的数据,x1和x2的形状均为(32, 600, 64)。考虑到batch问题,我在计算时将addmm换为了addbmm,但是我发现,无论是用二维数据使用addmm还是三维数据使用addbmm,在计算时存在矩阵形状不一样导致无法相加的问题,各个张量的形状如下图所示: 66e83a0c316846b11decc46f064f3c4 可以看到,x2_norm.transpose(-2, -1)的形状与matmul(x1, x2.transpose(-2, -1))的形状不一致,二者是没办法做和相加的。麻烦大佬们看一下,是哪里有了问题?感谢

建议的做法是将三维的转化成二维的进行MMD距离计算。如果用torch.addbmm,最好参考一下https://pytorch.org/docs/1.10/generated/torch.addmm.html?highlight=addmm#torch.addmm, 可以看出来维度是不一致的。

问题是,我用二维矩阵也试过,按照你们MMD的代码,还是存在形状不一致,addmm无法进行相加的问题的。你也可以推导一下。
image

@sunsenheping
Copy link
Author

如果x2_norm.transpose(-2, -1)的形状与matmul(x1, x2.transpose(-2, -1))的形状不一致,二者是没办法做和相加的,那么你们现在的MMD代码也就无法跑通呀。

@jindongwang
Copy link
Owner

我们是按照batch来算的,一个batch内大家形状是一样的,不明白为什么会出现形状不一样的问题。

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants