飞桨系列课程:AI识虫比赛之使用K-means算法获取锚框(anchors)

本文介绍了一种使用K-means算法优化目标检测锚框(anchors)的方法,通过百度AI识虫比赛的数据集,展示了如何从真实框(ground truth boxes)中自动学习出适合的锚框尺寸,以提高目标检测模型的性能。

在基于锚框的目标检测领域中,锚框的设定一直是一个困难和重要的课题,现在以百度课程中的AI识虫比赛为案例,实现以K-means算法获取锚框(anchors)

本文最终结果仅供参考。

在百度AI Studio上有比赛对应的昆虫数据集,我也专门为这个算法建了一个项目,项目是基于Notebook建立的,可以直接运行,会更加的直观,感兴趣的伙伴点K-means算法计算识虫比赛锚框

导入必要的包

import numpy as np
import matplotlib.patches as patches
from matplotlib.image import imread
import math
import random
import os
import xml.etree.ElementTree as ET

解压数据以及一些数据处理的准备工作

# 解压数据脚本,第一次运行时打开注释,将文件解压到work目录下
!unzip -d /home/aistudio/work /home/aistudio/data/data19638/insects.zip

INSECT_NAMES = ['Boerner', 'Leconte', 'Linnaeus', 
                'acuminatus', 'armandi', 'coleoptera', 'linnaeus']

def get_insect_names():
    """
    return a dict, as following,
        {'Boerner': 0,
         'Leconte': 1,
         'Linnaeus': 2, 
         'acuminatus': 3,
         'armandi': 4,
         'coleoptera': 5,
         'linnaeus': 6
        }
    It can map the insect name into an integer label.
    """
    insect_category2id = {}
    for i, item in enumerate(INSECT_NAMES):
        insect_category2id[item] = i

    return insect_category2id
cname2cid = get_insect_names()

#从数据文件中拿到相关信息,读者只需关注真实框(gt_bboxs)的信息即可gt_bboxs, 也可以省略不看,对后续的影响不大


def get_annotations(cname2cid, datadir):
    filenames = os.listdir(os.path.join(datadir, 'annotations', 'xmls'))
    records = []
    ct = 0
    for fname in filenames:
        fid = fname.split('.')[0]
        fpath = os.path.join(datadir, 'annotations', 'xmls', fname)
        img_file = os.path.join(datadir, 'images', fid + '.jpeg')
        tree = ET.parse(fpath)  # 解析每一个 xml文件

        if tree.find('id') is None:
            im_id = np.array([ct])
        else:
            im_id = np.array([int(tree.find('id').text)])

        objs = tree.findall('object')  # 拿到所有obj的内容
        im_w = float(tree.find('size').find('width').text)
        im_h = float(tree.find('size').find('height').text)
        gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
        gt_class = np.zeros((len(objs), ), dtype=np.int32)
        is_crowd = np.zeros((len(objs), ), dtype=np.int32)
        difficult = np.zeros((len(objs), ), dtype=np.int32)
        for i, obj in enumerate(objs):  #具体拿每个obj的内容
            cname = obj.find('name').text
            gt_class[i] = cname2cid[cname]
            _difficult = int(obj.find('difficult').text)
            x1 = float(obj.find('bndbox').find('xmin').text)
            y1 = float(obj.find('bndbox').find('ymin').text)
            x2 = float(obj.find('bndbox').find('xmax').text)
            y2 = float(obj.find('bndbox').find('ymax').text)
            x1 = max(0, x1)
            y1 = max(0, y1)
            x2 = min(im_w - 1, x2)
            y2 = min(im_h - 1, y2)
            # 这里使用xywh格式来表示目标物体真实框
            gt_bbox[i] = [(x1+x2)/2.0 , (y1+y2)/2.0, x2-x1+1., y2-y1+1.]
            is_crowd[i] = 0
            difficult[i] = _difficult

        voc_rec = {
            'im_file': img_file,
            'im_id': im_id,
            'h': im_h,
            'w': im_w,
            'is_crowd': is_crowd,
            'gt_class': gt_class,
            'gt_bbox': gt_bbox,
            'gt_poly': [],
            'difficult': difficult
            }
        if len(objs) != 0:
            records.append(voc_rec)
        ct += 1
    return records
TRAINDIR = '/home/aistudio/work/insects/train'
TESTDIR = '/home/aistudio/work/insects/test'
VALIDDIR = '/home/aistudio/work/insects/val'
cname2cid = get_insect_names()
records = get_annotations(cname2cid, TRAINDIR)
valid_records = get_annotations(cname2cid, VALIDDIR)
# 为了充分利用已有数据,我们把验证集的数据也一并用上,以得到更优秀的锚框
sum_records = records + valid_records
"""
    上面的内容是从数据文件中获取数据的过程, 本案例读者只需要关注此处以及后面的内容即可
    gt_bbox存放的是每张图片存在的真实框的坐标(一张图片有多个真实框)
    是(n,4)的数组, n对应一张图有几个框, 4 对应坐标数据(x,y,w,h)
"""
# 获取训练集和验证集中的所以有真实框坐标信息并返回一个(N,4)的数组
def get_allbox(sum_records):
    gt_bboxes = sum_records[0]['gt_bbox']
    """
    	利用numpy的合并方法,获得所有真实框构成的数组gt_bboxes 
    	它是一个(12203,4)的数组
    """
    for i in range(len(sum_records)-1):
        gt_bbox = sum_records[i+1]['gt_bbox']
        gt_bboxes = np.vstack((gt_bboxes, gt_bbox))
    return  gt_bboxes

gt_bboxes = get_allbox(sum_records)
# 可以打印查看 得到gt_bboxes的形状为(12203,4)
print(gt_bboxes.shape)

K-means算法步骤与实现 K-means算法简述 K-means是一种简单且常用的无监督学习算法,它旨在将数据集划分成K个簇,使得相同簇之内的数据相似性高,不同簇之间的数据相似性低。

算法步骤:

  1. 初始化K个簇中心; 使用相似性度量(一般是欧氏距离),将每个样本分配给与其距离最近的簇中心;
  2. 计算每个簇中所有样本的均值,更新簇中心;
  3. 重复1、2步,直到均簇中心不再变化,或者达到了最大迭代次数。

对box进行K-means的步骤为:

  1. 随机选取K个box作为初始anchor;

  2. 使用IOU度量,将每个box分配给与其距离最近(即IOU最大)的anchor;
    注意要点:计算IOU时,默认两个框之间的x,y都相同,只需要使用(x,y,w,h)中的w,h来计算即可,如果加入x,y计算IOU来分类(分簇),这样的分簇并不正确,因为我们其实只关心(w,h)的分布 最终想要的也只是(w,h

  3. 计算每个簇中所有box宽和高的均值,更新anchor;

  4. 重复2、3步,直到anchor不再变化,或者达到了最大迭代次数。

代码实现:第一步

#  -----------------------------------------第一步:随机选取9个anchor(初始化K个簇中心)-----------------------------------------
def init_anchors(gt_bboxes,seed):
    '''
    gt_bboxes 是一个(N,4)的数组,N对应个数,4对应(x,y,w,h)
    这个值通过函数get_allbox(sum_records)获得
    '''
    gtbox_num = gt_bboxes.shape[0]
    index_list = range(gtbox_num)
    # 随机选取9个框,为便于程序调试和结果观察,设置为随机数种子,使得随机产生的9个数始终相同。
    # 而且再所有程序设计完成后,改变这个随机数种子,也就是改变9个ancho的初始值
    # 可以观察最终结果有所波动,不过波动范围不是很大
    random.seed(seed)
    random_num = random.sample(index_list,9)
    anchors = []  # 装了9个数组(框的坐标即初始化的anchors)的列表
    for i in random_num:
        anchor = gt_bboxes[i]
        anchors.append(anchor)
    return anchors

代码实现:第二步

#  -------------------------第二步:使用IOU度量,将每个box分配给与其距离最近的anchor;-----------------------------------
#  只关心w,h,默认两者的x,y相同,以此计算iou
def box_iou_wh(box1, box2):
    w1,h1 = box1[2:4]
    w2,h2 = box2[2:4]
    s1 = w1*h1
    s2 = w2*h2
    intersection = min(h1,h2) * min(w1,w2)
    if ((w1 < w2) and (h1 < h2)) or ((w1 > w2) and (h1 > h2)):
        union = max(w1,w2) * max(h1,h2)
    else:
        union = s1 + s2 - intersection
    iou = intersection / union
    return iou

# 开始分簇,求均值,更新anchors
def kmeans(anchors, boxes, anchors_num):
    loss = 0
    groups = []
    new_anchors = []
    # 创建9个聚类
    for i in range(anchors_num):
            groups.append([])
    # 遍历每个框
    for box in boxes:
        ious = []
        # 遍历每个初始聚类中心anchor,计算当前box与每个中心的iou,找出最大的IOU后将当前box归为对应的类   
        for anchor in anchors:
            iou = box_iou_wh(box, anchor)
            ious.append(iou)
        index_of_maxiou = ious.index(max(ious))
        groups[index_of_maxiou].append(box)

    # 求每个聚类中,框的w, h 的均值
    for group in groups:
        w_sum = 0
        h_sum = 0
        length = len(group)
        for box_in_group in group:
            w_sum += box_in_group[2]
            h_sum += box_in_group[3]      
        w_mean = w_sum / length
        h_mean = h_sum / length
        # 计算iou时并不关心xy, 所以这里xy设置为默认0
        anchor = np.array([0,0,w_mean,h_mean])
        new_anchors.append(anchor)
    return new_anchors

代码实现:第三步

#  -------------第三步:重复调用kmean函数,直到满足要求:Ⅰ循环次数结束,或者Ⅱ平均值不再变化(代表找到了该类的中心)--------------
def do_kmeans(anchors, boxes, anchors_num, cycle_num):
    cycle = 0
    new_anchors = kmeans(anchors, boxes, anchors_num)
    while True:
        final_anchors = new_anchors
        new_anchors = kmeans(new_anchors, boxes, anchors_num)
        # for anchor in new_anchors:
        # loss = final_anchors - 
        cycle += 1 
        # if cycle % 10 == 0:
        #     print('循环了%d次'%(cycle))
        flag = np.zeros((9))
        for i in range(len(final_anchors)):
            equal = (new_anchors[i] == final_anchors[i]).all()
            flag[i] = equal
        if flag.all():
            print('循环了',cycle,'次,终于找到了中心anchors')
            break        
        if cycle == cycle_num:
            print('循环次数使用完毕')
            break
    # 截取 w ,h
    final_anchors = [anchor[2:4].astype('int32') for anchor in final_anchors ] 
    #由小到大排序
    final_anchors = sorted(final_anchors, key=lambda anchor: anchor[0])
    #换成YOLOV3算法中需要的形式,即变成一个列表[w,h,w,h...w,h,w,h]
    true_final_anchors = []
    for anchor in final_anchors:
        true_final_anchors.append(anchor[0])
        true_final_anchors.append(anchor[1])
    return true_final_anchors

验证结果

#--------------------------------------------------验证结果----------------------------------------
def test(seed=1):
    TRAINDIR = '/home/aistudio/work/insects/train'
    TESTDIR = '/home/aistudio/work/insects/test'
    VALIDDIR = '/home/aistudio/work/insects/val'
    cname2cid = get_insect_names()
    records = get_annotations(cname2cid, TRAINDIR)
    # 最大化利用已知数据,因此验证集上的信息我们也要统计
    valid_records = get_annotations(cname2cid, VALIDDIR)
    # 把验证集和训练集的信息合在一起
    sum_records = records + valid_records
    # 获取所有的真实框坐标信息,返回一个(N,4)的数组
    gt_bboxes = get_allbox(sum_records)
    # 随机初始化anchors
    anchors = init_anchors(gt_bboxes,seed=seed)
    # 设置聚类个数K,这里指要生成的锚框大小个数
    anchors_num = 9
    # 设定迭代次数
    cycle_num = 10000
    # 进行kmeans算法迭代
    final_anchors = do_kmeans(anchors,gt_bboxes,anchors_num,cycle_num)
    print(final_anchors)
    return final_anchors
# -----------------------------------来一起看看结果------------------------------------
test(1)
# 改变随机数种子,也就是改变初始的anchors 看看结果会怎么样?
print('------------------改变初始anchors后--------------------------')
test(2)
print('------------------改变初始anchors后--------------------------')
test(100)
'''
    从以下结果我可以看到,改变初始化值之后,小值波动很小,但是大值波动还是蛮大的,波动的不仅是数值,大小比例也改变了
    为什么? 可能是大值数量较少,分布得比较开,所以随机的初值对于大值而言影响比较大
    解决办法?   那我就随机一百次,取最终的平均值! 一百次时间太长了,,还是整个20次吧
'''
循环了10次
循环了20次
循环了30次
循环了 37 次,终于找到了中心anchors
[42, 63, 58, 91, 60, 42, 83, 84, 86, 58, 89, 136, 134, 77, 134, 111, 143, 153]
------------------改变初始anchors后--------------------------
循环了10次
循环了20次
循环了30次
循环了40次
循环了50次
循环了 52 次,终于找到了中心anchors
[41, 61, 57, 87, 62, 42, 74, 132, 83, 84, 88, 59, 107, 135, 135, 86, 151, 137]
------------------改变初始anchors后--------------------------
循环了10次
循环了20次
循环了30次
循环了40次
循环了50次
循环了 55 次,终于找到了中心anchors
[41, 63, 58, 99, 62, 42, 68, 74, 91, 89, 92, 58, 93, 137, 137, 88, 143, 139]

def get_aver_anchors(num_random):
    # num_random 为随机次数,可以手动设定
    seed_list = range(num_random)
    # 存放所有anchor 方便求平均值
    all_anchors = []
    for seed in seed_list:
        print('种子号是%d'%seed)
        anchors = test(seed)  # 由于要打印的东西太多,所以我们把do_kmeans中打印循环次数的代码注释掉
        # 列表转换为数组方便计算平均值
        anchors = np.array(anchors) # anchors里有9个anchor, 9*2=18 共18个数据
        all_anchors.append(anchors)
    # 同样,把大列表转换成数组
    all_anchors = np.array(all_anchors)
    aver_anchors = np.mean(all_anchors, axis=0).astype('int32')
    return(aver_anchors)


aver_anchors = get_aver_anchors(20)
print('平均anchors:\n',aver_anchors)

种子号是0
循环了 29 次,终于找到了中心anchors
[40, 61, 57, 87, 61, 42, 74, 129, 84, 81, 87, 57, 107, 135, 135, 86, 150, 137]
种子号是1
循环了 37 次,终于找到了中心anchors
[42, 63, 58, 91, 60, 42, 83, 84, 86, 58, 89, 136, 134, 77, 134, 111, 143, 153]
种子号是2
循环了 52 次,终于找到了中心anchors
[41, 61, 57, 87, 62, 42, 74, 132, 83, 84, 88, 59, 107, 135, 135, 86, 151, 137]
种子号是3
循环了 44 次,终于找到了中心anchors
[41, 61, 57, 87, 61, 42, 74, 129, 84, 81, 88, 57, 107, 135, 135, 86, 150, 138]
种子号是4
循环了 31 次,终于找到了中心anchors
[40, 61, 58, 87, 61, 42, 74, 129, 84, 81, 87, 57, 107, 135, 136, 87, 151, 139]
种子号是5
循环了 45 次,终于找到了中心anchors
[42, 62, 61, 88, 65, 44, 75, 126, 87, 66, 105, 135, 129, 78, 141, 109, 150, 154]
种子号是6
循环了 35 次,终于找到了中心anchors
[40, 61, 58, 87, 61, 42, 74, 129, 84, 81, 87, 57, 107, 135, 135, 86, 150, 138]
种子号是7
循环了 29 次,终于找到了中心anchors
[40, 61, 58, 87, 61, 42, 74, 129, 84, 81, 87, 57, 107, 135, 136, 87, 151, 139]
种子号是8
循环了 111 次,终于找到了中心anchors
[41, 61, 57, 87, 61, 42, 74, 129, 84, 81, 88, 57, 107, 135, 135, 86, 150, 138]
种子号是9
循环了 48 次,终于找到了中心anchors
[40, 61, 57, 87, 61, 42, 74, 129, 84, 81, 87, 57, 107, 135, 135, 86, 150, 138]
种子号是10
循环了 53 次,终于找到了中心anchors
[42, 63, 58, 91, 60, 42, 83, 84, 86, 58, 89, 137, 134, 77, 134, 111, 143, 153]
种子号是11
循环了 27 次,终于找到了中心anchors
[42, 62, 58, 91, 62, 42, 82, 86, 86, 59, 89, 137, 133, 77, 135, 111, 143, 153]
种子号是12
循环了 65 次,终于找到了中心anchors
[42, 62, 58, 91, 62, 42, 82, 86, 87, 59, 89, 137, 134, 77, 134, 112, 143, 153]
种子号是13
循环了 48 次,终于找到了中心anchors
[42, 62, 58, 90, 62, 42, 83, 86, 86, 59, 89, 137, 134, 77, 134, 111, 143, 153]
种子号是14
循环了 27 次,终于找到了中心anchors
[46, 52, 56, 86, 73, 124, 82, 53, 85, 78, 103, 137, 132, 77, 138, 109, 150, 153]
种子号是15
循环了 48 次,终于找到了中心anchors
[41, 62, 58, 91, 62, 42, 83, 86, 87, 59, 89, 137, 134, 77, 134, 111, 143, 153]
种子号是16
循环了 32 次,终于找到了中心anchors
[41, 61, 57, 87, 61, 42, 74, 129, 84, 81, 88, 57, 107, 135, 135, 86, 150, 138]
种子号是17
循环了 34 次,终于找到了中心anchors
[41, 61, 57, 87, 61, 42, 74, 129, 84, 81, 88, 57, 107, 135, 135, 86, 150, 138]
种子号是18
循环了 37 次,终于找到了中心anchors
[42, 62, 58, 90, 62, 42, 83, 86, 86, 59, 89, 137, 134, 111, 134, 77, 143, 153]
种子号是19
循环了 40 次,终于找到了中心anchors
[40, 61, 57, 87, 61, 42, 74, 129, 84, 81, 87, 57, 107, 135, 135, 86, 150, 138]

平均anchors:
 [ 41  61  57  88  62  46  77 109  84  72  89  92 118 110 135  95 147 144]

参考资料资料一资料二

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值