您的位置:首页 > 其它

语义分割中类别不平衡的解决方法

2020-07-15 04:21 1296 查看

图片分割经常会遇到class unbalance的情况,如果你的要求是每个类别的accuracy 都很高,那么在训练的时候做class balancing 很重要,如果你的要求只要求图片总体的pixel accuracy好,那么class balancing 此时就不是很重要,因为占比小的class, accuray 虽然小,但是对总体的Pixel accuracy影响也较小。

ps:这里我只说了一种方式(借鉴别人的,自己又补充了一点),还有其他方式处理,例如采用不同的损失函数等

对于一个多类别图片数据库,每个类别都会有一个class frequency(该类别像素数目除以数据库总像素数目),求出所有class frequency 的median 值(是1/n吗,n为类别数,我的理解是这样),除以上述所求的class frequency。

这样可以保证占比小的class, 权重大于1, 占比大的class, 权重小于1, 达到balancing的效果.

如对我自己的数据有两类分别为0,1, 一共55张500500训练图片,统计55张图片中0,1像素的个数:

count1 227611

count0 13522389

freq1 = 227611/(50050055) = 0.0166

freq0 = 13522389/(500500*55) = 0.9834

median = 1/2=0.5(2分类)

weight1 = 0.5/0.0166=30.12

weight0 = 0.5/0.9834=0.508

由于有人想要知道怎么计算所有标签中每个类别的像素的总数,在这里我在这个github上找了一个pytorch代码,分享如下

import os
from tqdm import tqdm
import numpy as np
from mypath import Path##这里是作者自己创建的一个文件,用来生成路径的

def calculate_weigths_labels(dataset, dataloader, num_classes):
# Create an instance from the data loader
z = np.zeros((num_classes,))
# Initialize tqdm
tqdm_batch = tqdm(dataloader)
print('Calculating classes weights')
for sample in tqdm_batch:
y = sample['label']##这里是作者创建的一个dataloader,这里的sample['label']返回的是标签图像的lable mask
y = y.detach().cpu().numpy()
mask = (y >= 0) & (y < num_classes)
labels = y[mask].astype(np.uint8)
count_l = np.bincount(labels, minlength=num_classes)##统计每幅图像中不同类别像素的个数
z += count_l
tqdm_batch.close()
total_frequency = np.sum(z)
class_weights = []
for frequency in z:
class_weight = 1 / (np.log(1.02 + (frequency / total_frequency)))##这里是计算每个类别像素的权重
class_weights.append(class_weight)
ret = np.array(class_weights)
classes_weights_path = os.path.join(Path.db_root_dir(dataset), dataset+'_classes_weights.npy')##生成权重文件
np.save(classes_weights_path, ret)##把各类别像素权重保存到一个文件中

return ret

从上面这个bincount函数可以看出来,github这个作者统计每个像素的权重的方式和我上面所讲的有所不同

class_weight = 1 / (np.log(1.02 + (frequency / total_frequency)))
这里的frequency是某个类别像素的总数目,total_frequency是总类别像素的数目,frequency / total_frequency应该就是上面所说的class frequency,但github作者使用1 / (np.log(1.02 + (frequency / total_frequency))这样一种方式计算权重的。

**我有点迷,我感觉github上的那个大佬应该是对的,或者说处理权重的方式不一样?希望有大佬可以帮忙解答

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: