Scatter version of Cross-Entropy Loss #9687
leonardcaquot94
started this conversation in
Ideas
Replies: 0 comments
# for free
to join this conversation on GitHub.
Already have an account?
# to comment
-
Context
I’m working on a node selection task where I compute logits for each node in a batch and train them using cross-entropy loss and selection probabilities. Since I select one node per graph and the number of nodes per graph varies, cross-entropy must be computed per graph using the batch tensor for grouping.
I propose adding a scatter-based cross-entropy function, similar to scatter_softmax in torch_geometric.utils, to compute one loss per graph.
Implementation
Reduction
The
reduce
parameter mirrors the behavior ofreduction
parameter intorch.nn.functional.cross_entropy
to compute a per-graph loss. An additional reduction can be applied outside this function to obtain a single scalar value if needed.Beta Was this translation helpful? Give feedback.
All reactions