ST-GCN的学习之路(二)源码解读 (Pytorch版)

引言

上一篇我们阅读了st-gcn的论文,了解了st-gcn的整体思想。这一篇博客我准备就官方推出的Pytorch源码进行详细的分析(会具体到每一句,每一个原理),如果有不足和错误之处希望各位多多指出,欢迎交流,共同进步。(由于博主目前还是一名大三学生,由于学业的事也不能经常更新博客和回应提问,请各位海涵)

论文原文:Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition
ST-GCN(Pytorch)官方源码:https://github.com/yongqyu/st-gcn-pytorch

代码分析

核心代码分析 net网络

核心代码共分3个文件,在net文件夹下,分别为graph.py, tgcn.py, st-gcn.py。其中graph.py中包含邻接矩阵的建立和结点分组策略(下面会详细介绍结点分组策略的含义)、st-gcn.py包含整个网络部分的结构和前向传播方法、tgcn.py主要是空间域卷积的结构和前向传播方法。

graph.py

首先我们先来看下graph.py,类Graph的构造函数使用了self.get_edge、self.hop_dis、self.get_adjacency,在这个模块主要分了3类:

  1. 邻接矩阵的建立
  2. 归一化以及快速图卷积的与处理
  3. 权值的分组
class Graph():
    def __init__(self,
                 layout='openpose',
                 strategy='uniform',
                 max_hop=1,
                 dilation=1):
        self.max_hop = max_hop
        self.dilation = dilation

        self.get_edge(layout) # 确定图中结点间边的关系
        self.hop_dis = get_hop_distance(self.num_node, self.edge, max_hop=max_hop)# 获得邻接矩阵
        self.get_adjacency(strategy)
        ...
self.get_edge

这里采用的是OpenPose的节点进行举例,需要指出的是作者的节点连接顺序与本来OP中提供的输出格式的连接顺序是不同的,具体的体现在(2,8)(5,11)点的连接,这样的连接对结果没有影响,但是也不能简单地认为将OP中的节点pair改为st-gcn中的顺序就匹配了,因为不能忘记OP中的PAF的训练是按照(1,8)(1,11)进行训练的。

  def get_edge(self, layout):
        if layout == 'openpose':
            self.num_node = 18
            self_link = [(i, i) for i in range(self.num_node)]
            neighbor_link = [(4, 3), (3, 2), (7, 6), (6, 5), (13, 12), (12,
                                                                        11),
                             (10, 9), (9, 8), (11, 5), (8, 2), (5, 1), (2, 1),
                             (0, 1), (15, 0), (14, 0), (17, 15), (16, 14)]
            self.edge = self_link + neighbor_link
            self.center = 1
        elif layout == 'ntu-rgb+d':
            self.num_node = 25
            ...!

openpose关节对应图从源码可以看出来center点是neck(1)点。注意如果两点都邻接不可到中心点即距离都是inf,那么算作远心点。

self.get_hop_distance
def get_hop_distance(num_node, edge, max_hop=1):
    A = np.zeros((num_node, num_node))
    for i, j in edge: #构建邻接矩阵
        A[j, i] = 1
        A[i, j] = 1

    # compute hop steps
    hop_dis = np.zeros((num_node, num_node)) + np.inf
    transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
  
    arrive_mat = (np.stack(transfer_mat) > 0) # transfer_mat是list类型,需要将list堆叠成一个数组才能进行>操作
    for d in range(max_hop, -1, -1):
        hop_dis[arrive_mat[d]] = d
    return hop_dis

这一段代码中获得了带自环的邻接矩阵(是18乘18的方阵),非连接处是inf

self. get_adjacency
def get_adjacency(self, strategy):
    valid_hop = range(0, self.max_hop + 1, self.dilation) # 合法的距离值:0或1
            adjacency = np.zeros((self.num_node, self.num_node))
            for hop in valid_hop:
                adjacency[self.hop_dis == hop] = 1 # 将0|1的位置置1,inf抛弃
            normalize_adjacency = normalize_digraph(adjacency)#图卷积的预处理
    ...
    elif strategy == 'spatial': # 如果按论文的第三种划分方式
    A = []
    for hop in valid_hop: 
        a_root = np.zeros((self.num_node, self.num_node))
        a_close = np.zeros((self.num_node, self.num_node))
        a_further = np.zeros((self.num_node, self.num_node))
        for i in range(self.num_node):
            for j in range(self.num_node):
                if self.hop_dis[j, i] == hop:  # 如果结点j和结点i是邻结点
                	# 比较结点i和结点j分别到中心点的距离,中心点默认为为openpose输出的1结点
                    if self.hop_dis[j, self.center] == self.hop_dis[ i, self.center]:
                        a_root[j, i] = normalize_adjacency[j, i]
                    elif self.hop_dis[j, self. center] > self.hop_dis[i, self.center]:
                        a_close[j, i] = normalize_adjacency[j, i]
                    else:
                        a_further[j, i] = normalize_adjacency[j, i]
        if hop == 0:
            A.append(a_root) # A的第一维第1个矩阵:self distance matrix 对角阵
        else:
            A.append(a_root + a_close) # A的第一维第2个矩阵:列对结点到中心点的距离比行对应点到中心点的距离近或者相等(都为inf)
            A.append(a_further) # A的第一维第3个矩阵:列对应结点到中心点的距离比行对应点到中心点的距离远
    A = np.stack(A)
    self.A = A
    # 输出A的shape(3,18,18)
    ...
# 图卷积的预处理           
def normalize_digraph(A):
    Dl = np.sum(A, 0)  #计算邻接矩阵的度
    num_node = A.shape[0]
    Dn = np.zeros((num_node, num_node))
    for i in range(num_node):
        if Dl[i] > 0:
            Dn[i, i] = Dl[i]**(-1) #由每个点的度组成的对角矩阵
    AD = np.dot(A, Dn)
    return AD 

这段代码将会输出一个(3,18,18)的权值分组A矩阵。那么这个矩阵是怎么来的呢?这就要追溯到论文里提到的三种划分方法了:
在这里插入图片描述

  1. Uni-labeling,全部 B ( v t i ) B(v_{ti}) B(vt
评论 22
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值