openpcdet之pointpillar代码阅读——第一篇:数据增强与数据处理

pointpillar相关的其它文章链接如下:

  1. 【论文阅读】CVPR 2019| PointPillars: 基于点云的快速编码目标检测框架(Fast Encoders for Object Detection from Point Clouds)
  2. OpenPCDet v0.5版本的安装与测试
  3. openpcdet之pointpillar代码阅读——第一篇:数据增强与数据处理
  4. openpcdet之pointpillar代码阅读——第二篇:网络结构
  5. openpcdet之pointpillar代码阅读——第三篇:损失函数的计算

1. 数据增强

数据增强部分,相对比较清晰,整体流程如下所示。后续openpcdet也出了一些新的数据增强方法,不过目前本人暂时还没有使用。

在这里插入图片描述

数据增强部分代码在:pcdet/datasets/augmentor/data_augmentor.py

1.1 gt数据采集——gt_sampling

该模块思路很简单,就是为了丰富训练数据,也就是将其它帧gt的点云以及box放入待训练帧中的空余位置。下面是这部分的配置文件,官方这部分训练了3种类型。

            - NAME: gt_sampling
              USE_ROAD_PLANE: True
              DB_INFO_PATH:
                  - kitti_dbinfos_train.pkl
              PREPARE: {
   
   
                 filter_by_min_points: ['Car:5', 'Pedestrian:5', 'Cyclist:5'],
                 filter_by_difficulty: [-1],
              }

              SAMPLE_GROUPS: ['Car:15','Pedestrian:15', 'Cyclist:15']
              NUM_POINT_FEATURES: 4
              DATABASE_WITH_FAKELIDAR: False
              REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]
              LIMIT_WHOLE_SCENE: False

首先对采样的gt进行最小点过滤。代码注释如下:

class DataAugmentor(object):
    def __init__(self, root_path, augmentor_configs, class_names, logger=None):
        self.root_path = root_path
        self.class_names = class_names
        self.logger = logger
        
        self.data_augmentor_queue = []
        # 读取数据增强部分配置文件
        aug_config_list = augmentor_configs if isinstance(augmentor_configs, list) \
            else augmentor_configs.AUG_CONFIG_LIST
        #逐个读取数据增强部分
        for cur_cfg in aug_config_list:
            if not isinstance(augmentor_configs, list):
            	#不用数据增强的列表DISABLE_AUG_LIST
                if cur_cfg.NAME in augmentor_configs.DISABLE_AUG_LIST:
                    continue
            #使用partial,所以此刻只是把数据增强方法加入队列(data_dict=0)
            # 执行数据增加的函数,并加入至data_augmentor_queue
            cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
            self.data_augmentor_queue.append(cur_augmentor)
    
    #gt数据采集部分
    def gt_sampling(self, config=None):
        db_sampler = database_sampler.DataBaseSampler(
            root_path=self.root_path,
            sampler_cfg=config,
            class_names=self.class_names,
            logger=self.logger
        )
        return db_sampler

其中DataBaseSampler的代码如下:

class DataBaseSampler(object):
    def __init__(self, root_path, sampler_cfg, class_names, logger=None):
        self.root_path = root_path
        self.class_names = class_names
        self.sampler_cfg = sampler_cfg
        self.logger = logger
        self.db_infos = {
   
   }
        #按照类别分类
        for class_name in class_names:
            self.db_infos[class_name] = []

        # use_shared_memory = false
        self.use_shared_memory = sampler_cfg.get('USE_SHARED_MEMORY', False)
        
        for db_info_path in sampler_cfg.DB_INFO_PATH:
            db_info_path = self.root_path.resolve() / db_info_path
            #按照类别加入数据各自的db数据
            with open(str(db_info_path), 'rb') as f:
                infos = pickle.load(f)
                [self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names]
        #执行最小点过滤和困难点过滤,我这里只用了filter_by_min_points过滤
        for func_name, val in sampler_cfg.PREPARE.items():
            self.db_infos = getattr(self, func_name)(self.db_infos, val)
        
        self.gt_database_data_key = self.load_db_to_shared_memory() if self.use_shared_memory else None

        self.sample_groups = {
   
   } #sample_num、pointer和indices
        self.sample_class_num = {
   
   } #sample_num
        self.limit_whole_scene = sampler_cfg.get('LIMIT_WHOLE_SCENE', False) #False

        for x in sampler_cfg.SAMPLE_GROUPS:
            class_name, sample_num = x.split(':')
            if class_name not in class_names:
                continue
            self.sample_class_num[class_name] = sample_num
            self.sample_groups[class_name] = {
   
   
                'sample_num': sample_num,
                'pointer': len(self.db_infos[class_name]),
                'indices': np.arange(len(self.db_infos[class_name]))
            }
 #最小点过滤函数 
  def filter_by_min_points(self, db_infos, min_gt_points_list):
        for name_num in min_gt_points_list:
            #对每个类别单独过滤
            name, min_
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

非晚非晚

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值