ST-GCN的学习之路(二)源码解读 (Pytorch版)
引言
上一篇我们阅读了st-gcn的论文,了解了st-gcn的整体思想。这一篇博客我准备就官方推出的Pytorch源码进行详细的分析(会具体到每一句,每一个原理),如果有不足和错误之处希望各位多多指出,欢迎交流,共同进步。(由于博主目前还是一名大三学生,由于学业的事也不能经常更新博客和回应提问,请各位海涵)
论文原文:Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition
ST-GCN(Pytorch)官方源码:https://github.com/yongqyu/st-gcn-pytorch
代码分析
核心代码分析 net网络
核心代码共分3个文件,在net文件夹下,分别为graph.py, tgcn.py, st-gcn.py。其中graph.py中包含邻接矩阵的建立和结点分组策略(下面会详细介绍结点分组策略的含义)、st-gcn.py包含整个网络部分的结构和前向传播方法、tgcn.py主要是空间域卷积的结构和前向传播方法。
graph.py
首先我们先来看下graph.py,类Graph的构造函数使用了self.get_edge、self.hop_dis、self.get_adjacency,在这个模块主要分了3类:
- 邻接矩阵的建立
- 归一化以及快速图卷积的与处理
- 权值的分组
class Graph():
def __init__(self,
layout='openpose',
strategy='uniform',
max_hop=1,
dilation=1):
self.max_hop = max_hop
self.dilation = dilation
self.get_edge(layout) # 确定图中结点间边的关系
self.hop_dis = get_hop_distance(self.num_node, self.edge, max_hop=max_hop)# 获得邻接矩阵
self.get_adjacency(strategy)
...
self.get_edge
这里采用的是OpenPose的节点进行举例,需要指出的是作者的节点连接顺序与本来OP中提供的输出格式的连接顺序是不同的,具体的体现在(2,8)(5,11)点的连接,这样的连接对结果没有影响,但是也不能简单地认为将OP中的节点pair改为st-gcn中的顺序就匹配了,因为不能忘记OP中的PAF的训练是按照(1,8)(1,11)进行训练的。
def get_edge(self, layout):
if layout == 'openpose':
self.num_node = 18
self_link = [(i, i) for i in range(self.num_node)]
neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12,
11),
(10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1),
(0, 1), (15, 0), (14, 0), (17, 15), (16, 14)]
self.edge = self_link + neighbor_link
self.center = 1
elif layout == 'ntu-rgb+d':
self.num_node = 25
...!
从源码可以看出来center点是neck(1)点。注意如果两点都邻接不可到中心点即距离都是inf,那么算作远心点。
self.get_hop_distance
def get_hop_distance(num_node, edge, max_hop=1):
A = np.zeros((num_node, num_node))
for i, j in edge: #构建邻接矩阵
A[j, i] = 1
A[i, j] = 1
# compute hop steps
hop_dis = np.zeros((num_node, num_node)) + np.inf
transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
arrive_mat = (np.stack(transfer_mat) > 0) # transfer_mat是list类型,需要将list堆叠成一个数组才能进行>操作
for d in range(max_hop, -1, -1):
hop_dis[arrive_mat[d]] = d
return hop_dis
这一段代码中获得了带自环的邻接矩阵(是18乘18的方阵),非连接处是inf
self. get_adjacency
def get_adjacency(self, strategy):
valid_hop = range(0, self.max_hop + 1, self.dilation) # 合法的距离值:0或1
adjacency = np.zeros((self.num_node, self.num_node))
for hop in valid_hop:
adjacency[self.hop_dis == hop] = 1 # 将0|1的位置置1,inf抛弃
normalize_adjacency = normalize_digraph(adjacency)#图卷积的预处理
...
elif strategy == 'spatial': # 如果按论文的第三种划分方式
A = []
for hop in valid_hop:
a_root = np.zeros((self.num_node, self.num_node))
a_close = np.zeros((self.num_node, self.num_node))
a_further = np.zeros((self.num_node, self.num_node))
for i in range(self.num_node):
for j in range(self.num_node):
if self.hop_dis[j, i] == hop: # 如果结点j和结点i是邻结点
# 比较结点i和结点j分别到中心点的距离,中心点默认为为openpose输出的1结点
if self.hop_dis[j, self.center] == self.hop_dis[ i, self.center]:
a_root[j, i] = normalize_adjacency[j, i]
elif self.hop_dis[j, self. center] > self.hop_dis[i, self.center]:
a_close[j, i] = normalize_adjacency[j, i]
else:
a_further[j, i] = normalize_adjacency[j, i]
if hop == 0:
A.append(a_root) # A的第一维第1个矩阵:self distance matrix 对角阵
else:
A.append(a_root + a_close) # A的第一维第2个矩阵:列对结点到中心点的距离比行对应点到中心点的距离近或者相等(都为inf)
A.append(a_further) # A的第一维第3个矩阵:列对应结点到中心点的距离比行对应点到中心点的距离远
A = np.stack(A)
self.A = A
# 输出A的shape(3,18,18)
...
# 图卷积的预处理
def normalize_digraph(A):
Dl = np.sum(A, 0) #计算邻接矩阵的度
num_node = A.shape[0]
Dn = np.zeros((num_node, num_node))
for i in range(num_node):
if Dl[i] > 0:
Dn[i, i] = Dl[i]**(-1) #由每个点的度组成的对角矩阵
AD = np.dot(A, Dn)
return AD
这段代码将会输出一个(3,18,18)的权值分组A矩阵。那么这个矩阵是怎么来的呢?这就要追溯到论文里提到的三种划分方法了:

- Uni-labeling,全部 B ( v t i ) B(v_{ti}) B(vt

源码解读 (Pytorch版)&spm=1001.2101.3001.5002&articleId=115030327&d=1&t=3&u=7d5c160a7a91477690007816f6fe9b93)
1985

被折叠的 条评论
为什么被折叠?



