【CV基础】语义分割任务计算类别权重

前言

 语义分割任务一般都存在样本类别不平衡的问题,采用类别权重来解决这个问题,本文记录类别权重的计算过程。

类别权重计算的基本思路

 

code

# 20240620: calculate class weights with semantic segmentation gt images.
import os
import numpy as np
import cv2 as cv

void_classes = [2, 4, 10, 12, 16, 17, 19, 21, 25, 30, 32, 33, 35]
valid_classes = [0, 1, 3, 5, 6, 7, 8, 9, 11, 13, 14, 15, 18, 20, 22, 23, 24, 26, 27, 28, 29, 31, 34]

# edgeai-torchvision
def calc_median_frequency(classes, present_num):
    """
    Class balancing by median frequency balancing method.
    Reference: https://arxiv.org/pdf/1411.4734.pdf
       'a = median_freq / freq(c) where freq(c) is the number of pixels
        of class c divided by the total number of pixels in images where
        c is present, and median_freq is the median of these frequencies.'
    """
    class_freq = classes / present_num
    median_freq = np.median(class_freq)
    return median_freq / class_freq

# edgeai-torchvision
def calc_log_frequency(classes, value=1.02):
    """Class balancing by ERFNet method.
       prob = each_sum_pixel / each_sum_pixel.max()
       a = 1 / (log(1.02 + prob)).
    """
    class_freq = classes / classes.sum()  # ERFNet is max, but ERFNet is sum
    # print(class_freq)
    # print(np.log(value + class_freq))
    return 1 / np.log(value + class_freq)

def calculate_class_weight_present(path):
    # edgeai-torchvision
    gtpath = os.path.join(path, 'gt')
    class_counts = np.zeros(len(valid_classes), dtype="f")
    class_freq = np.zeros(len(valid_classes), dtype="f")
    class_weights = np.zeros(len(valid_classes))
    present_num = np.zeros(len(valid_classes), dtype="f")
    for filename in os.listdir(gtpath):  
        # print('filename: ', filename)
        if filename.endswith('.png') or filename.endswith('.jpg'): 
            filepath = os.path.join(gtpath, filename)  
            gtimg = cv.imread(filepath, cv.IMREAD_GRAYSCALE)  
            if gtimg is not None:
                # for i in range(len(valid_classes)):
                #    class_counts[i] += np.sum(gtimg == valid_classes[i])
                for i, classid in enumerate(valid_classes): 
                    num_pixel = np.sum(gtimg == classid)
                    if num_pixel:
                        class_counts[i] += np.sum(gtimg == classid)
                        present_num[i] += 1

    for i, count in enumerate(class_counts):  
        class_freq[i] = count / present_num[i] if present_num[i] > 0 else 0
    # print('class_freq: ', class_freq)
    medval = np.median(class_freq)
    # print('medval: ', medval)
    for i, freq in enumerate(class_freq):  
        # class_weights[i] = medval / freq
        class_weights[i] = medval / freq if freq > 0 else 0
    print(class_weights)
    # for i, weight in enumerate(class_weights):  
    #     print(f"类别 {valid_classes[i]}: 权重 = {weight}")  

    # Normalization
    # # 对权重进行归一化,使它们的和为1(可选步骤,取决于你的应用)
    # class_weights = class_weights / class_weights.sum()
    # for i, weight in enumerate(class_weights):  
    #     print(f"类别 {valid_classes[i]}: 权重 = {weight}")  



def calculate_class_weight_all(path):
    gtpath = os.path.join(path, 'gt')
    class_counts = np.zeros(len(valid_classes), dtype=np.int64)
    class_weights = np.zeros(len(valid_classes))
    for filename in os.listdir(gtpath):  
        # print('filename: ', filename)
        if filename.endswith('.png') or filename.endswith('.jpg'): 
            filepath = os.path.join(gtpath, filename)  
            gtimg = cv.imread(filepath, cv.IMREAD_GRAYSCALE)  
            if gtimg is not None:
                # for i in range(len(valid_classes)):
                #    class_counts[i] += np.sum(gtimg == valid_classes[i])
                for i, classid in enumerate(valid_classes): 
                    class_counts[i] += np.sum(gtimg == classid)

    total_pixels = class_counts.sum()
    # print('class_counts: \n', class_counts)
    # print('totalpixel: ', total_pixels)
    medval = np.median(class_counts)
    # print('medval: ', medval)
    for i, count in enumerate(class_counts):  
        class_weights[i] = medval / count if count > 0 else 0
    print(class_weights)
    # for i, weight in enumerate(class_weights):  
    #     print(f"类别 {valid_classes[i]}: 权重 = {weight}")  
    # Normalization

    # # 对权重进行归一化,使它们的和为1(可选步骤,取决于你的应用)
    # class_weights = class_weights / class_weights.sum()
    # for i, weight in enumerate(class_weights):  
    #     print(f"类别 {valid_classes[i]}: 权重 = {weight}")  


if __name__ == "__main__":
    path = os.path.dirname(os.path.realpath(__file__))
    # calculate_class_weight_all(path)
    # print("\n\n\n  start present \n\n\n")
    calculate_class_weight_present(path)
View Code

 

参考

posted on 2024-12-09 16:32  鹅要长大  阅读(8)  评论(0编辑  收藏  举报

导航