diff --git a/gluefactory/models/utils/losses.py b/gluefactory/models/utils/losses.py index cca17636..06c7958b 100644 --- a/gluefactory/models/utils/losses.py +++ b/gluefactory/models/utils/losses.py @@ -69,5 +69,5 @@ def nll_loss(self, log_assignment, data): weights[:, :m, :n] = positive weights[:, :m, -1] = neg0 - weights[:, -1, :m] = neg1 + weights[:, -1, :n] = neg1 return weights