OneFlow原始碼解析:Op、Kernel與直譯器

語言: CN / TW / HK

撰文|鄭建華

更新|趙露陽

1

Op與Kernel的註冊

繼續追蹤執行流程會發現,ReluFunctor在構造UserOpExpr時會用到UserOpRegistryMgr管理的Op與Kernel。Op表示運算元的描述資訊,Kernel在不同裝置上實現計算。

註冊資訊儲存在私有的map變數中。UserOpRegistryMgr的標頭檔案

hhttp://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/framework/user_op_registry_manager.h)中定義了3個巨集,REGISTER_USER_OP、REGISTER_USER_OP_GRAD、REGISTER_USER_KERNEL分別用於註冊op、grad_op、kernel。

1.1 ReluOp的註冊

REGISTER_USER_OP負責UserOp的註冊。通過檢索程式碼可以找到這個巨集的使用場景。ReluOp相關的原始碼在這3個檔案中:

  • class定義:
  • build/oneflow/core/framework/op_generated.h
  • 註冊op、op的部分實現:
  • build/oneflow/core/framework/op_generated.cpp
  • 主要實現:
  • oneflow/oneflow/user/ops/relu_op.cpp

REGISTER_USER_OP巨集在op_generated.cpp中展開後代碼如下:

static UserOpRegisterTrigger<OpRegistry> g_register_trigger715 = ::oneflow::user_op::UserOpRegistryMgr::Get() .CheckAndGetOpRegistry("relu") .Input("x") .Output("y") .SetGetSbpFn(&ReluOp::GetSbp) .SetLogicalTensorDescInferFn(&ReluOp::InferLogicalTensorDesc) .SetPhysicalTensorDescInferFn(&ReluOp::InferPhysicalTensorDesc) .SetDataTypeInferFn(&ReluOp::InferDataType);

呼叫流程如下:

CheckAndGetOpRegistry

http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/framework/user_op_registry_manager.cpp#L33 )會建立一個OpRegistry(http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/framework/user_op_registry.h#L91 )物件,這個類和UserOpRegisterTrigger(http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/framework/user_op_registry_manager.h#L63 )類一樣,只是為構造OpRegistryResult(http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/framework/user_op_registry.h#L62 )用的中間型別。

OpRegistry會暫存中間結果並在Finish中設定一些預設推導邏輯。UserOpRegisterTrigger的建構函式會呼叫註冊邏輯。靜態變數就是為了觸發建構函式從而呼叫註冊邏輯,將構造好的OpRegistryResult儲存到UserOpRegistryMgr(http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/framework/user_op_registry_manager.h#L29 )(key是op_type,如relu)。

ReluOp表示一個具體的op_type,負責為OpRegistryResult提供Op特有的方法。

OpRegistryResult把不同的Op抽象為一個通用的結構(便於統一註冊管理),主要包含描述資訊,儲存了op的輸入輸出描述,以及資料型別、sbp等的推導邏輯函式。對於relu來說,主要是記錄了幾個推導函式要呼叫ReluOp的靜態方法;op_def主要包含input/output的名字。

1.2 ReluKernel的註冊

ReluKernel在relu_kernel.cpp中註冊,過程和Op的註冊類似。REGISTER_USER_KERNEL巨集產開後如下所示:

static UserOpRegisterTrigger<OpKernelRegistry> g_register_trigger0 = UserOpRegistryMgr::Get(). CheckAndGetOpKernelRegistry("relu"). .SetCreateFn(...) .SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kRelu, "y", "x")) .SetInplaceProposalFn([](const user_op::InferContext&, const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> { OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "x", 0, true)); return Maybe<void>::Ok(); });

注意SetCreateFn只是把一個如下的lambda表示式賦值給result_.create_fn,這個欄位很重要,後續執行就是通過它獲取kernel。

[]() { return user_op::NewOpKernel<UnaryPrimitiveKernel>( "y", "x", [](user_op::KernelComputeContext* ctx) { const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0); const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0); return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>( ctx->device_type(), ep::primitive::UnaryOp::kRelu, src->data_type(), dst->data_type()); }); }

對於relu來說,NewOpKernel就是new一個UnaryPrimitiveKernel物件並返回函式指標。

最終註冊的結果,會把OpKernelRegistryResult儲存到UserOpRegistryMgr(key是op_type_name,如"relu")。

1.3 Op和Kernel註冊相關的類關係圖

2

UserOpExpr的構造

上一篇提到,functional_api.yaml.cpp中的functional::Relu函式通過find("Relu")獲取預先註冊的PackedFunctor,呼叫其call方法會執行impl::ReluFunctor。

ReluFunctor

http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/functional/impl/activation_functor.cpp#L38 )的核心程式碼如下:

class ReluFunctor { public: ReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("relu").Input("x", 1).Output("y", 1).Build()); } Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, bool inplace) const { // 忽略inplace相關邏輯 return OpInterpUtil::Dispatch<Tensor>(*op_, {x}); } private: std::shared_ptr<OpExpr> op_; };

ReluFunctor

http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/functional/impl/activation_functor.cpp#L40 )的建構函式中,主要是構造UserOpExpr(http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/framework/op_expr.h#L131 )。

每一個user op 通過OpBuilder的Build()後,都會生成相應的UserOpExpr,用於儲存屬性、型別/shape/裝置等推導方法,用於接下來op/kernel的實際計算。UserOpExpr包含以下成員:

  • base_attrs_
  • tensor_desc_infer_fn_
  • dtype_infer_fn_
  • device_and_stream_infer_fn_

它們分別用於儲存該user op相關attrs屬性、input/output tensor shape推導方法、資料型別data type推導方法、裝置及計算流推導方法等。除了常用的UserOpExpr、還有一些用於系統op的BuiltinOpExpr。

OpBuilder的Input/Output呼叫主要是操作UserOpConf的proto物件,Build函式內會修改UserOpConf物件,比如根據OpRegistryResult::op_def補充預設值到attr。

之後構造UserOpExpr物件,UserOpConf物件被儲存到UserOpExpr的父類BuiltinOpExprImpl的op_proto_欄位,對於relu來說,op_proto_主要儲存input, output等資訊。UserOpExpr初始化時會從OpRegistryResult拷貝函式變數。

3

Functor的執行

ReluFunctor執行的核心邏輯是呼叫OpInterpUtil::Dispatch。調運順序如下:

整個鏈路很長,本篇筆記只以Eager Local Mode下,對主要執行流程做一些說明。

3.1 根據環境和輸入選擇直譯器

Dispatch呼叫的GetInterpreter(http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp#L147 )返回的是一個AutogradInterpreter(http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/framework/op_interpreter.h#L168 )物件,這個類是在其內含的OpExprInterpreter成員變數基礎之上增加了autograd的功能。GetInterpreter內實際構造的是以下3種Interpreter,在Build函式返回時轉為AutogradInterpreter。

  • LazyInterpreter:  用於lazy mode下的分散式靜態圖執行模式
  • EagerLocalInterpreter: 用於eager local mode本地單卡執行模式(和pytorch單卡或DDP對齊)
  • EagerGlobalInterpreter: 用於eager global mode,的分散式動態圖執行模式

各個Interpreter的關係如下:

GetInterpreter的作用是根據輸入和環境等資訊,選擇一個合適的直譯器。

接著在Dispatch中呼叫直譯器的\ AutogradInterpreter::Apply方法,在這個方法內呼叫internal_->Apply(...)(\ http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/framework/op_interpreter/op_interpreter.cpp#L111 ),也就是上述3個直譯器的Apply方法。

3.2 Apply

通過上面我們知道,EagerLocalInterpreterEagerGlobalnterpreterLazyInterpreter 都將為其包裹上AutogradInterpreter的殼,通過AutogradInterpreter觸發Apply的呼叫。顧名思義,AutogradInterpreter的作用主要是和autograd相關,其主要為eager mode下前向的op節點插入對應的,用於反向計算grad的節點。

下面以最常用的(Eager Mode)模式,講解Apply的執行方法。在Eager Mode(無論是eager local還是eager consistent)模式下,實際都會走到EagerInterpreter的Apply(http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/framework/op_interpreter/op_interpreter.cpp#L51 )方法:

``` Maybe EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple outputs, const OpExprInterpContext& ctx) const { #define APPLY_IF(op_type) \ if (const auto op = dynamic_cast(&op_expr)) { \ return ApplyImpl(*op, inputs, outputs, ctx); \ }

APPLY_IF(UserOp); APPLY_IF(VariableOp); APPLY_IF(CastToLocalOp); APPLY_IF(CastFromLocalOp); APPLY_IF(GlobalToGlobalOp); APPLY_IF(CastToGlobalOp); APPLY_IF(CastFromGlobalOp); APPLY_IF(DistributeSplitOp); APPLY_IF(DistributeCloneOp); APPLY_IF(DistributeConcatOp); APPLY_IF(DistributeAddOp); APPLY_IF(FunctionOp); APPLY_IF(SelectTopNOp)

undef APPLY_IF

OF_UNIMPLEMENTED() << "The type " << op_expr.op_type_name() << " has not been supported in EagerInterpreter::Apply."; } ```

這裡通過巨集定義APPLY_IF,增加了對不同型別op的分支處理,將op_expr dynamic_cast成相應子類op實現的Expr,如對於大多數使用者來說,用到的op都是UserOp型別,所以這裡實際上會走到這個分支中:

if (const auto* op = dynamic_cast<const UserOpExpr*>(&op_expr)) { return ApplyImpl(*op, inputs, outputs, ctx); }

再看看\ EagerLocalInterpreter::ApplyImpl(\ http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp#L209 ):

Maybe<void> EagerLocalInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, const OpExprInterpContext& ctx) const { return NaiveInterpret(op_expr, inputs, outputs, ctx); }

其最終實現是NaiveInterpret( http://github.com/Oneflow-Inc/oneflow/blob/v0.8.1/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp#L88

3.3 NaiveInterpret

NaiveInterpret簡單來說,主要用於做以下四件事:

  • check input tensor的device是否一致
  • 生成output tensor
  • 為output tensor推導和檢查shape/stride/dtype
  • 構建op執行指令,並派發至vm

簡化版的程式碼如下:

``` Maybe NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, const Symbol& default_device, TensorTuple* outputs, const OpExprInterpContext& ctx) {

const auto& attrs = ctx.attrs; // 檢查input tensor是否位於相同device上 ...

// 推導outout tensor的裝置型別 // Infer devices if (!user_op_expr.has_device_and_stream_infer_fn()) { stream = JUST(GetDefaultStreamByDevice(default_device)); for (int i = 0; i < outputs->size(); i++) { auto tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); JUST(tensor_impl->mut_device()) = default_device; } } else { need_check_mem_case = false; stream = JUST(user_op_expr.InferDeviceAndStream(attrs, inputs, outputs)); }

// 推導outout tensor的形狀、資料型別 // Infer shapes and dtypes const auto& device_tag = stream->device()->type(); JUST(user_op_expr.InferPhysicalTensorDesc( attrs, device_tag, & -> const TensorMeta { return CHECK_JUST(TensorImpl4Tensor(inputs[i]))->mut_tensor_meta(); }, & -> TensorMeta { // using thread_local TensorMeta pointer if inplace. // using tensor_impl TensorMeta pointer if not inplace. return output_tensor_metas->at(i); }));

// 為output tensor初始化eager_blob_object for (int i = 0; i < output_eager_blob_objects->size(); i++) { auto tensor_impl = JUST(TensorImpl4Tensor(outputs->at(i))); if (!output_eager_blob_objects->at(i)) { if (!JUST(user_op_expr.SupportNonContiguous())) { std::shared_ptr stride(new Stride(tensor_impl->shape())); tensor_impl->mut_tensor_meta()->set_stride(stride); } const auto& dep_object = NewLocalDepObject(); JUST(tensor_impl->InitEagerBlobObject(dep_object)); output_eager_blob_objects->at(i) = JUST(tensor_impl->eager_blob_object()); } else { // output i is inplaced. // check thread_local TensorMeta and tensor_impl TensorMeta. CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape() == output_tensor_metas->at(i)->shape()); CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype() == output_tensor_metas->at(i)->dtype()); } }

// 從user_op_expr中取出kernel const auto& kernel = JUST(user_op_expr.MutKernel4Stream(stream)); kernel->set_need_check_mem_case(need_check_mem_case);

for (int64_t index : kernel->output_tuple_indexes4mut2_obns()) { output_eager_blob_objects->at(index)->set_is_shape_synced(false); } // kernel dispatch至VM,等待後續實際的排程執行 JUST(PhysicalRun(& -> Maybe { return builder->Call(kernel, input_eager_blob_objects, output_eager_blob_objects, ctx, stream); })); return Maybe::Ok(); } ```

PhysicalRun接受一個lambda functor作為引數,這裡即InstructionsBuilder->Call方法,該方法接受kernel、input/output的eager blob object、kernel執行的上下文作為引數。Call方法實際會完成OpCall指令的構建,並最終將其派發至vm指令列表中,等待VM實際排程執行。

參考資料

  • OneFlow學習筆記:Op註冊
  • http://mp.weixin.qq.com/s/eF-c2irraxnH4iAesURy0Q
  • 從Functor到OpExprInterpreter
  • http://github.com/Oneflow-Inc/oneflow/tree/v0.8.1
  • http://zhuanlan.zhihu.com/p/523884650

(本文經授權後釋出,原文http://segmentfault.com/a/1190000041844858)

歡迎下載體驗 OneFlow v0.8.0 最新版本: http://github.com/Oneflow-Inc/oneflow/