文章目录
pointpillar相关的其它文章链接如下:
1. 数据增强
数据增强部分,相对比较清晰,整体流程如下所示。后续openpcdet也出了一些新的数据增强方法,不过目前本人暂时还没有使用。

数据增强部分代码在:pcdet/datasets/augmentor/data_augmentor.py
1.1 gt数据采集——gt_sampling
该模块思路很简单,就是为了丰富训练数据,也就是将其它帧gt的点云以及box放入待训练帧中的空余位置。下面是这部分的配置文件,官方这部分训练了3种类型。
- NAME: gt_sampling
USE_ROAD_PLANE: True
DB_INFO_PATH:
- kitti_dbinfos_train.pkl
PREPARE: {
filter_by_min_points: ['Car:5', 'Pedestrian:5', 'Cyclist:5'],
filter_by_difficulty: [-1],
}
SAMPLE_GROUPS: ['Car:15','Pedestrian:15', 'Cyclist:15']
NUM_POINT_FEATURES: 4
DATABASE_WITH_FAKELIDAR: False
REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]
LIMIT_WHOLE_SCENE: False
首先对采样的gt进行最小点过滤。代码注释如下:
class DataAugmentor(object):
def __init__(self, root_path, augmentor_configs, class_names, logger=None):
self.root_path = root_path
self.class_names = class_names
self.logger = logger
self.data_augmentor_queue = []
# 读取数据增强部分配置文件
aug_config_list = augmentor_configs if isinstance(augmentor_configs, list) \
else augmentor_configs.AUG_CONFIG_LIST
#逐个读取数据增强部分
for cur_cfg in aug_config_list:
if not isinstance(augmentor_configs, list):
#不用数据增强的列表DISABLE_AUG_LIST
if cur_cfg.NAME in augmentor_configs.DISABLE_AUG_LIST:
continue
#使用partial,所以此刻只是把数据增强方法加入队列(data_dict=0)
# 执行数据增加的函数,并加入至data_augmentor_queue
cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
self.data_augmentor_queue.append(cur_augmentor)
#gt数据采集部分
def gt_sampling(self, config=None):
db_sampler = database_sampler.DataBaseSampler(
root_path=self.root_path,
sampler_cfg=config,
class_names=self.class_names,
logger=self.logger
)
return db_sampler
其中DataBaseSampler的代码如下:
class DataBaseSampler(object):
def __init__(self, root_path, sampler_cfg, class_names, logger=None):
self.root_path = root_path
self.class_names = class_names
self.sampler_cfg = sampler_cfg
self.logger = logger
self.db_infos = {
}
#按照类别分类
for class_name in class_names:
self.db_infos[class_name] = []
# use_shared_memory = false
self.use_shared_memory = sampler_cfg.get('USE_SHARED_MEMORY', False)
for db_info_path in sampler_cfg.DB_INFO_PATH:
db_info_path = self.root_path.resolve() / db_info_path
#按照类别加入数据各自的db数据
with open(str(db_info_path), 'rb') as f:
infos = pickle.load(f)
[self.db_infos[cur_class].extend(infos[cur_class]) for cur_class in class_names]
#执行最小点过滤和困难点过滤,我这里只用了filter_by_min_points过滤
for func_name, val in sampler_cfg.PREPARE.items():
self.db_infos = getattr(self, func_name)(self.db_infos, val)
self.gt_database_data_key = self.load_db_to_shared_memory() if self.use_shared_memory else None
self.sample_groups = {
} #sample_num、pointer和indices
self.sample_class_num = {
} #sample_num
self.limit_whole_scene = sampler_cfg.get('LIMIT_WHOLE_SCENE', False) #False
for x in sampler_cfg.SAMPLE_GROUPS:
class_name, sample_num = x.split(':')
if class_name not in class_names:
continue
self.sample_class_num[class_name] = sample_num
self.sample_groups[class_name] = {
'sample_num': sample_num,
'pointer': len(self.db_infos[class_name]),
'indices': np.arange(len(self.db_infos[class_name]))
}
#最小点过滤函数
def filter_by_min_points(self, db_infos, min_gt_points_list):
for name_num in min_gt_points_list:
#对每个类别单独过滤
name, min_


4329

被折叠的 条评论
为什么被折叠?



