-
-
Notifications
You must be signed in to change notification settings - Fork 108
/
automaticmaskgeneration.py
38 lines (32 loc) · 1.21 KB
/
automaticmaskgeneration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
'''
Function:
SAMV2 examples: Automatic mask generation
Author:
Zhenchao Jin
'''
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from ssseg.modules.models.segmentors.samv2.visualization import showanns
from ssseg.modules.models.segmentors.samv2 import SAMV2AutomaticMaskGenerator
# initialize environment
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# read image
image = Image.open('images/cars.jpg')
image = np.array(image.convert("RGB"))
# mask_generator could be SAMV2AutomaticMaskGenerator(use_default_samv2_t=True) or SAMV2AutomaticMaskGenerator(use_default_samv2_s=True) or SAMV2AutomaticMaskGenerator(use_default_samv2_bplus=True) or SAMV2AutomaticMaskGenerator(use_default_samv2_l=True)
mask_generator = SAMV2AutomaticMaskGenerator(use_default_samv2_l=True, device='cuda', apply_postprocessing=False)
# generate
masks = mask_generator.generate(image)
# show results
print(len(masks))
print(masks[0].keys())
plt.figure(figsize=(20, 20))
plt.imshow(image)
showanns(masks)
plt.axis('off')
plt.savefig('output.png')