Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

LIQE训练过程中loss不收敛,请问是什么问题呢 #229

Open
liuchenxin666 opened this issue Dec 2, 2024 · 4 comments
Open

LIQE训练过程中loss不收敛,请问是什么问题呢 #229

liuchenxin666 opened this issue Dec 2, 2024 · 4 comments

Comments

@liuchenxin666
Copy link

`
model = pyiqa.create_metric('liqe', device=device, pretrained=False)
model.to(device)

pred = model(img)
loss = loss_iqa(pred.squeeze(), mos.float().detach())
`

@chaofengc
Copy link
Owner

chaofengc commented Dec 3, 2024

模型默认采用推理模式,因此disable gradient来节省计算量。如果需要训练需要传入as_loss=True参数pyiqa.create_metric('liqe', device=device, as_loss=True, pretrained=False)

@liuchenxin666
Copy link
Author

但是传入as_loss=True参数时,InferenceModel模型的forward()函数并没有传入mos值,weight_reduce_loss()函数如何计算的loss值。
同时,当as_loss=True时,forward()函数不再返回预测分数,那么我再计算测试集或者验证集上的SROCC,PLCC值时,是否无法继续进行呢?
因此,我不太理解as_loss=True参数的作用,我是否可以直接修改InferenceModel模型的forward()函数使其返回output呢?

    def forward(self, target, ref=None, **kwargs):
        device = self.device

        with torch.set_grad_enabled(self.as_loss):

            if self.metric_name == 'fid':
                output = self.net(target, ref, device=device, **kwargs)
            elif self.metric_name == 'inception_score':
                output = self.net(target, device=device, **kwargs)
            else:
                if not torch.is_tensor(target):
                    target = imread2tensor(target, rgb=True)
                    target = target.unsqueeze(0)
                    if self.metric_mode == 'FR':
                        assert ref is not None, 'Please specify reference image for Full Reference metric'
                        ref = imread2tensor(ref, rgb=True)
                        ref = ref.unsqueeze(0)
                        self.is_valid_input(ref)
                
                self.is_valid_input(target)

                if self.metric_mode == 'FR':
                    assert ref is not None, 'Please specify reference image for Full Reference metric'
                    output = self.net(target.to(device), ref.to(device), **kwargs)
                elif self.metric_mode == 'NR':
                    output = self.net(target.to(device), **kwargs)

        # if self.as_loss:
        #     if isinstance(output, tuple):
        #         output = output[0]
        #     return weight_reduce_loss(output, self.loss_weight, self.loss_reduction)
        # else:
        #     return output
        return output

@liuchenxin666
Copy link
Author

此外,我理解的weight_reduce_loss()函数是在将output一个batch的预测分数求了个平均值?如果我理解的哪里有误,希望能够得到前辈指导,非常感谢!

@chaofengc
Copy link
Owner

chaofengc commented Dec 4, 2024

这里的loss是指用训练好的模型作为一个另一个需要训练的模型(例如图像超分辨率模型)的损失函数。如果你想要训练IQA模型本身,请参考本仓库训练部分的代码

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants