diff --git a/models/experimental.py b/models/experimental.py index 548353c93be0..85a22d42fa54 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -110,7 +110,9 @@ def forward(self, x, augment=False): return y, None # inference, train output -def attempt_load(weights, map_location=None): +def attempt_load(weights, map_location=None, inplace=True): + from models.yolo import Detect, Model + # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: @@ -120,8 +122,8 @@ def attempt_load(weights, map_location=None): # Compatibility updates for m in model.modules(): - if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: - m.inplace = True # pytorch 1.7.0 compatibility + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]: + m.inplace = inplace # pytorch 1.7.0 compatibility elif type(m) is Conv: m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility diff --git a/models/yolo.py b/models/yolo.py index cbff70fc83d4..520047ff3a99 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -26,7 +26,7 @@ class Detect(nn.Module): stride = None # strides computed during build export = False # onnx export - def __init__(self, nc=80, anchors=(), ch=()): # detection layer + def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer super(Detect, self).__init__() self.nc = nc # number of classes self.no = nc + 5 # number of outputs per anchor @@ -37,6 +37,7 @@ def __init__(self, nc=80, anchors=(), ch=()): # detection layer self.register_buffer('anchors', a) # shape(nl,na,2) self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv + self.inplace = inplace # use in-place ops (e.g. slice assignment) def forward(self, x): # x = x.copy() # for profiling @@ -52,8 +53,13 @@ def forward(self, x): self.grid[i] = self._make_grid(nx, ny).to(x[i].device) y = x[i].sigmoid() - y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy - y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + if self.inplace: + y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953 + xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + y = torch.cat((xy, wh, y[..., 4:]), -1) z.append(y.view(bs, -1, self.no)) return x if self.training else (torch.cat(z, 1), x) @@ -85,12 +91,14 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i self.yaml['anchors'] = round(anchors) # override yaml value self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist self.names = [str(i) for i in range(self.yaml['nc'])] # default names + self.inplace = self.yaml.get('inplace', True) # logger.info([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) # Build strides, anchors m = self.model[-1] # Detect() if isinstance(m, Detect): s = 256 # 2x min stride + m.inplace = self.inplace m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward m.anchors /= m.stride.view(-1, 1, 1) check_anchor_order(m) @@ -105,24 +113,23 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i def forward(self, x, augment=False, profile=False): if augment: - img_size = x.shape[-2:] # height, width - s = [1, 0.83, 0.67] # scales - f = [None, 3, None] # flips (2-ud, 3-lr) - y = [] # outputs - for si, fi in zip(s, f): - xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) - yi = self.forward_once(xi)[0] # forward - # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save - yi[..., :4] /= si # de-scale - if fi == 2: - yi[..., 1] = img_size[0] - yi[..., 1] # de-flip ud - elif fi == 3: - yi[..., 0] = img_size[1] - yi[..., 0] # de-flip lr - y.append(yi) - return torch.cat(y, 1), None # augmented inference, train + return self.forward_augment(x) # augmented inference, None else: return self.forward_once(x, profile) # single-scale inference, train + def forward_augment(self, x): + img_size = x.shape[-2:] # height, width + s = [1, 0.83, 0.67] # scales + f = [None, 3, None] # flips (2-ud, 3-lr) + y = [] # outputs + for si, fi in zip(s, f): + xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) + yi = self.forward_once(xi)[0] # forward + # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save + yi = self._descale_pred(yi, fi, si, img_size) + y.append(yi) + return torch.cat(y, 1), None # augmented inference, train + def forward_once(self, x, profile=False): y, dt = [], [] # outputs for m in self.model: @@ -146,6 +153,23 @@ def forward_once(self, x, profile=False): logger.info('%.1fms total' % sum(dt)) return x + def _descale_pred(self, p, flips, scale, img_size): + # de-scale predictions following augmented inference (inverse operation) + if self.inplace: + p[..., :4] /= scale # de-scale + if flips == 2: + p[..., 1] = img_size[0] - p[..., 1] # de-flip ud + elif flips == 3: + p[..., 0] = img_size[1] - p[..., 0] # de-flip lr + else: + x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale + if flips == 2: + y = img_size[0] - y # de-flip ud + elif flips == 3: + x = img_size[1] - x # de-flip lr + p = torch.cat((x, y, wh, p[..., 4:]), -1) + return p + def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency # https://arxiv.org/abs/1708.02002 section 3.3 # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.