PaddlePaddle inference 源码分析(三)
本节介绍Operator定义注册机制
简介
Op的核心方法是Run,Run方法需要两方面的资源:数据资源Scope和计算资源Place。框架内部有一个全局的DeviceContextPool,用来记录Place和 DeviceContext之间的对应的关系,即每个Place有且仅有一个DeviceContext与之对应, DeviceContext中存放了当前设备的计算资源,比如对于GPU,这些资源包括cudnn_handle、 cublas_handle、stream等,所有的计算(数据拷贝和CUDA Kernel)都必须在绑定到 DeviceContext中的stream上。
Fluid框架的设计理念是可以在多种设备及第三方库上运行,有些Op的实现可能会因为设 备或者第三方库的不同而不同,为此Fluid引入了OpKernel的方式,即一个Op可以有多个 OpKernel,这类Op继承自OperatorWithKernel,这类Op的代表是conv,conv_op的OpKerne 有:GemmConvKernel、CUDNNConvOpKernel、ConvMKLDNNOpKernel,且每个 OpKernel都有double和float两种数据类型。不需要OpKernel的代表有WhileOp等。
PS: 对于直接继承自OperatorBase的OP,只会调用注册OP的宏定义,将构造函数等放入OPInfo中,放入OpInfoMap。然后根据模型文件创建对应的OP对象放入vector中,在执行时顺序执行vector中的OP。对于继承自OperatorWithKernel(OperatorWithKernel继承自OperatorBase)的OP,除了调用注册OP的宏定义,还会添加基于framework::OpKernel<T>的Kernel类,如CPURangeKernel、CUDARangeKernel等。对于Kernel类会调用对应注册宏,如
//注册CPU kernel REGISTER_OP_CPU_KERNEL(range, ops::CPURangeKernel<int>, ops::CPURangeKernel<float>, ops::CPURangeKernel<double>, ops::CPURangeKernel<int64_t>); //注册GPU kernel REGISTER_OP_CUDA_KERNEL(range, ops::CUDARangeKernel<int>, ops::CUDARangeKernel<int64_t>, ops::CUDARangeKernel<float>, ops::CUDARangeKernel<double>);
这里注册的kernel保存在unordered_map<key=op名称,value=OpKernelMap>,OpKernelMap的组成为<OpKernelType, OpKernelFunc>,其中OpKernelType包含了place、数据类型信息等。如果执行时OP是继承自OperatorWithKernel,那么就会调用OperatorWithKernel::RunImpl函数,选择对应kernel类执行。
下边以conv2d为例来介绍注册及使用流程
一、注册
1、新建OP注册
OP注册时通过宏来进行,一般形式如下:
REGISTER_OPERATOR(op_type,
OperatorBase,
op_maker_and_checker_maker,
op_grad_opmaker,
op_infer_var_shape,
op_infer_var_type)
例如:
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker, ops::ConvOpInferVarType, ops::Conv2DGradMaker<paddle::framework::OpDesc>, ops::Conv2DGradMaker<paddle::imperative::OpBase>);
对于一般的OP,前三个参数是必须的。实际使用时,不必按照特定顺序填写。注册器会根据模板特化逐个注册。
- op_type:op的名字
- OpeartorBase:该OP的对象
- op_maker_and_checker_maker是op的maker和op中attr的checker。
- op_grad_opmaker:创建当前Op对应的反向OP。如果Op有反向,必须要有op_grad_opmaker,因为backward会从正向的Op中获取反向Op的Maker。默认的op_grad_opmaker:DefaultGradOpMaker(grad_op_desc_maker.h)。它会将前向Op的输入和输出都作为反向Op的输入,将前向Op的输入的剃度作为反向Op的输出,并将前向Op的属性拷贝过来。使用DefaultGradOpMaker带来的问题是会将前向Op的所有输入输出都作为反向OP的输入,即使这个输入是非必须的,这会导致无法作内存优化,排除无用变量。
- 框架没有默认的op_infer_var_shape提供。因此在保证shape不会出问题的情况下,OP可以不对optput的shape作推断,即可以不提供op_infer_var_shape,但是如果shape出问题会导致后续OP的shape都出错。如果Op是继承自OperatorWithKernel,可以通过覆盖OperatorWithKernel中的 InferShape方法的方式不去注册op_infer_var_shape,这也是大多数带Kernel的Op的 做法。
- 建议每个OP都注册op_infer_var_type。在InferVarType中根据输入的Var的 type和dtype推断输出Var的type和dtype
对于继承自OperatorWithKernel的Op,需要分别注册OpKernle
2、REGISTER_OPERATOR详解
-
/* The variadic arguments should be class types derived from one of the following classes: OpProtoAndCheckerMaker GradOpDescMakerBase VarTypeInference InferShapeBase */ #define REGISTER_OPERATOR(op_type, op_class, ...) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, \ "REGISTER_OPERATOR must be called in global namespace"); \ static ::paddle::framework::OperatorRegistrar<op_class, ##__VA_ARGS__> \ __op_registrar_##op_type##__(#op_type); \ int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ return 0; \ }
paddle/fluid/framework/op_registry.h 该宏定义用于注册新建OP。第一步会检查op_type是否已存在。第二步执行具体的注册逻辑。第三步
- STATIC_ASSERT_GLOBAL_NAMESPACE(fluid/extension/include/ext_op_meta_info.h)检查op_type是否已经存在,声明一个特定名称的结构体,然后比较全局作用域中是否存在同名类型,以此来判断名称是否存在
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ struct __test_global_namespace_##uniq_name##__ {}; \ static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ __test_global_namespace_##uniq_name##__>::value, \ msg)
- OpeartorRegistrar(framework/op_registry.h)这里是实际进行注册的逻辑。主体逻辑为将注册的各项函数添加到OpInfo中,然后将info存放到单例的OpInfoMap中。下面着重梳理OperatorRegistrarRecursive的逻辑。
template <typename... ARGS> struct OperatorRegistrar : public Registrar { explicit OperatorRegistrar(const char* op_type) { PADDLE_ENFORCE_EQ( OpInfoMap::Instance().Has(op_type), false, platform::errors::AlreadyExists( "Operator '%s' is registered more than once.", op_type)); static_assert(sizeof...(ARGS) != 0, "OperatorRegistrar should be invoked at least by OpClass"); OpInfo info; details::OperatorRegistrarRecursive<0, false, ARGS...>(op_type, &info); OpInfoMap::Instance().Insert(op_type, info); } };
- OperatorRegistrarRecursive(frame/details/op_registry.h)。模板类,有两个特化,一个at_end=false,一个at_end=true。其处理的基本逻辑为递归方式。初始调用时,I=0,at_end=false,并传入ARGS。而后使用std::typle_element取出ARGS[I],然后调用对应特化的OpInfoFiller将ARGS[I]放入info中。接着I+1,如果I+1!=ARGS.size(),则会继续调用I=1,at_end=false,ARGS 的自身。这样依次遍历ARGS,直到处理完所有的注册函数后,会进入at_end=true结束递归。
template <size_t I, bool at_end, typename... ARGS> class OperatorRegistrarRecursive; template <size_t I, typename... ARGS> class OperatorRegistrarRecursive<I, false, ARGS...> { public: using T = typename std::tuple_element<I, std::tuple<ARGS...>>::type; OperatorRegistrarRecursive(const char* op_type, OpInfo* info) { OpInfoFiller<T> fill; fill(op_type, info); constexpr auto size = sizeof...(ARGS); OperatorRegistrarRecursive<I + 1, I + 1 == size, ARGS...> reg(op_type, info); (void)(reg); } }; template <size_t I, typename... ARGS> class OperatorRegistrarRecursive<I, true, ARGS...> { public: OperatorRegistrarRecursive(const char* op_type, OpInfo* info) {} };
-
例如对于ops::ConvOp这个OperatorBase类型,会调用对应OpInfoFiller<T,KOperator>
template <typename T> struct OpInfoFiller<T, kOperator> { void operator()(const char* op_type, OpInfo* info) const { PADDLE_ENFORCE_EQ(info->creator_, nullptr, platform::errors::AlreadyExists( "OpCreator of %s has been registered", op_type)); // 将OperatorBase构造函数放入info->creator这个函数指针中 info->creator_ = [](const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, const AttributeMap& attrs) { return new T(type, inputs, outputs, attrs); }; // 如果T为OperatorWithKernel类型,会有更多操作 if (std::is_base_of<OperatorWithKernel, T>::value) { PADDLE_ENFORCE_EQ( info->infer_shape_, nullptr, platform::errors::AlreadyExists( "Duplicate InferShapeFN of %s has been registered", op_type)); OperatorWithKernel* op = dynamic_cast<OperatorWithKernel*>(info->creator_( std::string{}, VariableNameMap{}, VariableNameMap{}, AttributeMap{})); PADDLE_ENFORCE_NOT_NULL(op, platform::errors::InvalidArgument( "%s should have kernels", op_type)); info->infer_shape_ = [op](InferShapeContext* ctx) { op->InferShape(ctx); }; } } };
- 这里OpInfoFiller进行类型推导的逻辑是:
- 首先有一个枚举类与ARGS类型对应
enum OpInfoFillType { kOperator = 0, kOpProtoAndCheckerMaker = 1, kGradOpDescMaker = 2, kVarTypeInference = 3, kShapeInference = 4, kInplaceOpInference = 5, kNoNeedBufferVarsInference = 6, kGradOpBaseMaker = 7, kUnknown = -1 };
- OpInfoFiller会调用OpInfoFillTypeID对T进行类型推导
template <typename T, OpInfoFillType = OpInfoFillTypeID<T>::ID()> struct OpInfoFiller;
- 推导的方式也是类似用模板特化的方式,ID函数会获得特化类的kType,这里调用的时候也是递归的方式
template <typename T> struct OpInfoFillTypeID { static constexpr OpInfoFillType ID() { return internal::OpInfoFillTypeGetter<T>::kType; } };
template <typename T> using OpInfoFillTypeGetter = OpInfoFillTypeGetterImpl<T, 0, kOpRegistryClassNumber, kOpRegistryClassNumber == 0, IsMatchedBaseType<T, 0>()>;
这里kOpRegistryClassNumber是枚举列表的长度
using OpRegistryClasses = std::tuple< // NOLINT TypePair<OperatorBase, kOperator>, // NOLINT TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>, // NOLINT TypePair<GradOpDescMakerBase, kGradOpDescMaker>, // NOLINT TypePair<imperative::GradOpBaseMakerBase, kGradOpBaseMaker>, // NOLINT TypePair<VarTypeInference, kVarTypeInference>, // NOLINT TypePair<InferShapeBase, kShapeInference>, // NOLINT TypePair<InplaceOpInference, kInplaceOpInference>, // NOLINT TypePair<NoNeedBufferVarsInference, kNoNeedBufferVarsInference> // NOLINT >; static constexpr int kOpRegistryClassNumber = std::tuple_size<OpRegistryClasses>::value;
IsMatchedBaseType用于判断T与kPos的类型是否一致。
template <typename T, int kPos> static inline constexpr bool IsMatchedBaseType() { return IsMatchedBaseTypeImpl< T, kPos, (kPos >= 0 && kPos < kOpRegistryClassNumber)>::kValue; }
这里如果kPos超出了列表长度或者设置了非法值,会直接返回false
template <typename T, int kPos> struct IsMatchedBaseTypeImpl<T, kPos, false> { static constexpr bool kValue = false; };
否则,会比较T与OpRegistryClasses的类型是否一致
// TypePair定义,T与枚举OpInfoFillerType一一对应 template <typename T, OpInfoFillType kType> struct TypePair { using Type = T; static constexpr OpInfoFillType kFillType = kType; }; // 比较传入的T与OpRegistryClasses中T的类型 template <typename T, int kPos, bool kIsBounded /* = true*/> struct IsMatchedBaseTypeImpl { using PairType = typename std::tuple_element<kPos, OpRegistryClasses>::type; static constexpr bool kValue = std::is_base_of<typename PairType::Type, T>::value; };
- OpInfoFillTypeGetterImpl是一个递归调用,如果IsMatchedBaseType返回false,会将kStart+1继续比较,直到匹配上后返回kType也就是枚举中的序号
// 没匹配上就kStart+1 template <typename T, int kStart, int kEnd> struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, false> { static constexpr OpInfoFillType kType = OpInfoFillTypeGetterImpl<T, kStart + 1, kEnd, kStart + 1 == kEnd, IsMatchedBaseType<T, kStart + 1>()>::kType; }; // 匹配上直接返回kType,也就是OpRegistryClasses中T对应的枚举值 template <typename T, int kStart, int kEnd> struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, true> { using PairType = typename std::tuple_element<kStart, OpRegistryClasses>::type; static constexpr OpInfoFillType kType = PairType::kFillType; };
- 首先有一个枚举类与ARGS类型对应
- 宏中的第三步是定义一个函数,并且调用一下第二步创建的静态OperatorRegistrar变量。这是因为在framework打包时,这些注册时声明的变量并没有被调用过,会被编译器移除。因此创建一个空函数调用一下,保证framework编译打包时能保存该变量
// 创建一个函数调用下创建的变量 int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ return 0; \ } // Touch实际是空函数 class Registrar { public: // In our design, various kinds of classes, e.g., operators and kernels, // have their corresponding registry and registrar. The action of // registration is in the constructor of a global registrar variable, which // are not used in the code that calls package framework, and would // be removed from the generated binary file by the linker. To avoid such // removal, we add Touch to all registrar classes and make USE_OP macros to // call this method. So, as long as the callee code calls USE_OP, the global // registrar variable won't be removed by the linker. void Touch() {} };
二、创建OP
在《PaddlePaddle Inference源码分析(二)》 中介绍到,PrepareExecutor会创建OP对象并放入Executor中。我们从准备阶段开始看起。
- NaiveExecutor::Prepare,会调用CreateOps
void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc, int block_id, bool with_feed_fetch_ops) { if (!scope) { scope_ = new framework::Scope; } else { scope_ = scope; } VLOG(3) << "NaiveExecutor init with scope " << scope; CreateOps(program_desc, block_id, with_feed_fetch_ops); }
- CreateOps会从ProgramDesc中保存的模型文件信息中获取到所有OP的信息OpDesc(framework/block_desc.cc),然后使用OpDesc创建OP对象
void NaiveExecutor::CreateOps(const ProgramDesc &desc, int block_id, bool with_feed_fetch_ops) { for (const auto &op_desc : desc.Block(block_id).AllOps()) { if (!with_feed_fetch_ops && (op_desc->Type() == "feed" || op_desc->Type() == "fetch")) { LOG(INFO) << "--- skip [" << op_desc->Input("X")[0] << "], " << op_desc->Type() << " -> " << op_desc->Output("Out")[0]; continue; } ops_.emplace_back(OpRegistry::CreateOp(*op_desc)); } }
- 这里OpRegistry::CreateOp均为静态函数。实际逻辑为从OpInfoMap(单例,全局共享)中根据op_type取出对应的OpInfo,再调用OpInfo中的Createor(2.5节中创建的lambda函数,放入了构造函数)创建OperatorBase对象。
// 接口函数,取出OpDesc信息后进行实际调用 std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) { return CreateOp(op_desc.Type(), op_desc.Inputs(), op_desc.Outputs(), op_desc.GetAttrMap()); } //实际调用 std::unique_ptr<OperatorBase> OpRegistry::CreateOp( const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, AttributeMap attrs, bool attr_check) { auto& info = OpInfoMap::Instance().Get(type); if (attr_check && info.Checker() != nullptr) { info.Checker()->Check(&attrs); } auto op = info.Creator()(type, inputs, outputs, attrs); return std::unique_ptr<OperatorBase>(op); }
三、调用OP
1、OP的调用
(1)OperatorBase::Run
传入参数为scope和place,
(2)OperatorBase类型的Runimpl
对于单纯OP,会自己实现RunImpl函数,进行直接计算。例如打印错误信息等OP
2、OpWithKernel的调用
(1)OperatorBase::Run
同上
(2)OperatorWithKernel::RunImpl
获取需要的资源,保存DeviceContext、参数权重等。再根据place以及算子自身属性选择对应的Kernel,然后处理好参数shape后,执行kernel的实际计算Compute函数
void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place, RuntimeContext* runtime_ctx) const { // 获取DeviceContext platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); ... // 选择对应器件的kernel if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) { ChooseKernel(*runtime_ctx, scope, place); } // do data transformScope &transfer_scope; std::vector<std::string> transfered_inplace_vars; Scope* transfer_scope = nullptr; { platform::RecordEvent record_event("prepare_data", platform::EventRole::kInnerOp); if (need_prepare_data_) { transfer_scope = PrepareData(scope, *kernel_type_, &transfered_inplace_vars, runtime_ctx); } } // exec scope is the scope that kernel actually executed on. const Scope& exec_scope = (transfer_scope == nullptr ? scope : *transfer_scope); if (!(kernel_type_->place_ == dev_ctx->GetPlace())) { dev_ctx = pool.Get(kernel_type_->place_); } // 计算输入shape if (!all_kernels_must_compute_runtime_shape_) { platform::RecordEvent record_event("infer_shape", platform::EventRole::kInnerOp); RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx); this->InferShape(&infer_shape_ctx); } // 执行kernel的实际计算 // not Scope. Imperative mode only pass inputs and get outputs. { platform::RecordEvent record_event("compute", platform::EventRole::kInnerOp); (*kernel_func_)( ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); } if (!transfered_inplace_vars.empty()) { // there is inplace variable has been transferred. TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope); } // See [ Why need handle complex gradient to real gradient? ] // Only handle the case where the current kernel data type is complex if (framework::IsComplexType(kernel_type_->data_type_)) { HandleComplexGradToRealGrad(scope, runtime_ctx); } if (FLAGS_enable_unused_var_check) { // skip op that uses mkldnn because it has different memory reuse strategy. // use attr here because some GradMakers (like ActivationGradOpMaker) add // input when use_mkldnn=true; if (!(HasAttr("use_mkldnn") && Attr<bool>("use_mkldnn"))) { CheckUnusedVar(*this, scope); } } /*For profiling/benchmark only*/ ... } ... // To solve issue #15032, have a discussion with @Luotao for cpu inference, // do not cache transfer scope, hence in this case delete transfer scope // after run to avoid memory leak if (transfer_scope && !run_by_executor_ && !enable_cache_transfer_scope_) { scope.DeleteScope(transfer_scope); } }