sam自动生成mask代码解析

要自动生成mask,请向“SamAutomaticMaskGenerator”类注入SAM模型(需要先初始化SAM模型)

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
#自动生成采样点对图像进行分割
mask_generator = SamAutomaticMaskGenerator(sam)

masks = mask_generator.generate(image)

print(len(masks))
print(masks[0].keys())
print(masks[0])

plt.figure(figsize=(16,16))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

其中print masks这块输出为

42
dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])
{'segmentation': array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       ...,
       [ True,  True,  True, ..., False, False, False],
       [ True,  True,  True, ..., False, False, False],
       [False, False, False, ..., False, False, False]]), 'area': 18821, 'bbox': [0, 113, 207, 152], 'predicted_iou': 0.9937220215797424, 'point_coords': [[93.75, 146.015625]], 'stability_score': 0.9622295498847961, 'crop_box': [0, 0, 400, 267]}

例如生成的图片

masks = mask_generator.generate(image)

Mask generation返回该图像所有的masks信息,每一个mask都是一个字典对象,mask的keys如下:

  • segmentation : np的二维数组,为二值的mask图片
  • area : mask的像素面积
  • bbox : mask的外接矩形框,为XYWH格式
  • predicted_iou : 该mask的质量(模型预测出的与真实框的iou)
  • point_coords : 用于生成该mask的point输入
  • stability_score : mask质量的附加指标
  • crop_box : 用于以XYWH格式生成此遮罩的图像裁剪

 在自动掩模生成中有几个可调参数,用于控制采样点的密度以及去除低质量或重复掩模的阈值。此外,SamAutomaticMaskGenerator可以自动在图像上切片运行,以提高较小对象的性能,可以通过后处理去除杂散像素和孔洞。以下是对更多遮罩进行采样的示例配置:

mask_generator_2 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,#控制采样点的间隔,值越小,采样点越密集(这个有争议,我在测试时值越小,输出的mask数量越少)
    pred_iou_thresh=0.86,#mask的iou阈值
    stability_score_thresh=0.92,#mask的稳定性阈值
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=50,  #最小mask面积,会使用opencv滤除掉小面积的区域
)
masks2 = mask_generator_2.generate(image)

print(len(masks2)) # 69

plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show() 

 

posted @ 2023-07-21 14:12  海_纳百川  阅读(1809)  评论(0编辑  收藏  举报
本站总访问量