使用方法
参考tensorflow/lite/micro/examples/xxx 目录下的使用方法, 以hello_world为例,
文件hello_world_test.cc
1. 创建MicroErrorReporter object
tflite::MicroErrorReporter micro_error_reporter;
2. 有tflite model文件得到 tflite::Modle 结构体
const tflite::Model* model = ::tflite::GetModel(g_model);
3. 创建OpsResolver对象
这里是创建包含所有算子的对象
tflite::AllOpsResolver resolver;
如果是根据实际使用的算子创建OpsResolver, 使用
static tflite::MicroMutableOpResolver<5> micro_op_resolver;
micro_op_resolver.AddConv2D();
micro_op_resolver.AddDepthwiseConv2D();
micro_op_resolver.AddFullyConnected();
micro_op_resolver.AddMaxPool2D();
micro_op_resolver.AddSoftmax();
4. 提供一段连续的内存,用于model的内存(placement new)
uint8_t tensor_arena[tensor_arena_size];
5. 创建MicroInterpreter对象
MicroInterpreter(const Model* model, const MicroOpResolver& op_resolver,
uint8_t* tensor_arena, size_t tensor_arena_size,
ErrorReporter* error_reporter,
tflite::Profiler* profiler = nullptr);
构造函数的参数类型是基类,而创建对象的参数是派生类的引用或指针,实现了多态性
tflite::MicroInterpreter interpreter(
model, resolver, tensor_arena, tensor_arena_size, µ_error_reporter);
6. 为所有的Tensor/ScratchBuffer分配内存,为所有支撑的变量分配内存等
interpreter.AllocateTensors();
//因内存的大小是个超参数常量,可以通过使用情况调整下,如参考arena_used_bytes()
interpreter.arena_used_bytes();
7. 得到输入对应的Tensor
TfLiteTensor* input = interpreter.input(0);
检测input tensor的一些属性
TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size);
// The value of each element gives the length of the corresponding tensor.
// We should expect two single element tensors (one is contained within the
// other).
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[1]);
// The input is a 32 bit floating point value
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, input->type);
8. Provide an input value赋值input tensor
input->data.f[0] = 0.;
9. 进行推断
interpreter.Invoke();
10. 得到推断结果
// Obtain a pointer to the output tensor and make sure it has the
// properties we expect. It should be the same as the input tensor.
TfLiteTensor* output = interpreter.output(0);
//检测推断结果的属性
TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[1]);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteFloat32, output->type);
// Obtain the output value from the tensor
float value = output->data.f[0];
按上面的步骤,分析下代码实现, 尽量明白一个 AIEngine reference的实现。
ErrorReporter

ErrorReport class 是用来输出log, 各个平台如PC/ 串口打印log的具体实现不同,使用虚函数继承是合理的,接口类Class ErrorReporter:lite/core/api/error_report.h
虚函数是virtual int Report(const char*, va_list args), 普通成员函数调用虚函数也得到了跨平台的接口。
对外的接口是个宏: 其中... 表示所有输入
#define TF_LITE_REPORT_ERROR(reporter, ...) \
do { &nbs

本文深入解析了TensorFlow Lite Micro中的AI引擎实现,涵盖了模型加载、内存分配、OpResolver、MicroInterpreter的工作原理,以及关键组件如ErrorReporter、Model、OpResolver和MicroInterpreter的内部细节。

1150

被折叠的 条评论
为什么被折叠?



