You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ifweightisnotNone:
# trans weight from class to sample, shape:N or [N,H,W] for 1d and 2d cases.ifsoft_label:
# chajchaj:# weight's shape is C, where C is class num.# for 1d case: label's shape is [N,C], weight_gather's shape is N.# for 2d case: label's shape is [N,H,W,C], weight_gather's shape is [N,H,W].weight_gather=paddle.matmul(
x=paddle.cast(label, weight.dtype),
y=weight,
transpose_x=False,
transpose_y=True,
)
out_shape=list(out.shape)
weight_gather_reshape=reshape(weight_gather, shape=out_shape)
out=paddle.cast(out, weight_gather_reshape.dtype)
out=_C_ops.multiply(out, weight_gather_reshape)
文档链接&描述 Document Links & Description
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/CrossEntropyLoss_cn.html
因为
目前 paddle 的文档里写的是
$$loss_j = loss_j^{\prime} \cdot \sum_{i} (weight[label_i] \cdot logits_i)$$
$$loss_j = loss_j^{\prime} \cdot \sum_{i=0}^{C-1} label_i \cdot weight_i$$
$$loss_j^{\prime} = -\sum_{i=0}^{C-1} label_i \cdot (logits_i - \log(\sum_{k=0}^{C-1} \exp(logits_k)))$$
$loss_j^{\prime}$ 是代码中的
$$loss_j = -\sum_{i=0}^{C-1}weight_i \cdot label_i \cdot (logits_i - \log(\sum_{k=0}^{C-1} \exp(logits_k)))$$
如果从代码里面看的话这块的处理应该是
out
, shape 是[N]
或者[N,H,W]
有点不理解这块,如果按torch的文档这里应该是
请提出你的建议 Please give your suggestion
No response
The text was updated successfully, but these errors were encountered: