一、简介
Relay Op Attrs指的是Op中属性的定义,以前这是属于NNVM的部分,后来TVM引入了Relay替换掉了NNVM,Op Attrs的定义被放到了Relay中,先看一下Attrs相关的类定义继承关系:

图1
其中蓝色的类就是Op Attrs实际定义的地方,所有的Attrs都应该定义在include/tvm/relay/attrs这个文件夹的文件中。
二、从Conv2DAttrs定义开始看
因为Conv2D的属性个数和类型相对比较多,所以感觉更具有代表性,下面就以Conv2DAttrs的定义来看Op Attrs的定义:
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
tvm::String data_layout;
tvm::String kernel_layout;
tvm::String out_layout;
tvm::String auto_scheduler_rewritten_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1})).describe("...");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0})).describe("...");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1})).describe("...");
TVM_ATTR_FIELD(groups).set_default(1).describe("...");
TVM_ATTR_FIELD(channels).describe("...").set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size).describe("...").set_default(NullValue<Array<IndexExpr>>());
TVM_ATTR_FIELD(data_layout).set_default("NCHW").describe("...");
TVM_ATTR_FIELD(kernel_layout).set_default("OIHW").describe("...");
TVM_ATTR_FIELD(out_layout).set_default("").describe("...");
TVM_ATTR_FIELD(out_dtype).set_default(NullValue<DataType>()).describe("...");
}
};
这里面涉及了TVM_DECLARE_ATTRS和TVM_ATTR_FIELD两个宏,不方便理解,我不贴这两个宏的代码了,直接使用-E编译选项来把所有宏展开,代码整理后如下:
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
tvm::String data_layout;
tvm::String kernel_layout;
tvm::String out_layout;
tvm::String auto_scheduler_rewritten_layout;
DataType out_dtype;
static constexpr const char *_type_key = "relay.attrs.Conv2DAttrs";
static const constexpr bool _type_final = true;
static const constexpr int _type_child_slots = 0;
static_assert(!::tvm::BaseAttrsNode::_type_final, "ParentObj marked as final");
static uint32_t RuntimeTypeIndex() {
if (Conv2DAttrs::_type_index != ::tvm::runtime::TypeIndex::kDynamic) {
return Conv2DAttrs::_type_index;
}
return _GetOrAllocRuntimeTypeIndex();
}
static uint32_t _GetOrAllocRuntimeTypeIndex() {
static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex(
Conv2DAttrs::_type_key, Conv2DAttrs::_type_index,
::tvm::BaseAttrsNode::_GetOrAllocRuntimeTypeIndex(),
Conv2DAttrs::_type_child_slots,
Conv2DAttrs::_type_child_slots_can_overflow);
return tindex;
}
template <typename FVisit>
void __VisitAttrs__(FVisit &__fvisit__) {
__fvisit__("strides", &strides).set_default(Array<IndexExpr>({1, 1})).describe("...");
__fvisit__("padding", &padding).set_default(Array<IndexExpr>({0, 0})).describe("...");
__fvisit__("dilation", &dilation).set_default(Array<IndexExpr>({1, 1})).describe("...");
__fvisit__("groups", &groups).set_default(1).describe("...");
__fvisit__("channels", &channels).describe("...").set_default(NullValue<IndexExpr>());
__fvisit__("kernel_size", &kernel_size).describe("...").set_default(NullValue<Array<IndexExpr>>());
__fvisit__("data_layout", &data_layout).set_default("NCHW").describe("...");
__fvisit__("kernel_layout", &kernel_layout).set_default("OIHW").describe("...");
__fvisit__("out_layout", &out_layout).set_default("").describe("...");
__fvisit__("out_dtype", &out_dtype).set_default(NullValue<DataType>()).describe("...");
}
};
可以看到TVM_DECLARE_ATTRS和TVM_ATTR_FIELD这两个宏主要做了这两件事:
- 定义了Object类体系要求的静态变量和函数,这些静态变量和函数的更多细节可以参考《深入理解TVM:Object家族》
- 定义了__VisitAttrs__这个模板函数,里面通过调用__fvisit__.set_default().describe()来给定义的attribute设置默认值和描述信息,__fvisit__是模板函数传进来的参数,它的类型FVisit是模板参数类型,最终__VisitAttrs__这个模板函数会根据FVisit的不同被实例化成不同的函数
三、__VisitAttrs__在哪里被调用
根据前面图1,Conv2DAttrs的父类是AttrsNode,定义在include/tvm/ir/attrs.h,它是之前讲的《C/C++杂谈:CRTP》中的静态多态的一个典型应用,简化代码如下:
template <typename DerivedType>
class AttrsNode : public BaseAttrsNode {
public:
void VisitAttrs(AttrVisitor* v) {
::tvm::detail::AttrNormalVisitor vis(v);
self()->__VisitAttrs__(vis);
}
void VisitNonDefaultAttrs(AttrVisitor* v) {
::tvm::detail::AttrNonDefaultVisitor vis(v);
self()->__VisitAttrs__(vis);
}
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown)
{。。。}
private:
DerivedType* self() const {
return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
}
};
可以看到__VisitAttrs__在VisitAttrs和VisitNonDefaultAttrs中以CRTP的方式被调用,其实还被上面类中的InitByPackedArgs调用,这个函数代码比较多稍微复杂一些,上面没有列出来,后面讲Op Attrs的初始化的时候再详细说。
四、__VisitAttrs__的模板参数
__VisitAttrs__的模板参数是一系列的xxxVisitor类,定义在include/tvm/ir/attrs.h中,通过定义不同的xxxVisitor的实现,就可以使用相同的__VisitAttrs__接口来做不同的动作,目前已经定义过的模板参数类型有:
- AttrNormalVisitor,返回AttrNopEntry
- AttrsSEqualVisitor,返回AttrNopEntry
- AttrsSHashVisitor,返回AttrNopEntry
- AttrExistVisitor,返回AttrNopEntry
- AttrInitVisitor<FFind>,返回AttrInitEntry<T>
- AttrDocVisitor,返回AttrDocEntry
- AttrNonDefaultVisitor,返回AttrTriggerNonDefaultEntry<T>
上面列出的xxxVisitor系列类中,最复杂的是AttrInitVisitor<FFind>,返回AttrInitEntry<T>,它用于给Op Attrs赋值,并且检查大小范围是否合法,后面讲Op Attrs的初始化的时候再详细说。简单来说,xxxVisitor类提供类似下面的函数调用运算符重载,并且返回一个xxxEntry:
class xxxVisitor {
template <typename T>
xxxEntry operator()(const char* key, T* v) {
return xxxEntry(...);
}
};
xxxVisitor类提供的函数调用运算符重载功能使得__VisitAttrs__接口中可以进行__fvisit__调用,返回的xxxEntry提供了支持链式调用的set_default、describe、set_lower_bound、set_upper_bound等接口:
template <typename FVisit>
void __VisitAttrs__(FVisit &__fvisit__) {
__fvisit__("strides", &strides).set_default(Array<IndexExpr>({1, 1})).describe("...");
}
五、Op Attrs初始化
第三节中的InitByPackedArgs函数就是用来对Op Attrs做初始化的,它调用图1中的蓝色子类中提供的__VisitAttrs__模板函数接口,使用AttrInitVisitor<FFind>作为模板参数,来给Op Attrs赋值,下面是精简代码:
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
ICHECK_EQ(args.size() % 2, 0);
const int kLinearSearchBound = 16;
int hit_count = 0;
// linear search.
auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
for (int i = 0; i < args.size(); i += 2) {
ICHECK_EQ(args.type_codes[i], kTVMStr);
if (!std::strcmp(key, args.values[i].v_str)) {
*val = args[i + 1];
return true;
}
}
return false;
};
auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
self()->__VisitAttrs__(vis);
hit_count = vis.hit_count_;
}
这里的CreateInitVisitor返回一个具体的AttrInitVisitor<ffind>可调用对象,它通过下面代码完成对Op Attrs的赋值:
template <typename FFind> class AttrInitVisitor {
template <typename T>
AttrInitEntry<T> operator()(const char* key, T* value) {
TVMArgValue val;
AttrInitEntry<T> opt;
opt.type_key_ = type_key_;
opt.key_ = key;
opt.value_ = value;
if (ffind_(key, &val)) {
SetValue(value, val);
opt.value_missing_ = false;
++hit_count_;
} else {
opt.value_missing_ = true;
}
return std::move(opt);
}
};
而对赋值内容检查合法性等相关的工作由对返回的AttrInitEntry<T>通过进行链式调用完成,即__fvisit__("strides", &strides).set_default().describe()最终会调用到下面这些精简过的接口中:
template <typename T> struct AttrInitEntry {
const char* key_;
T* value_;
AttrInitEntry<T>& set_lower_bound(const T& begin) {
if (begin > *value_) {
throw AttrError(。。。);
}
return *this;
}
AttrInitEntry<T>& set_upper_bound(const T& end) {
if (*value_ > end) {
throw AttrError(。。。);
}
return *this;
}
AttrInitEntry<T>& set_default(const T& value) {
*value_ = value;
value_missing_ = false;
return *this;
}
AttrInitEntry<T>& describe(const char* str) { return *this; }
};
六、Summary and Reference
本文介绍了Relay Op Attrs的定义和赋值相关的主要技术细节,现在的Op Attrs的定义和之前NNVM中的定义方式长的差不多,但是所依赖的基类和辅助类的实现已经完全变了,现在的实现更清晰,扩展性更好。本文主要参考的是TVM的官方代码和官方文档,之前的文章已经列过多次了,这里就不再重新贴了。

5193

被折叠的 条评论
为什么被折叠?



