From 5c28b03c5ba4caf38859b1d9a25b30ef7c8e8550 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 21 Mar 2022 18:55:10 +0100 Subject: [PATCH] Update loss for FP16 `tobj` --- utils/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/loss.py b/utils/loss.py index b49cc7f66e66..a06330e034bc 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -125,7 +125,7 @@ def __call__(self, p, targets): # predictions, targets # Losses for i, pi in enumerate(p): # layer index, layer predictions b, a, gj, gi = indices[i] # image, anchor, gridy, gridx - tobj = torch.zeros(pi.shape[:4], device=self.device) # target obj + tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj n = b.shape[0] # number of targets if n: