sam自动生成mask代码解析
要自动生成mask,请向“SamAutomaticMaskGenerator”类注入SAM模型(需要先初始化SAM模型)
import sys sys.path.append("..") from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor sam_checkpoint = "sam_vit_b_01ec64.pth" model_type = "vit_b" device = "cuda" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) #自动生成采样点对图像进行分割 mask_generator = SamAutomaticMaskGenerator(sam) masks = mask_generator.generate(image) print(len(masks)) print(masks[0].keys()) print(masks[0]) plt.figure(figsize=(16,16)) plt.imshow(image) show_anns(masks) plt.axis('off') plt.show()
其中print masks这块输出为
42 dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box']) {'segmentation': array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], ..., [ True, True, True, ..., False, False, False], [ True, True, True, ..., False, False, False], [False, False, False, ..., False, False, False]]), 'area': 18821, 'bbox': [0, 113, 207, 152], 'predicted_iou': 0.9937220215797424, 'point_coords': [[93.75, 146.015625]], 'stability_score': 0.9622295498847961, 'crop_box': [0, 0, 400, 267]}
例如生成的图片
masks = mask_generator.generate(image)
Mask generation返回该图像所有的masks信息,每一个mask都是一个字典对象,mask的keys如下:
- segmentation : np的二维数组,为二值的mask图片
- area : mask的像素面积
- bbox : mask的外接矩形框,为XYWH格式
- predicted_iou : 该mask的质量(模型预测出的与真实框的iou)
- point_coords : 用于生成该mask的point输入
- stability_score : mask质量的附加指标
- crop_box : 用于以XYWH格式生成此遮罩的图像裁剪
在自动掩模生成中有几个可调参数,用于控制采样点的密度以及去除低质量或重复掩模的阈值。此外,SamAutomaticMaskGenerator可以自动在图像上切片运行,以提高较小对象的性能,可以通过后处理去除杂散像素和孔洞。以下是对更多遮罩进行采样的示例配置:
mask_generator_2 = SamAutomaticMaskGenerator( model=sam, points_per_side=32,#控制采样点的间隔,值越小,采样点越密集(这个有争议,我在测试时值越小,输出的mask数量越少) pred_iou_thresh=0.86,#mask的iou阈值 stability_score_thresh=0.92,#mask的稳定性阈值 crop_n_layers=1, crop_n_points_downscale_factor=2, min_mask_region_area=50, #最小mask面积,会使用opencv滤除掉小面积的区域 ) masks2 = mask_generator_2.generate(image) print(len(masks2)) # 69 plt.figure(figsize=(20,20)) plt.imshow(image) show_anns(masks2) plt.axis('off') plt.show()
本文来自博客园,作者:海_纳百川,转载请注明原文链接:https://www.cnblogs.com/chentiao/p/17571225.html,如有侵权联系删除