適配PyTorch FX,OneFlow讓量化感知訓練更簡單

語言: CN / TW / HK

 

作者 | 劉耀輝

審稿 | BBuf許嘯宇

 

1

 

背景

 

近年來,量化感知訓練是一個較為熱點的問題,可以大大優化量化後訓練造成精度損失的問題,使得訓練過程更加高效。

 

Torch.fx在這一問題上走在了前列,使用純Python語言實現了對於Torch.nn.Module的解析和向IR的轉換,也可以提供變換後的IR對應的Python程式碼,在外部則是提供了簡潔易用的API,大大方便了量化感知訓練過程的搭建。此外,Torch.fx也有助於消除動態圖和靜態圖之間的Gap,可以比較方便地對圖進行操作以及進行運算元融合。

 

OneFlow緊隨其後添加了針對OneFlow的fx,即One-fx,在安裝One-fx之後,使用者可以直接呼叫oneflow.fx,也可以直接通過import onefx as fx進行使用。

 

one-fx地址:
https://github.com/Oneflow-Inc/one-fx

 

One-fx實現程式碼中絕大部分是對於Torch.fx的fork,但根據OneFlow和PyTorch之間存在的差別進行了一些適配或優化。本文將圍繞One-fx適配方式以及在OneFlow中的應用展開。

 

2

 

FX主要模組

 

  • Symbolioc Trace

  • Graph Module

  • Interpreter

  • Proxy

  • Passes

 

其中,前4個模組共同實現了fx的基本功能,Graph Module和Proxy又是Symbolic Trace的基礎,Passes則是在此基礎上的擴充。

 

 

Symbolic Trace的基本概念如上圖所示,最基本的模型執行過程就是從模型定義到模型執行這樣一個流程。

 

fx則是進行了非侵入式的解析,將模型執行過程轉成一張圖,這張圖中包含了很多個Node,每一個Node都包含了模型中的子模組或者函式呼叫資訊,然後使用者可以很方便地獲取到所有的Node,並對其進行一些變換操作,最後通過GraphModule重新生成一個模型定義,並對其執行。

 

其中,在進行模型解析的時候,節點之間變數傳遞也均使用代理後的變數,如y = oneflow.relu(x),實際上x和y是Proxy(x)Proxy(y)

 

3

 

One-fx實現方式

 

這裡給出一個Fx最簡單的用例,以方便後續對於實現方式的介紹。

 

 
import oneflow
class MyModule(oneflow.nn.Module):    def __init__(self):        super().__init__()        self.linear = oneflow.nn.Linear(512, 512)
    def forward(self, x):        x = self.linear(x)        y = oneflow.ones([2, 3])
        x = oneflow.relu(x)        return y
m = MyModule()
traced = oneflow.fx.symbolic_trace(m)print(traced.code)"""def forward(self, x):    linear = self.linear(x);  x = None    relu = oneflow.relu(linear);  linear = None    _tensor_constant0 = self._tensor_constant0    return _tensor_constant0"""

 

 

 

函式代理

 

代理,即fx中的Proxy模組,目的是在每次進行函式或模組呼叫的時候新增一些額外操作,使得對模型的解析和重建得以進行,而包裝則是適配代理的一種方式。

 

torch.fx中,對於nn.Module的包裝比較易於理解,每當待解析Module中出現了繼承自nn.Module的物件,那麼就將其__call__函式替換成包裝過的函式。然而,對於pytorch的函式的代理的實現要更“繞”一些,是藉助了__torch_function__這一機制

https://github.com/pytorch/pytorch/blob/c7c723897658eda6298bb74d92e4bb18ab4a5fe3/torch/overrides.py),限於篇幅原因這裡不專門對其進行介紹。比較關鍵的點是,OneFlow中沒有這一機制,如果需要新增,那麼會是規模很大的、侵入性的,於是One-fx的實現就需要找其它路徑。

 

我們使用的解決方式是搜尋oneflowoneflow.nn.functionaloneflow._C等模組中的Callable,並去除其中屬於類的部分,然後對其餘函式進行包裝,在每次解析模型之前,會將這些模組的__dict__中對應項替換成包裝後的函式,並且在解析模型之後重新將這些項進行還原。對於constructor型別的函式,如onesrandn等則不進行代理,直接執行,在最終構建圖的時候作為constant來處理。

 

對於函式的包裝部分原始碼實現如下,每次執行代理後的函式,會先判斷該函式的入參中有沒有Proxy變數,如果有,那麼將會建立一個call_function型別的節點並返回Proxy包裝後的節點,否則直接呼叫原函式並返回結果。

 

def _create_wrapped_func(orig_fn):    @functools.wraps(orig_fn)    def wrapped(*args, **kwargs):        # 判斷引數中是否存在proxy變數        proxy = _find_proxy(args, kwargs)        if proxy is not None:            # 如果引數中有Proxy變數,建立節點並返回Proxy包裝後的節點            return_proxy = proxy.tracer.create_proxy(                "call_function", orig_fn, args, kwargs            )            return_proxy.node.meta["is_wrapped"] = True            return return_proxy        # 如果沒有Proxy變數,直接呼叫原函式        return orig_fn(*args, **kwargs)
    return wrapped

 

其中,return_proxy = proxy.tracer.create_proxy("call_function", orig_fn, args, kwargs)這行程式碼指定了使用與入參相同的Tracer來建立節點並返回結果,create_proxy函式定義的主要部分如下,建立節點並在Proxy包裝後返回。

 

def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],                     name: Optional[str] = None, type_expr : Optional[Any] = None,                     proxy_factory_fn: Callable[[Node], 'Proxy'] = None):    args_ = self.create_arg(args)    kwargs_ = self.create_arg(kwargs)    assert isinstance(args_, tuple)    assert isinstance(kwargs_, dict)
    # 建立節點    node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
    if not proxy_factory_fn:        proxy = self.proxy(node)    else:        proxy = proxy_factory_fn(node)
    return proxy

 

而其中的create_node方法,實際上是呼叫了Tracer.graph.create_node,在圖中建立節點,主要部分程式碼如下,其中op就是fx IR中的op,代表了節點型別,而target則是節點的操作主體,在上面的例子中就是orig_func

 

因此,當我們自定義的Module中的forward函式中的所有呼叫都被包裝之後,實際上再執行forward的時候,就會依次在Tracer.graph中建立節點,這也正是symbolic_trace的基本思路。

 

def create_node(self, op: str, target: 'Target',                    args: Optional[Tuple['Argument', ...]] = None,                    kwargs: Optional[Dict[str, 'Argument']] = None,                    name: Optional[str] = None,                    type_expr: Optional[Any] = None) -> Node:    # 此處有一些assert
    # 建立一個節點名稱,避免重複    candidate = name if name is not None else self._target_to_str(target)    name = self._graph_namespace.create_name(candidate, None)    # 建立節點    n = Node(self, name, op, target, args, kwargs, type_expr)
    # 建立名稱與節點的對映關係    self._graph_namespace.associate_name_with_obj(name, n)
    return n

 

而對於symbolic_trace過程,其核心就是Tracer.trace。這個方法可以分為兩部分,一個是預處理部分,一個是主幹部分。其中預處理過程大致定義如下,主要任務是初始化Graph、確立模型以及forward函式和建立包裝後的引數。

 

如前面所提及的,symbolic trace的基本思路是藉助Proxy變數以及包裝後的函式,在每次呼叫的時候都建立一個節點,因此,forward函式的輸入也需要用Proxy進行包裝,這一步定義在Tracer.create_args_for_root中。


  

 

def trace(        self,        root: Union[oneflow.nn.Module, Callable[..., Any]],        concrete_args: Optional[Dict[str, Any]] = None,    ) -> Graph:    # 確定模組主體以及forward函式,其中fn即forward函式    if isinstance(root, oneflow.nn.Module):        self.root = root
        assert hasattr(            type(root), self.traced_func_name        ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}"
        fn = getattr(type(root), self.traced_func_name)        self.submodule_paths = {mod: name for name, mod in root.named_modules()}    else:        self.root = oneflow.nn.Module()        fn = root
    tracer_cls: Optional[Type["Tracer"]] = getattr(self, "__class__", None)    # 在Tracer中初始化一張圖    self.graph = Graph(tracer_cls=tracer_cls)        self.tensor_attrs: Dict[oneflow.Tensor, str] = {}    # 這個子函式用於收集模型中所有Tensor型別的變數    def collect_tensor_attrs(m: oneflow.nn.Module, prefix_atoms: List[str]):        for k, v in m.__dict__.items():            if isinstance(v, oneflow.Tensor):                self.tensor_attrs[v] = ".".join(prefix_atoms + [k])        for k, v in m.named_children():            collect_tensor_attrs(v, prefix_atoms + [k])
    collect_tensor_attrs(self.root, [])
    assert isinstance(fn, FunctionType)
    # 獲取fn所在模組的所有可讀變數    fn_globals = fn.__globals__    # 建立包裝後的引數    fn, args = self.create_args_for_root(        fn, isinstance(root, oneflow.nn.Module), concrete_args    )

 

隨後則是trace的主幹部分,這一部分大致程式碼如下,主要任務是對函式、方法、模組進行必要的包裝,然後在Graph中建立節點,完成整個圖的資訊。

 

其中,我們會建立一個Patcher環境並在其中進行這些過程,這是因為對於函式和方法的包裝會直接改變掉某些包中對應函式或方法的行為,為了不讓這種行為的改變溢位到trace的範圍之外,在每次進行包裝的時候會在Patcher中記錄本次操作,然後在_Patcher.__exit__中根據記錄的操作一一還原現場。

 

 
# 下面程式碼仍然是`trace`函式的一部分
# 定義對於`nn.Module`的getattr方法的包裝@functools.wraps(_orig_module_getattr)def module_getattr_wrapper(mod, attr):    attr_val = _orig_module_getattr(mod, attr)    return self.getattr(attr, attr_val, parameter_proxy_cache)
# 定義對於`nn.Module`的forward方法的包裝@functools.wraps(_orig_module_call)def module_call_wrapper(mod, *args, **kwargs):    def forward(*args, **kwargs):        return _orig_module_call(mod, *args, **kwargs)
    _autowrap_check(        patcher,        getattr(getattr(mod, "forward", mod), "__globals__", {}),        self._autowrap_function_ids,    )    return self.call_module(mod, forward, args, kwargs)# 這裡Patcher的作用是在退出這一環境的時候恢復現場,避免包裝函式、方法的影響溢位到`trace`之外。with _Patcher() as patcher:    # 對`__getattr__`和`nn.Module.__call__`這兩個方法預設進行包裝    patcher.patch_method(        oneflow.nn.Module,        "__getattr__",        module_getattr_wrapper,        deduplicate=False,    )    patcher.patch_method(        oneflow.nn.Module, "__call__", module_call_wrapper, deduplicate=False    )    # 對預定好需要進行包裝的函式進行包裝    _patch_wrapped_functions(patcher)    _autowrap_check(patcher, fn_globals, self._autowrap_function_ids)    # 遍歷所有需要對其中函式進行自動包裝的package    for module in self._autowrap_search:        if module is oneflow:            dict = {}            # 當package為oneflow時,對此進行特殊處理,單獨分出一個字典存放原本`oneflow.__dict__`中的內容            for name, value in module.__dict__.items():                if not isinstance(value, oneflow.nn.Module) and not value in _oneflow_no_wrapped_functions:                    dict[name] = value            _autowrap_check_oneflow(                patcher, dict, module.__dict__, self._autowrap_function_ids            )        else:            _autowrap_check(                patcher, module.__dict__, self._autowrap_function_ids            )    # 建立節點,這裡的`create_node`呼叫實際上只是建立了最後一個節點,即輸出節點。    # 但是這裡`fn`就是forward函式,在執行這一函式的時候,就會如前面所說依次建立節點。    self.create_node(        "output",        "output",        (self.create_arg(fn(*args)),),        {},        type_expr=fn.__annotations__.get("return", None),    )

 

其中,_patch_wrapped_functions的實現如下:

 

def _patch_wrapped_functions(patcher: _Patcher):    # `_wrapped_fns_to_patch`中包含了所有需要自動包裝的函式    for frame_dict, name in _wrapped_fns_to_patch:        if name not in frame_dict:            if hasattr(builtins, name):                # 對於built-in函式,不存在於frame_dict中,單獨進行處理來根據名稱獲取函式本身                orig_fn = getattr(builtins, name)            else:                # 如果是oneflow中指定需要包裝的函式,那麼就進行獲取,否則丟擲名稱無法識別的異常                is_oneflow_wrapped_function, func = is_oneflow_wrapped_function_and_try_get(name)                if is_oneflow_wrapped_function:                    orig_fn = func                else:                    raise NameError("Cannot deal with the function %s."%name)        else:            # 如果函式名稱已經存在於frame_dict中,直接通過字典查詢來獲得函式            orig_fn = frame_dict[name]        # 建立包裝後的函式並進行`patch`,即定義當trace過程結束的時候,如何還原現場        patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn))        # 對於類中的方法,直接包裝並patch。    for cls, name in _wrapped_methods_to_patch:        patcher.patch_method(cls, name, _create_wrapped_method(cls, name))

 

 

全域性包裝

 

在模型的forward函式中,我們有時不僅會用到框架自帶的模組或者函式,有點時候還需要用到自定義的函式或者built-in函式,對於這種情況如果不進行處理,那麼自然無法接受Proxy(x)的入參。fx中提供了fx.wrap這一API,當用戶需要呼叫這部分函式的時候,可以實現使用fx.wrap(func)使其被包裝。

 

例如:

 

 
import oneflow
oneflow.fx.wrap(len)class MyModule(oneflow.nn.Module):    def __init__(self):        super().__init__()        self.linear = oneflow.nn.Linear(512, 512)
    def forward(self, x):        x = self.linear(x) + len(x.shape)        return x
traced = oneflow.fx.symbolic_trace(MyModule())print(traced.code)"""def forward(self, x):    linear = self.linear(x)    getattr_1 = x.shape;  x = None    len_1 = len(getattr_1);  getattr_1 = None    add = linear + len_1;  linear = len_1 = None    return add"""

 

但是其侷限性在於,如果Module的原始碼是來自其它庫,那麼在呼叫的地方使用fx.wrap是不起作用的,在oneflow和torch中都會有這一問題。然而flowvision中有多處使用了built-in function,因此我們添加了一個API,即global_wrap,原理比較簡單,就是直接對某個函式所在的包的__dict__進行修改,用法如下:

 

 
# MyModule來自其它包with oneflow.fx.global_wrap(len):    m = MyModule()
    traced = oneflow.fx.symbolic_trace(m)    print(traced.code)    """    def forward(self, x):        linear = self.linear(x);  x = None        getattr_1 = linear.shape        len_1 = len(getattr_1);  getattr_1 = None        relu = oneflow.relu(linear);  linear = None        add = relu + len_1;  relu = len_1 = None        return add    """

 

使用with關鍵字的原因是這種實現方式是直接修改了某個包的__dict__,對於其它地方的呼叫也會產生影響,因此需要將其限制在一定範圍內。此外,包裝後的函式包含了對型別的判定等一系列操作,也會極大影響built-in函式的效能。

 

 

其它適配

 

其它地方的處理都比較簡單,不需要對實現方式做修改,只需要將細節部分對齊即可,這也體現出oneflow和pytorch在前端部分的高度相容性。

 

4

 

IR設計

 

fx的IR設計遵循以下幾個原則:

 

  • 避免支援長尾分佈,複雜的樣例。主要關注經典模型的程式捕獲和變換。

     

  • 使用機器學習從業者已經熟悉的工具和概念,例如Python的資料結構和 PyTorch 中公開記錄的運算元 。

     

  • 使程式捕獲過程具有高度可配置性,以便使用者可以為長尾需求實現自己的解決方案。

     

 

fx的IR主要由幾個部分組成;

 

  • opcode:即當前操作的型別,可以是placeholder, get_attr, call_function, call_method, call_module, output

     

  • name:即給當前操作的命名。

     

  • target:當前操作的實體,例如對於call_function型別的操作,可能這一屬性會是

    <built-in function len>

     

  • args和kwargs:指定當前操作的引數。

     

 

通過print_tabular這一API可以很方便美觀地打印出fx中的IR,例如對於以下的MyModule模型,我們可以打印出其IR:

 

import oneflow
class MyModule(oneflow.nn.Module):    def __init__(self, do_activation : bool = False):        super().__init__()        self.do_activation = do_activation        self.linear = oneflow.nn.Linear(512, 512)
    def forward(self, x):        x = self.linear(x)        y = oneflow.ones([2, 3])
        x = oneflow.topk(x, 10)        return x.relu() + y
traced = oneflow.fx.symbolic_trace(MyModule())traced.graph.print_tabular()
"""opcode         name               target                    args                       kwargs-------------  -----------------  ------------------------  -------------------------  --------placeholder    x                  x                         ()                         {}call_module    linear             linear                    (x,)                       {}call_function  topk               <built-in function topk>  (linear, 10)               {}call_method    relu               relu                      (topk,)                    {}get_attr       _tensor_constant0  _tensor_constant0         ()                         {}call_function  add                <built-in function add>   (relu, _tensor_constant0)  {}output         output             output                    (add,)                     {}"""

 

儘管fx的IR不算強大(例如不能處理動態控制流),但是定義非常簡潔,實現簡單,對於使用者來講上手門檻相對低很多。

 

5

 

One-fx應用舉例

 

 

OP替換

 

下面的例子展示瞭如何將add操作全部替換成mul操作。

 

import oneflowfrom oneflow.fx import symbolic_traceimport operator
class M(oneflow.nn.Module):    def forward(self, x, y):        return x + y, oneflow.add(x, y), x.add(y)
if __name__ == '__main__':    traced = symbolic_trace(M())
    patterns = set([operator.add, oneflow.add, "add"])
    for n in traced.graph.nodes:        if any(n.target == pattern for pattern in patterns):            with traced.graph.inserting_after(n):                new_node = traced.graph.call_function(oneflow.mul, n.args, n.kwargs)                n.replace_all_uses_with(new_node)            traced.graph.erase_node(n)
    traced.recompile()
    traced.graph.print_tabular()
    print(traced.code)

 

 

效能分析

 

以下程式碼展示如何使用fx進行模型的效能分析,將原本的模型通過symbolic_trace解析成各個節點,再在其中插入測試效能的操作。

 

import oneflowimport flowvision.models as modelsimport statistics, tabulate, timefrom typing import Any, Dict, List
class ProfilingInterpreter(oneflow.fx.Interpreter):    def __init__(self, mod : oneflow.nn.Module):        gm = oneflow.fx.symbolic_trace(mod)        super().__init__(gm)
        # 記錄總執行時間        self.total_runtime_sec : List[float] = []        # 記錄各個節點執行時間        self.runtimes_sec : Dict[oneflow.fx.Node, List[float]] = {}
    # 重寫`run`方法,本質上是對基類`run`方法的簡單封裝,在執行前後記錄時間點。    # 這一方法是Graph整體執行的入口。    def run(self, *args) -> Any:        t_start = time.time()        return_val = super().run(*args)        t_end = time.time()        self.total_runtime_sec.append(t_end - t_start)        return return_val
    # 同上,重寫`run_node`方法,不需要自己寫細節實現,只需要在對基類的`run_node`呼叫前後記錄時間點即可    # 這一方法是Graph中執行每個Node的入口。    def run_node(self, n : oneflow.fx.Node) -> Any:        t_start = time.time()        return_val = super().run_node(n)        t_end = time.time()        self.runtimes_sec.setdefault(n, [])        self.runtimes_sec[n].append(t_end - t_start)        return return_val
    # 定義如何列印效能測試結果    def summary(self, should_sort : bool = False) -> str:        # 儲存每個節點的列印資訊        node_summaries : List[List[Any]] = []        # 由於模組會被呼叫多次,所以這裡計算一下平均的執行總時長        mean_total_runtime = statistics.mean(self.total_runtime_sec)
        for node, runtimes in self.runtimes_sec.items():            mean_runtime = statistics.mean(runtimes)            # 計算節點執行時間佔總時間的比例            pct_total = mean_runtime / mean_total_runtime * 100            # 記錄節點資訊、節點平均執行時長和節點執行時間佔總時間的比例            node_summaries.append(                [node.op, str(node), mean_runtime, pct_total])
        # 如果需要,安按照執行時間進行排序        if should_sort:            node_summaries.sort(key=lambda s: s[2], reverse=True)
        # 以下是藉助tabulate庫進行格式化來美化顯示效果        headers : List[str] = [            'Op type', 'Op', 'Average runtime (s)', 'Pct total runtime'        ]        return tabulate.tabulate(node_summaries, headers=headers)

if __name__ == '__main__':    rn18 = models.resnet18()    rn18.eval()    input = oneflow.randn(5, 3, 224, 224)    output = rn18(input)    interp = ProfilingInterpreter(rn18)    interp.run(input)    print(interp.summary(True))


   

 

效果如下:

 

 

 

運算元融合

 

以下程式碼演示如何藉助fx將模型中的卷積層和BN層進行融合,對於這種組合,並不需要引入新的運算元,只需要對原本conv的權重進行操作即可。可以參考:https://nenadmarkus.com/p/fusing-batchnorm-and-conv/

 

 
import sysimport oneflowimport oneflow.nn as nnimport numpy as npimport copyfrom typing import Dict, Any, Tuple
# 通過直接對權重進行運算的方式進行Conv和BN的融合def fuse_conv_bn_eval(conv, bn):    assert(not (conv.training or bn.training)), "Fusion only for eval!"    fused_conv = copy.deepcopy(conv)
    fused_conv.weight, fused_conv.bias = \        fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,                             bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
    return fused_conv
# 權重融合方式def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):    if conv_b is None:        conv_b = oneflow.zeros_like(bn_rm)    if bn_w is None:        bn_w = oneflow.ones_like(bn_rm)    if bn_b is None:        bn_b = oneflow.zeros_like(bn_rm)    bn_var_rsqrt = oneflow.rsqrt(bn_rv + bn_eps)
    conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))    conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
    return oneflow.nn.Parameter(conv_w), oneflow.nn.Parameter(conv_b)
# 根據字串對名稱進行分割,比如`foo.bar.baz` -> (`foo.bar`, `baz`)def _parent_name(target : str) -> Tuple[str, str]:    *parent, name = target.rsplit('.', 1)    return parent[0] if parent else '', name
def replace_node_module(node: oneflow.fx.Node, modules: Dict[str, Any], new_module: oneflow.nn.Module):    assert(isinstance(node.target, str))    parent_name, name = _parent_name(node.target)    setattr(modules[parent_name], name, new_module)
# 定義對模型進行融合操作的過程def fuse(model: oneflow.nn.Module) -> oneflow.nn.Module:    model = copy.deepcopy(model)    # 先通過fx.symbolic_trace獲取一個GraphModule    fx_model: oneflow.fx.GraphModule = oneflow.fx.symbolic_trace(model)    modules = dict(fx_model.named_modules())
    # 遍歷GraphModule中的所有節點,分別進行操作    for node in fx_model.graph.nodes:        # 跳過所有不是module的節點        if node.op != 'call_module':            continue        # 檢測到conv+bn的結構後進行融合操作        if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:            # conv的輸出同時被其它節點使用,即conv後連線兩個節點時無法融合            if len(node.args[0].users) > 1:                continue            conv = modules[node.args[0].target]            bn = modules[node.target]            fused_conv = fuse_conv_bn_eval(conv, bn)            replace_node_module(node.args[0], modules, fused_conv)            # 對圖中的邊進行置換,對於用到bn輸出的節點,要更改它們的輸入            node.replace_all_uses_with(node.args[0])            # 移除舊的節點            fx_model.graph.erase_node(node)    fx_model.graph.lint()    # 重新建圖(構造模型)    fx_model.recompile()    return fx_model

if __name__ == '__main__':    # 以下引入flowvision中的resnet 18模型,並進行融合前後的benchmark比較    import flowvision.models as models    import time
    rn18 = models.resnet18().cuda()    rn18.eval()
    inp = oneflow.randn(10, 3, 224, 224).cuda()    output = rn18(inp)
    def benchmark(model, iters=20):        for _ in range(10):            model(inp)        oneflow.cuda.synchronize()        begin = time.time()        for _ in range(iters):            model(inp)        return str(time.time()-begin)
    fused_rn18 = fuse(rn18)    unfused_time = benchmark(rn18)    fused_time = benchmark(fused_rn18)    print("Unfused time: ", benchmark(rn18))    print("Fused time: ", benchmark(fused_rn18))    assert unfused_time > fused_time

 

6

 

未來計劃

 

  • 基於fx進行8bit量化感知訓練和部署

  • 基於fx進行運算元融合

  • eager模式下基於fx獲得模型更精確的FLOPs和MACs結果

 

 

參考文獻

1.https://pytorch.org/docs/stable/fx.html

2.https://github.com/Oneflow-Inc/one-fx

3.https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html

4.https://pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html

5.https://zhuanlan.zhihu.com/p/449908382

 


其他人都在看

歡迎Star、試用OneFlow最新版本:https://github.com/Oneflow-Inc/oneflow/

 


 

本文分享自微信公眾號 - OneFlow(OneFlowTechnology)。
如有侵權,請聯絡 [email protected] 刪除。
本文參與“OSC源創計劃”,歡迎正在閱讀的你也加入,一起分享。