深入理解TVM:详解Relay Op Attrs

低功耗蓝牙项目,需要一块懂省电的板

思澈 SF32LB52 芯片,BLE 协议栈深度优化,上手即开发

一、简介

Relay Op Attrs指的是Op中属性的定义,以前这是属于NNVM的部分,后来TVM引入了Relay替换掉了NNVM,Op Attrs的定义被放到了Relay中,先看一下Attrs相关的类定义继承关系:

图1

其中蓝色的类就是Op Attrs实际定义的地方,所有的Attrs都应该定义在include/tvm/relay/attrs这个文件夹的文件中。

二、从Conv2DAttrs定义开始看

因为Conv2D的属性个数和类型相对比较多,所以感觉更具有代表性,下面就以Conv2DAttrs的定义来看Op Attrs的定义:

struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
  Array<IndexExpr> strides;
  Array<IndexExpr> padding;
  Array<IndexExpr> dilation;
  int groups;
  IndexExpr channels;
  Array<IndexExpr> kernel_size;
  tvm::String data_layout;
  tvm::String kernel_layout;
  tvm::String out_layout;
  tvm::String auto_scheduler_rewritten_layout;
  DataType out_dtype;

  TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
    TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1})).describe("...");
    TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0})).describe("...");
    TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1})).describe("...");
    TVM_ATTR_FIELD(groups).set_default(1).describe("...");
    TVM_ATTR_FIELD(channels).describe("...").set_default(NullValue<IndexExpr>());
    TVM_ATTR_FIELD(kernel_size).describe("...").set_default(NullValue<Array<IndexExpr>>());
    TVM_ATTR_FIELD(data_layout).set_default("NCHW").describe("...");
    TVM_ATTR_FIELD(kernel_layout).set_default("OIHW").describe("...");
    TVM_ATTR_FIELD(out_layout).set_default("").describe("...");
    TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("...");
  }
};

这里面涉及了TVM_DECLARE_ATTRS和TVM_ATTR_FIELD两个宏,不方便理解,我不贴这两个宏的代码了,直接使用-E编译选项来把所有宏展开,代码整理后如下:

struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
  Array<IndexExpr> strides;
  Array<IndexExpr> padding;
  Array<IndexExpr> dilation;
  int groups;
  IndexExpr channels;
  Array<IndexExpr> kernel_size;
  tvm::String data_layout;
  tvm::String kernel_layout;
  tvm::String out_layout;
  tvm::String auto_scheduler_rewritten_layout;
  DataType out_dtype;

  static constexpr const char *_type_key = "relay.attrs.Conv2DAttrs";
  static const constexpr bool _type_final = true;
  static const constexpr int _type_child_slots = 0;
  static_assert(!::tvm::BaseAttrsNode::_type_final, "ParentObj marked as final");
  static uint32_t RuntimeTypeIndex() {
    if (Conv2DAttrs::_type_index != ::tvm::runtime::TypeIndex::kDynamic) {
      return Conv2DAttrs::_type_index;
    }
    return _GetOrAllocRuntimeTypeIndex();
  }
  static uint32_t _GetOrAllocRuntimeTypeIndex() {
    static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex(
        Conv2DAttrs::_type_key, Conv2DAttrs::_type_index,
        ::tvm::BaseAttrsNode::_GetOrAllocRuntimeTypeIndex(),
        Conv2DAttrs::_type_child_slots,
        Conv2DAttrs::_type_child_slots_can_overflow);
    return tindex;
  }
  template <typename FVisit> 
  void __VisitAttrs__(FVisit &__fvisit__) {
    __fvisit__("strides", &strides).set_default(Array<IndexExpr>({1, 1})).describe("...");
    __fvisit__("padding", &padding).set_default(Array<IndexExpr>({0, 0})).describe("...");
    __fvisit__("dilation", &dilation).set_default(Array<IndexExpr>({1, 1})).describe("...");
    __fvisit__("groups", &groups).set_default(1).describe("...");
    __fvisit__("channels", &channels).describe("...").set_default(NullValue<IndexExpr>());
    __fvisit__("kernel_size", &kernel_size).describe("...").set_default(NullValue<Array<IndexExpr>>());
    __fvisit__("data_layout", &data_layout).set_default("NCHW").describe("...");
    __fvisit__("kernel_layout", &kernel_layout).set_default("OIHW").describe("...");
    __fvisit__("out_layout", &out_layout).set_default("").describe("...");
    __fvisit__("out_dtype", &out_dtype).set_default(NullValue<DataType>()).describe("...");
  }
};

可以看到TVM_DECLARE_ATTRS和TVM_ATTR_FIELD这两个宏主要做了这两件事:

  1. 定义了Object类体系要求的静态变量和函数,这些静态变量和函数的更多细节可以参考《深入理解TVM:Object家族
  2. 定义了__VisitAttrs__这个模板函数,里面通过调用__fvisit__.set_default().describe()来给定义的attribute设置默认值和描述信息,__fvisit__是模板函数传进来的参数,它的类型FVisit是模板参数类型,最终__VisitAttrs__这个模板函数会根据FVisit的不同被实例化成不同的函数

三、__VisitAttrs__在哪里被调用

根据前面图1,Conv2DAttrs的父类是AttrsNode,定义在include/tvm/ir/attrs.h,它是之前讲的《C/C++杂谈:CRTP》中的静态多态的一个典型应用,简化代码如下:

template <typename DerivedType>
class AttrsNode : public BaseAttrsNode {
public:
  void VisitAttrs(AttrVisitor* v) {
    ::tvm::detail::AttrNormalVisitor vis(v);
    self()->__VisitAttrs__(vis);
  }
  void VisitNonDefaultAttrs(AttrVisitor* v) {
    ::tvm::detail::AttrNonDefaultVisitor vis(v);
    self()->__VisitAttrs__(vis);
  }
  void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown)
  {。。。}

private:
  DerivedType* self() const {
    return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
  }
};

可以看到__VisitAttrs__在VisitAttrs和VisitNonDefaultAttrs中以CRTP的方式被调用,其实还被上面类中的InitByPackedArgs调用,这个函数代码比较多稍微复杂一些,上面没有列出来,后面讲Op Attrs的初始化的时候再详细说。

四、__VisitAttrs__的模板参数

__VisitAttrs__的模板参数是一系列的xxxVisitor类,定义在include/tvm/ir/attrs.h中,通过定义不同的xxxVisitor的实现,就可以使用相同的__VisitAttrs__接口来做不同的动作,目前已经定义过的模板参数类型有:

  • AttrNormalVisitor,返回AttrNopEntry
  • AttrsSEqualVisitor,返回AttrNopEntry
  • AttrsSHashVisitor,返回AttrNopEntry
  • AttrExistVisitor,返回AttrNopEntry
  • AttrInitVisitor<FFind>,返回AttrInitEntry<T>
  • AttrDocVisitor,返回AttrDocEntry
  • AttrNonDefaultVisitor,返回AttrTriggerNonDefaultEntry<T>

上面列出的xxxVisitor系列类中,最复杂的是AttrInitVisitor<FFind>,返回AttrInitEntry<T>,它用于给Op Attrs赋值,并且检查大小范围是否合法,后面讲Op Attrs的初始化的时候再详细说。简单来说,xxxVisitor类提供类似下面的函数调用运算符重载,并且返回一个xxxEntry:

class xxxVisitor {
  template <typename T>
  xxxEntry operator()(const char* key, T* v) {
    return xxxEntry(...);
  }
};

xxxVisitor类提供的函数调用运算符重载功能使得__VisitAttrs__接口中可以进行__fvisit__调用,返回的xxxEntry提供了支持链式调用的set_default、describe、set_lower_bound、set_upper_bound等接口:

template <typename FVisit> 
void __VisitAttrs__(FVisit &__fvisit__) {
  __fvisit__("strides", &strides).set_default(Array<IndexExpr>({1, 1})).describe("...");
}

五、Op Attrs初始化

第三节中的InitByPackedArgs函数就是用来对Op Attrs做初始化的,它调用图1中的蓝色子类中提供的__VisitAttrs__模板函数接口,使用AttrInitVisitor<FFind>作为模板参数,来给Op Attrs赋值,下面是精简代码:

void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
  ICHECK_EQ(args.size() % 2, 0);
  const int kLinearSearchBound = 16;
  int hit_count = 0;
    
  // linear search.
  auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
    for (int i = 0; i < args.size(); i += 2) {
      ICHECK_EQ(args.type_codes[i], kTVMStr);
      if (!std::strcmp(key, args.values[i].v_str)) {
        *val = args[i + 1];
        return true;
      }
    }
    return false;
  };
    
  auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
  self()->__VisitAttrs__(vis);
  hit_count = vis.hit_count_;
}

这里的CreateInitVisitor返回一个具体的AttrInitVisitor<ffind>可调用对象,它通过下面代码完成对Op Attrs的赋值:

template <typename FFind> class AttrInitVisitor {
  template <typename T>
  AttrInitEntry<T> operator()(const char* key, T* value) {
    TVMArgValue val;
    AttrInitEntry<T> opt;
    opt.type_key_ = type_key_;
    opt.key_ = key;
    opt.value_ = value;
    if (ffind_(key, &val)) {
      SetValue(value, val);
      opt.value_missing_ = false;
      ++hit_count_;
    } else {
      opt.value_missing_ = true;
    }
    return std::move(opt);
  }
};

而对赋值内容检查合法性等相关的工作由对返回的AttrInitEntry<T>通过进行链式调用完成,即__fvisit__("strides", &strides).set_default().describe()最终会调用到下面这些精简过的接口中:

template <typename T> struct AttrInitEntry {
  const char* key_;
  T* value_;

  AttrInitEntry<T>& set_lower_bound(const T& begin) {
    if (begin > *value_) {
      throw AttrError(。。。);
    }
    return *this;
  }
  AttrInitEntry<T>& set_upper_bound(const T& end) {
    if (*value_ > end) {
      throw AttrError(。。。);
    }
    return *this;
  }
  AttrInitEntry<T>& set_default(const T& value) {
    *value_ = value;
    value_missing_ = false;
    return *this;
  }
  AttrInitEntry<T>& describe(const char* str) { return *this; }
};

六、Summary and Reference

本文介绍了Relay Op Attrs的定义和赋值相关的主要技术细节,现在的Op Attrs的定义和之前NNVM中的定义方式长的差不多,但是所依赖的基类和辅助类的实现已经完全变了,现在的实现更清晰,扩展性更好。本文主要参考的是TVM的官方代码和官方文档,之前的文章已经列过多次了,这里就不再重新贴了。

低功耗蓝牙项目,需要一块懂省电的板

思澈 SF32LB52 芯片,BLE 协议栈深度优化,上手即开发

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值