TorchDynamo初探:Python ByteCode的動態修改
作者|strint
1
背景
深度學習框架編譯優化時,需要先根據計算邏輯形成一個邏輯計算圖,然後再改寫計算圖,最後執行改寫後的計算圖。其中生成邏輯計算圖方式有兩種。
一種計算圖生成是基於 trace tensor 的,跟蹤 tensor 的執行路徑。tensor 執行時,基於函式過載,可以落到支援 tensor 計算的框架自定義函式,該函式一般是 c++ 層的。c++ 層的自定義函式中,功能是用於生成一個 Operation 的符號表達。比如一個對於加法運算,trace 就是記錄一個符號化的加法運算元。如此一連串的運算就被轉換了符號化的計算圖。
另外一種計算圖生成是基於 AST(抽象語法樹) 解析的。在程式碼執行前,直接根據 Python 文字程式碼得到 Python AST,然後根據 AST 來翻譯成計算圖(也叫做中間程式碼 IR)。
Python(特指 CPython)直譯器執行,第一階段會先把 Python 原始碼解析成 AST,第二階段根據 AST 生成和優化 ByteCode(位元組碼),第三階段在虛擬機器中執行 ByteCode。
基於 AST 解析的計算圖生成,發生在這裡的第一階段;基於 trace tensor 的計算圖生成,發生在第三階段之後。
TorchDynamo 特別的地方在於其工作在第二階段,動態修改 Python ByteCode,這樣第三階段執行的已經是修改後的 ByteCode了。
2
TorchDynamo 概述
TorchDynamo 是 PyTorch 新實驗的 JIT 編譯介面,支援使用 Python 在執行時修改動態執行邏輯,修改的時機是 CPython 的 ByteCode 執行前。這個思想類似 DynamoRIO(https://dynamorio.org) 專案,DynamoRIO 可以動態的修改 x86 機器碼。
CPython 的每次函式呼叫會生成一個 Frame(或者叫 Stack),Frame 中帶有的程式碼部分就是 ByteCode。CPython 執行時支援基於現有的 Frame 去設定一個自定義的 Frame,然後後面執行的就是自定義的 Frame。
TorchDynamo 的工作原理就是在執行時設定一個自定義的 Frame,該 Frame 中的 ByteCode 支援 CallBack 到 Python 層去修改。其提供的典型的修改介面是 FX Graph,也就是說 TorchDynamo 會分析 ByteCode,生成對應的 FX Graph,然後提供 FX Graph 的介面供使用者自定義計算圖。這種做法有如下優點:
-
可以支援所有的 Python 語法,因為如果在自定義 Frame 過程中的任何一點發現不支援,都可以選擇不修改 Frame 而回退到原 Frame;
-
開銷少,劫持發生在 Python 執行比較早的階段(ByteCode 生成和優化階段),而非 Python ByteCode 執行後的階段,有時可以減少 Python ByteCode 的執行開銷(猜測如果很多次 ByteCode 層面的函式呼叫被融合層成一次函式呼叫,的確可以縮減開銷);
-
可以做到不增加編譯帶來的延遲(之前的基於 tensor trace 或者 ast 解析的做法,一般都有先編譯執行所以編譯開銷無法掩蓋,但是改寫 ByteCode 這個做法,猜測是可以在識別出熱點程式碼後,單獨開一個執行緒去做編譯,而不影響主執行緒工作。Python ByteCode 改寫的 API 中有這種延遲編譯的樣例,peps.python.org/pep-052 )。
之前計算圖生成機制(基於 trace tensor、基於 AST 解析的)中的幾個問題,得到了緩解:
-
存在無法靜態化的操作,之前一般需要顯式的移除靜態化作用域,現在總是允許不做編譯,直接執行原 Python 程式碼,這樣使得靜態化標註變得簡單;
-
開啟靜態圖編譯優化,之前編譯時一般無法掩蓋,現在有辦法部分掩蓋;
-
動態 shape 問題,因為有了編譯時和執行時的掩蓋,也可以得到緩解。
這種儘量優化、動態優化的設計,最大程度了照顧了程式碼開發的體驗,讓編譯優化上手變得更簡單了。這是 TorchDynamo 帶來的最主要的好處。這種做法非常符合 PyTorch 的 Python First、Eager First、User Experience First的偏好。但是這個設計對於尋求最好的效能、最方便的靜態化部署這兩個目標並沒有改善。
3
CPython 的標準執行流程
上文提到了 CPython 的執行從 Python 文字程式碼,到 AST,到 ByteCode。這裡用一個示例展開看一下。Python 的標準組件非常易用,可以在 Python 層用 ast 元件來檢視 AST,可以用 compile 內建函式來編譯 ByteCode,可以用 exec 系統函式來執行 ByteCode。我們先在程式碼開頭匯入相關元件:
import ast
import dis
import sys
然後我們構造一個 python 程式碼,可以看到 src_code 就是普通的字串。其中包含了一段普通的 python 內建的乘法,一段深度學習的 tensor scalar 加法,最後一段是當前Python Frame 中的 ByteCode 關聯物件的列印(用於一個檢驗,後面會提到)。
print("=== source code ===")
src_code = """
# normal python operation
x = 1
x = x * 2
# tensor operation
y = dl_framework.ones((1, 2))
z = x + y
print(z)
# print python frame
f = sys._getframe()
# print the code object
print(f.f_code)
"""
print(src_code)
然後使用 ast 元件來生成這段程式碼的 AST。
print("=== source code to ast ===")
# 把原始碼解析成 AST
ast_obj = ast.parse(src_code)
# 列印 AST
print(ast.dump(ast_obj))
可以得到 AST,這裡展示的結果額外做了格式化,另外刪減掉了和計算邏輯無關的列印 frame 的部分,程式碼和其 AST 的對應關係參見注釋。AST解析是純文字層面的,`dl_framework` 還沒有被 import 進來,AST解析仍然可以正常工作。AST 基本是一個多叉樹的結構,每個節點對應一個表示式,節點子節點代表子表示式。以 `x = x + 2` 為例,Assign 是一個節點,是賦值運算,被賦值的是 `x`,賦值的值是一個二元乘法運算。
Module(body=[
# x = 1
Assign(targets=[Name(id='x', ctx=Store())],
value=Constant(value=1, kind=None),
type_comment=None),
# x = x * 2
Assign(targets=[Name(id='x', ctx=Store())],
value=BinOp(left=Name(id='x', ctx=Load()), op=Mult(), right=Constant(value=2, kind=None)), type_comment=None),
# y = dl_framework.ones((1, 2))
Assign(targets=[Name(id='y', ctx=Store())],
# dl_framework.ones((1, 2))
value=Call(func=Attribute(value=Name(id='dl_framework', ctx=Load()),
attr='ones', ctx=Load()),
args=[Tuple(elts=[Constant(value=1, kind=None),
Constant(value=2, kind=None)], ctx=Load())], keywords=[]), type_comment=None),
# z = x + y
Assign(targets=[Name(id='z', ctx=Store())],
# x + y
value=BinOp(left=Name(id='x', ctx=Load()),
op=Add(),
right=Name(id='y', ctx=Load())), type_comment=None),
# print(z)
Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Name(id='z', ctx=Load())], keywords=[])),
# 省略了列印 frame 的程式碼
],
type_ignores=[]
)
Python AST 生成後,可以利用系統函式 `compile` 把它轉成 ByteCode 位元組碼。直譯器執行也存在編譯的環節,只不過是編譯成位元組碼。
print("=== ast to bytecode ===")
# 編譯成 ByteCode
code_obj = compile(ast_obj, filename="", mode="exec")
print(code_obj)
# 展示 ByteCode 的語法糖
byte_obj = dis.Bytecode(code_obj)
print(byte_obj.dis())
`print(code_obj)`的結果是 `<code object <module> at 0x7ff79bb5c660, file "", line 3>`,這裡可以看到生成的 code object 物件的指標是 `0x7ff79bb5c660`,後面我們在執行位元組碼時,會再次看到這個指標。
`print(byte_obj.dis())` 的結果如下,每一行對應一條位元組碼,也即一條指令, 通過字面含義基本可以看出是在做什麼:
# x = 1
3 0 LOAD_CONST 0 (1)
2 STORE_NAME 0 (x)
# x = x * 2
4 4 LOAD_NAME 0 (x)
6 LOAD_CONST 1 (2)
8 BINARY_MULTIPLY
10 STORE_NAME 0 (x)
# y = dl_framework.ones((1, 2))
7 12 LOAD_NAME 1 (dl_framework)
14 LOAD_METHOD 2 (ones)
16 LOAD_CONST 2 ((1, 2))
18 CALL_METHOD 1
20 STORE_NAME 3 (y)
# x = x + y
8 22 LOAD_NAME 0 (x)
24 LOAD_NAME 3 (y)
26 BINARY_ADD
28 STORE_NAME 4 (z)
# print(z)
9 30 LOAD_NAME 5 (print)
32 LOAD_NAME 4 (z)
34 CALL_FUNCTION 1
36 POP_TOP
# 省略了列印 frame 的程式碼
得到 ByteCode 之後,就可以傳遞給 Python VM 執行了。在真正執行前,先做了一下 ByteCode 中指令的列印,實際 Python VM 執行時,也基本是這樣遍歷每一行指令,然後執行指令。可以想象,如果這些指令被修改,就可以讓 Python VM 執行自定義的指令了。
print("=== execute bytecode ===")
# print instruction
for instr in byte_obj:
print(instr.opname, instr.opcode)
# You can also do `import torch as dl_framework``
import oneflow as dl_framework
# execute bytecode
exec(code_obj)
位元組碼的執行結果如下。只需要在真正執行前,把 `dl_framework`匯入就好,然後可以看到 tensor 計算的結果,是符合預期的。
frame(或者叫 stack)是執行時的物件,對應一個函式呼叫的棧,在執行時被建立。frame 中要執行的指令就是之前建立的 ByteCode。
在執行時之前,像我們之前看到的,存在一個編譯時進行 AST 和 ByteCode 的編譯,之前編譯時生成的 code object 物件的指標是 `0x7ff79bb5c660`。
在執行時,可以獲取當前的 frame,然後通過 `frame.f_code`拿到當前 frame 裡面包含的 ByteCode(即 code object),可以發現它的指標就是之前編譯時生成的那個。
# print(z) 的結果
tensor([[3., 3.]], dtype=oneflow.float32)
# 執行時獲取當前 frame ,然後列印 frame 中的 ByteCode 物件的結果
# f = sys._getframe()
# print(f.f_code)
<code object <module> at 0x7f5cea7f1660, file "", line 3>
到此,窺見了一下 Python 原始碼到 AST, AST 到 ByteCode,ByteCode 到 Frame 執行這個預設的 Python 執行流程。TorchDynamo 用下圖做了簡單的介紹:
其中 foo 對應一個 Python 函式,即上文介紹的 Python Source Code。PyCodeObject 是上文介紹的 code object (ByteCode)在 C 程式碼層面對應的類。PyFrameObject 是上文介紹的 Frame 在 C 程式碼層面對應的類,它包含了程式碼段 PyCodeObject。_PyEval_EvalFrameDefault 對應上文介紹的 exec,它執行一個 Frame,即執行 Frame 帶有的 `PyCodeObject`。
現在我們看一下 CPython 在 C 層面的執行 Frame 的實現,對應 _PyEval_EvalFrameDefault(https://github.com/python/cpython/blob/d48ecebad5ac78a1783e09b0d32c211d9754edf4/Python/ceval.c#L757)。它的主邏輯就是取 ByteCode 指令和執行指令(https://github.com/python/cpython/blob/d48ecebad5ac78a1783e09b0d32c211d9754edf4/Python/ceval.c#L1080):
co = f->f_code; // 從 PyFrameObject* f 中取出 PyCodeObject* ,放到 co 中
names = co->co_names;
consts = co->co_consts;
fastlocals = f->f_localsplus;
freevars = f->f_localsplus + co->co_nlocals;
// 從 co 中取出第一條指令
first_instr = (_Py_CODEUNIT *) PyBytes_AS_STRING(co->co_code);
next_instr = first_instr;
#define NEXTOPARG() do { \
_Py_CODEUNIT word = *next_instr; \
opcode = _Py_OPCODE(word); \
oparg = _Py_OPARG(word); \
// 指向下一條指令
next_instr++; \
} while (0)
// 迴圈執行指令
for (;;) {
// 從當前的指令 next_instr 中獲取 opcode
NEXTOPARG();
switch (opcode) {
// 執行 op code,參見下個部分
}
}
每個指令型別對應一個 opcode,它是一個數值,執行 opcode(https://github.com/python/cpython/blob/d48ecebad5ac78a1783e09b0d32c211d9754edf4/Python/ceval.c#L1266),這裡的 opcode 可以清晰的看到和之前我們列印的 ByteCode 的型別對應關係:
#define TARGET(opcode) \
case opcode:
switch (opcode) {
// TARGET 就是一個 case
// load
TARGET(LOAD_FAST) {
PyObject *value = GETLOCAL(oparg);
if (value == NULL) {
format_exc_check_arg(PyExc_UnboundLocalError,
UNBOUNDLOCAL_ERROR_MSG,
PyTuple_GetItem(co->co_varnames, oparg));
goto error;
}
Py_INCREF(value);
PUSH(value);
FAST_DISPATCH();
}
// store
TARGET(STORE_FAST) {
PyObject *value = POP();
SETLOCAL(oparg, value);
FAST_DISPATCH();
}
// 二元加法
TARGET(BINARY_ADD) {
PyObject *right = POP();
PyObject *left = TOP();
PyObject *sum;
if (PyUnicode_CheckExact(left) &&
PyUnicode_CheckExact(right)) {
sum = unicode_concatenate(left, right, f, next_instr);
/* unicode_concatenate consumed the ref to left */
}
else {
sum = PyNumber_Add(left, right);
Py_DECREF(left);
}
Py_DECREF(right);
SET_TOP(sum);
if (sum == NULL)
goto error;
DISPATCH();
}
// 函式呼叫
TARGET(CALL_FUNCTION) {
PyObject **sp, *res;
PCALL(PCALL_ALL);
sp = stack_pointer;
res = call_function(&sp, oparg, NULL);
stack_pointer = sp;
PUSH(res);
if (res == NULL) {
goto error;
}
DISPATCH();
}
}
以上總結了 Python的預設執行流程。
4
TorchDynamo 的工作流程
TorchDynamo 在標準的 Python 執行流程中做的主要改變就是支援修改 Frame 執行前的 ByteCode。我們暫時不關注 AST 生成,看 Python 的執行流程,是 Python Source Code -> ByteCode -> Evaluate. TorchDynamo 支援 Python Source Code -> ByteCode -> [ByteCode rewrite] -> Evaluate。
ByteCode rewrite 的工作方式是把一段 ByteCode 轉成 FX Graph,然後呼叫使用者自定義的 FX Graph 改寫執行邏輯,生成一個可以經過編譯的執行函式。然後把該段 ByteCode 替換成函式呼叫 ByteCode,而呼叫的函式就是經過編譯的執行函式。從而實現編譯優化的功能。
FX Graph 支援了在 Python 層做程式碼改寫,提高了寫編譯 Pass 的便利性,這裡不做深入,可以參考資料1(https://pytorch.org/docs/stable/fx.html)和2(https://zhuanlan.zhihu.com/p/416165157)。
ByteCode rewrite 發生在 ByteCode 執行前。同樣的 Source Code,每次執行都會走到這個步驟,都可以選擇是否進行 ByteCode rewrite,或者選擇進行什麼樣的 rewrite,還可以支援 rewrite 結果的快取和複用。這體現了 Dynamo 的動態性。
下面看一個 TorchDynamo 下 fn() 函式編譯的的例子:
# 一個普通的函式
def fn(a, b):
x = a + b
x = x / 2.0
if x.sum() < 0:
return x * -1.0
return x
# torchdynamo 函式介面
with torchdynamo.optimize(custom_compiler):
fn(torch.randn(10), torch.randn(10))
fn() 函式對應的原始的 python ByteCode,和程式碼對應的關係參見其中的註釋:
# x = a + b
0 LOAD_FAST 0 (a)
2 LOAD_FAST 1 (b)
4 BINARY_ADD
6 STORE_FAST 2 (x)
# x = x / 2.0
8 LOAD_FAST 2 (x)
10 LOAD_CONST 1 (2.0)
12 BINARY_TRUE_DIVIDE
14 STORE_FAST 2 (x)
# if x.sum() < 0:
16 LOAD_FAST 2 (x)
18 LOAD_METHOD 0 (sum)
20 CALL_METHOD 0
22 LOAD_CONST 2 (0)
24 COMPARE_OP 0 (<)
26 POP_JUMP_IF_FALSE 36
# return x * -1.0
28 LOAD_FAST 2 (x)
30 LOAD_CONST 3 (-1.0)
32 BINARY_MULTIPLY
34 RETURN_VALUE
# return x
36 LOAD_FAST 2 (x)
38 RETURN_VALUE
經過 TorchDynamo 動態改寫後的 ByteCode:
# x = a + b
# x = x / 2.0
# x.sum() < 0
# 上面兩行被轉換成了 __compiled_fn_0
# __compiled_fn_0 會返回 x 和 x.sum() < 0 組成的 tuple
0 LOAD_GLOBAL 1 (__compiled_fn_0)
2 LOAD_FAST 0 (a)
4 LOAD_FAST 1 (b)
6 CALL_FUNCTION 2
8 UNPACK_SEQUENCE 2
10 STORE_FAST 2 (x)
12 POP_JUMP_IF_FALSE 22
# x * -1.0 被轉換成了 __compiled_fn_1
14 LOAD_GLOBAL 2 (__compiled_fn_1)
16 LOAD_FAST 2 (x)
18 CALL_FUNCTION 1
20 RETURN_VALUE
# return x
22 LOAD_FAST 2 (x)
24 RETURN_VALUE
可以看到新增了兩個函式呼叫, `__compiled_fn_0` 和 `__compiled_fn_1` ,這兩個函式對應的程式碼邏輯參見 bytecode 中的註釋。這兩個函式對應的 fx graph 如下:
__compiled_fn_0:
opcode name target args kwargs
------------- ------- --------------------------- ---------------- --------
placeholder a_0 a_0 () {}
placeholder b_1 b_1 () {}
call_function add <built-in function add> (a_0, b_1) {}
call_function truediv <built-in function truediv> (add, 2.0) {}
call_method sum_1 sum (truediv,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
__compiled_fn_1:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder x_4 x_4 () {}
call_function mul <built-in function mul> (x_4, -1.0) {}
output output output (mul,) {}
在 ByteCode rewrite 的最後,TorchDynamo 為這一段程式碼的輸入建立兩個 Guard:
-
區域性引數 a 必須是一個 Tensor
-
區域性引數 b 必須是一個 Tensor
該 fn 函式被再次呼叫時,如果符合這兩個條件,則可以命中快取的 TrochDynamo 處理結果;否則下次 fn 執行時,會觸發新的 ByteCode 分析和變換。
另外,對於和 tensor 無關的、比較特別的 python 程式碼,其 ByteCode 會保持原狀。這樣就達到了不需要使用者標註區域、自動尋找優化機會的設計目標。
現在看下 TorchDynamo 執行的流程總結:
可以看到它把原來的 PyFrameObject 替換成了 Patched PyFrameObject,這個是 CPython 支援的特性。這個 Patched PyFrameObject 中最主要的改動就是 Frame 中的 ByteCode (即 PyCodeObject)被修改了,原來的 PyCodeObject 變成了 Transformed PyCodeObject。而這個被改寫的 PyCodeObject 如上文和上圖所示,主要是部分 ByteCode 被替換成了呼叫被編譯過函式。這個被編譯過的函式,支援自定義編譯邏輯,當前預設的編譯介面是 FX Graph。
這部分基本參考了Dynamo的官方介紹(https://dev-discuss.pytorch.org/t/torchdynamo-an-experiment-in-dynamic-python-bytecode-transformation/361)。
5
TorchDynamo 修改 Python ByteCode 的實現
Python ByteCode 修改主要依賴 PEP 523(https://peps.python.org/pep-0523/) 提供的執行自定義 Frame Evaluation API。預設的 Eval Frame 邏輯入口函式是 _PyEval_EvalFrame,預設情況,它會直接呼叫 _PyEval_EvalFrameDefault() 來處理沒被修改的 frame,但是如果發現存在一個自定義的 Eval Frame 函式,就會執行自動線的函式。
CPython _PyEval_EvalFrame 函式實現(https://github.com/python/cpython/blob/76449350b3467b85bcb565f9e2bf945bd150a66e/Include/internal/pycore_ceval.h#L84),所以只要在 ByteCode 執行前,設定一個自定義的 eval frame 函式即可:
static inline PyObject*
_PyEval_EvalFrame(PyThreadState *tstate, struct _PyInterpreterFrame *frame, int throwflag)
{
EVAL_CALL_STAT_INC(EVAL_CALL_TOTAL);
if (tstate->interp->eval_frame == NULL) {
// 這是預設的 eval frame
return _PyEval_EvalFrameDefault(tstate, frame, throwflag);
}
// 如果存在 eval_frame 就會被執行
return tstate->interp->eval_frame(tstate, frame, throwflag);
}
可以看到 TorchDynamo 正是這麼做的。第一步,在 Python 層基於 ContextManger 在進入 Dynamo 作用域時,就觸發 eval_frame 的設定,實現(https://github.com/pytorch/pytorch/blob/4068c5467d496cd3c09a841f40adacedf3ab41a0/torch/_dynamo/eval_frame.py#L128):
# torch._dynamo.optimize(...) 對應的 context manager.
class _TorchDynamoContext:
def __init__(
self,
callback: DynamoCallback,
):
super().__init__()
assert callable(callback) or callback is False or callback is None
self.callback: DynamoCallback = callback
self.prior: Union[Unset, DynamoCallback] = unset
def __enter__(self):
# 設定 eval_frame,記錄之前的 eval frame
self.prior = set_eval_frame(self.callback)
def __exit__(self, exc_type, exc_val, exc_tb):
assert self.prior is not unset
# 恢復之前的 eval frame
set_eval_frame(self.prior)
這裡先大致認為設定的 DynamoCallback 對應一個自定義的 eval frame 所需的引數,通常是自定義的 eval frame 中所需的編譯邏輯。
看下 set_eval_frame ,C 程式碼層面的實現(https://github.com/pytorch/pytorch/blob/eaf4fe3d2b7096579b05b52d543756f74d0e91e7/torch/csrc/dynamo/eval_frame.c#L446),它有點繞但最終走到了這裡(https://github.com/pytorch/pytorch/blob/eaf4fe3d2b7096579b05b52d543756f74d0e91e7/torch/csrc/dynamo/eval_frame.c#L121),也是設定 tstate->interp->eval_frame ,把 eval_frame 設定成自定義的 custom_eval_frame_shim:
// custom_eval_frame_shim 是自定義的 frame
inline static void enable_eval_frame_shim(PyThreadState* tstate) {
if (tstate->interp->eval_frame != &custom_eval_frame_shim) {
// First call
// 設定自定義的 eval frame
tstate->interp->eval_frame = &custom_eval_frame_shim;
}
}
現在回頭看一下 PEP 523 提供的 Python JIT 編譯器的自定義 frame 執行的樣例,它提供了一個比較標準的模版(注意筆者對例子做了微調,原文有多餘和不合理的地方)。在自定義 eval frame 之前,一般還需要自定義一個存放自定義 ByteCode 的資料結構,可以認為是自定義編譯結果,比如樣例中自定義編譯結果包括3個欄位:
-
exec_count, 代表改 frame 被執行的次數;
-
jit_failed, 代表之前 jit 編譯是否失敗過;
-
jit_code,代表 jit 編譯過後的自定義 ByteCode;
據此,來看下自定義 eval frame 的樣例:
# 輸入原始的 frame
def eval_frame(frame, throw_flag):
# 獲取 frame 中的 code object 中的存放自定義編譯結果的欄位
pyjion_code = frame.code.co_extra
if not pyjion_code:
# 不如不存在,就設定一個空的預設值
frame.code.co_extra = PyjionJittedCode()
elif not pyjion_code.jit_failed:
# 如果之前 jit 執行成功
if pyjion_code.jit_code:
# 如果存在 jit 生成的 bytecode,就執行它
return pyjion_code.eval(pyjion_code.jit_code, frame)
elif pyjion_code.exec_count > 20000:
# 沒有 jit 編譯過,且 frame 被執行超過 20000 次,就嘗試進行 jit 編譯
# 如果不存在 jit 生成的 bytecode,就 jit 編譯生成它
if jit_compile(frame):
# 如果 jit 編譯成功,就執行 jit 編譯的 bytecode
return pyjion_code.eval(pyjion_code.jit_code, frame)
else:
# 如果 jit 編譯失敗,就記錄下,後面不再編譯
pyjion_code.jit_failed = True
# 增加 frame 執行次數計數
pyjion_code.exec_count += 1
# 執行預設的 frame
return _PyEval_EvalFrameDefault(frame, throw_flag)
下面接著看 TorchDynamo 自定義 evale frame 的實現。在瞭解具體的自定義 frame 執行邏輯前,有個前置知識是 PyFrameObject 中的 PyCodeObject 為了執行自定義 frame 增加了一個 co_extra 欄位,用來讓使用者放置自定義的資料,一般是存放自定義編譯結果(https://peps.python.org/pep-0523/#expanding-pycodeobject)。
typedef struct {
...
void *co_extra; /* 自定義的 frame 需要的自定義資料 */
} PyCodeObject;
TorchDynamo 在自定義編譯結果的型別是 CacheEntry,其中最重要的欄位是 code,是被編譯器修改後的 ByteCode:
typedef struct cache_entry {
// check the guards: lambda: <locals of user function>: bool
PyObject* check_fn;
// modified user bytecode (protected by check_fn's guards)
PyCodeObject* code;
// on a cache miss, linked list of next thing to try
struct cache_entry* next;
} CacheEntry;
現在看下自定義的 eval frame 邏輯 custom_eval_frame_shim(https://github.com/pytorch/pytorch/blob/eaf4fe3d2b7096579b05b52d543756f74d0e91e7/torch/csrc/dynamo/eval_frame.c#L342):
static PyObject* _custom_eval_frame(PyThreadState* tstate, PyFrameObject* frame, int throw_flag, PyObject* callback) {
// 獲取當前 frame 的 PyCodeObject 的 extra 欄位用於後面設定
// 該欄位用於放置自定義的編譯結果
CacheEntry* extra = get_extra(frame->f_code);
// callback 即上文說的自定義編譯器
// 使用 callback 進行 bytecode 的修改,即編譯
// 編譯結果寫在了 frame->f_code中的 extra 中
PyObject* result =
call_callback(callback, (PyObject*)frame, cache_size(extra));
if (result != Py_None) {
// 快取編譯結果
extra = create_cache_entry(extra, result);
Py_DECREF(result);
// 執行自定義的 frame
// eval_custom_code 最終會呼叫 CPython 介面 _PyEval_EvalFrameDefault 來執行計算
// 其中 extra->code 中存放的就自定義編譯器生成的 ByteCode
// 所以最終 _PyEval_EvalFrameDefault 執行的是編譯器生成的 ByteCode
return eval_custom_code(tstate, frame, extra->code, throw_flag);
}
}
inline static PyObject* eval_custom_code(PyThreadState* tstate, PyFrameObject* frame, PyCodeObject* custom_code, int throw_flag) {
// 使用 custom_code 建立一個自定義的 frame
PyFrameObject* shadow_frame = PyFrame_New(tstate, custom_code, frame->f_globals, NULL);
// 呼叫 Python 的 frame 執行自定義 frame
return _PyEval_EvalFrameDefault(tstate, shadow_frame, throw_flag);
}
到這裡,已經清楚了修改 Python ByteCode 執行的主線邏輯。
6
小結
這裡對 Python 的執行和 TorchDynamo 的主要原理做了初探,主要是自定義 Eval Frame 的實現技巧。其它相關的 Python ByteCode 標準,ByteCode 到 FX Graph 的轉換,ByteCode 的改寫等內容還沒涉及。
參考資料
-
tenthousandmeters.com/b (https://tenthousandmeters.com/blog/python-behind-the-scenes-1-how-the-cpython-vm-works/)
-
peps.python.org/pep-052 (https://peps.python.org/pep-0523/)
-
dev-discuss.pytorch.org (https://dev-discuss.pytorch.org/t/torchdynamo-an-experiment-in-dynamic-python-bytecode-transformation/361)
(原文:https://zhuanlan.zhihu.com/p/589115427)
其他人都在看
歡迎Star、試用OneFlow最新版本:https://github.com/Oneflow-Inc/oneflow/
本文分享自微信公眾號 - OneFlow(OneFlowTechnology)。
如有侵權,請聯絡 [email protected] 刪除。
本文參與“OSC源創計劃”,歡迎正在閱讀的你也加入,一起分享。
- OneFlow原始碼解析:Eager模式下的裝置管理與併發執行
- OpenAI創始人:GPT-4的研究起源和構建心法
- GPT-4創造者:第二次改變AI浪潮的方向
- NCCL原始碼解析①:初始化及ncclUniqueId的產生
- GPT-4問世;LLM訓練指南;純瀏覽器跑Stable Diffusion
- 適配PyTorch FX,OneFlow讓量化感知訓練更簡單
- 超越ChatGPT:大模型的智慧極限
- ChatGPT作者John Schulman:我們成功的祕密武器
- YOLOv5全面解析教程⑤:計算mAP用到的Numpy函式詳解
- GPT-3/ChatGPT復現的經驗教訓
- ChatGPT背後:從0到1,OpenAI的創立之路
- 一塊GPU搞定ChatGPT;ML系統入坑指南;理解GPU底層架構
- YOLOv5全面解析教程④:目標檢測模型精確度評估
- ChatGPT資料集之謎
- OneFlow原始碼解析:Eager模式下的SBP Signature推導
- YOLOv5全面解析教程③:更快更好的邊界框迴歸損失
- ChatGPT背後的經濟賬
- Sam Altman的成功學|升維指南
- 開源機器學習軟體對AI的發展意味著什麼?
- “一鍵”模型遷移,效能翻倍,多語言AltDiffusion推理速度超快