Skip to content

Commit

Permalink
Fix AutoAnchor MPS bug (ultralytics#9188)
Browse files Browse the repository at this point in the history
Resolves ultralytics#8862

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
glenn-jocher authored and Clay Januhowski committed Sep 8, 2022
1 parent 5ade9e0 commit f9a4aea
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions utils/autoanchor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import yaml
from tqdm import tqdm

from utils import TryExcept
from utils.general import LOGGER, colorstr

PREFIX = colorstr('AutoAnchor: ')
Expand All @@ -25,6 +26,7 @@ def check_anchor_order(m):
m.anchors[:] = m.anchors.flip(0)


@TryExcept(f'{PREFIX}ERROR:')
def check_anchors(dataset, model, thr=4.0, imgsz=640):
# Check anchor fit to data, recompute if necessary
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
Expand All @@ -49,10 +51,7 @@ def metric(k): # compute metric
else:
LOGGER.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')
na = m.anchors.numel() // 2 # number of anchors
try:
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
except Exception as e:
LOGGER.info(f'{PREFIX}ERROR: {e}')
anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
new_bpr = metric(anchors)[0]
if new_bpr > bpr: # replace anchors
anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
Expand Down Expand Up @@ -124,7 +123,7 @@ def print_results(k, verbose=True):
i = (wh0 < 3.0).any(1).sum()
if i:
LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found: {i} of {len(wh0)} labels are < 3 pixels in size')
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32) # filter > 2 pixels
# wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1

# Kmeans init
Expand Down Expand Up @@ -167,4 +166,4 @@ def print_results(k, verbose=True):
if verbose:
print_results(k, verbose)

return print_results(k)
return print_results(k).astype(np.float32)

0 comments on commit f9a4aea

Please sign in to comment.