如何看待PyTorch 2.0?

語言: CN / TW / HK


 

作者|吳育昕
 

1

為什麼是TorchDynamo
 

Graph capture 把用户 Python 寫的模型代碼變成 graph,是一切編譯的根基。而 PyTorch 在試了這麼多方案之後似乎已經鎖定 TorchDynamo 作為 graph capture 的未來方向了,所以寫一點關於 TorchDynamo 的內容,主要是解釋到底為什麼要做這個東西(離開FB一年了,內容主要憑自己的猜測和理解)。

 

一句話儘量解釋 TorchDynamo 幹了什麼:利用 PEP523(https://peps.python.org/pep-0523/) 的 API 在用户執行每個 python frame 前, 拿到這個 frame 的 bytecode,把其中認識的部分用 tracing 的方式提取出 graph (並送給後端編譯)不認識的部分維持原樣。把修改後的 bytecode還給 CPython 跑。

 

由於 LazyTensor 和 TorchDynamo 都做 tracing,並且都是 best-effort graph capture,即只編譯自己能 capture 的部分,capture 不到的用 Python 跑 (aka Python fallback),所以觀感上兩者可能會差不多。

 

然而,這兩個方案的差別正是 TorchDynamo 關鍵的地方:

 

LazyTensor 是個純靠 tracing 的方案,不可避免的問題是「只能看見 trace 到的部分,只有 trace 一下才知道哪裏不能 trace」。而每次執行模型的時候,不能 trace 的部分可能不太一樣。為了保證正確性,LazyTensor 就不得不每次執行都要重新 trace。舉個極端的例子,模型裏寫了一個torch.add(tensor, random.random()) ,其中 random 是個 LazyTensor 看不見摸不着的 Python 函數,那只有重新 trace 才能保證正確性。

 

而當 TorchDynamo 修改 bytecode 的時候,事情就不太一樣了:

 

  1. 在 bytecode 裏能夠看得見所有需要的信息,所以能夠證明「這段模型代碼沒有用到奇怪的東西所以不需要重新 trace」。

     

  2. 光證明了「不需要 trace」不代表可以真的不 trace因為用户的代碼還是一行行給 Python 來跑的。但是 TorchDynamo 又來了:CPython 到底跑什麼 bytecode 是可以被它換掉的!

 

因此它可以做到這麼一件事:當用户 call 一個被 capture 過的模型時模型裏大部分 Python 代碼都相當於不存在了,連 symbolic execution 的 overhead 都沒有而被換成了編譯後的 native code。這一點在以前所有的 partial graph capture 的方案裏是做不到的:
 

  • LazyTensor 即使編譯過的 graph 也要每次重新在 Python 裏 trace 一遍,才能發現「哦,這個 graph 我曾見過的」。

  • @torch.jit.script 、@tf.function、 @jax.jit 可以把裝飾的 python code 換成編譯後的,但是這都依賴用户把這個 subgraph refactor 出來放到一個單獨的函數裏。而 TorchDynamo 是全自動不需要用户改代碼的。

 

  • 這種 refactor 除了增加額外的工作量之外還可能與用户的代碼結構衝突,因為 「用來編譯的graph的邊界」與「用户代碼需要的抽象邊界」很可能不 match:例如用户本來希望寫三個函數但是最佳的優化是把其中兩個半函數變成一個 graph這會讓用户很尷尬。

這只是一個最直接的例子。由於能夠讀寫 bytecode,理論上 TorchDynamo 能 access 更多 LazyTensor 根本沒有的信息做更多事情(後面會提到)。而讀寫 bytecode 的難度比 source code要低不少所以成為了一個可行的方案。

 

2
whole-graph capture用處不大?

 

有的人可能會説上面提到的東西對 whole-graph capture 沒太大用啊。 

我覺得確實是這樣:TorchDynamo 是一個對 partial-graph capture 追求極致的方案能夠對幾乎所有的 Python 實現的模型開箱即用有加速不用改代碼——前提是還要跑 Python 作為 fallback。但是部署一般需要的是 whole-graph capture 整個模型在一個 graph 裏不能用 Python。

 

用 tracing 做 whole-graph capture 的前提是用户要在 Python 代碼裏避免所有不能被 trace 的東西最常見的用户要做的三件事是:使用 symbolic shape使用 symbolic control flow,禁用除了當前 tensor library之外的所有其它 library。如果用户做到了這些那隻要一個普通的 symbolic tracing 就能 capture 到完整的 graph 了不需要 TorchDynamo 這麼複雜的機制。TorchDynamo 可能可以略微簡化用户做這些的工作量但我感覺不會有本質不同。

 

我個人的觀點是從實用角度出發要求用户做上面幾件事不算是太複雜的要求:禁用其他 library 理所應當就不説了;即使今天 PyTorch 還沒有很好的 symbolic {shape, control flow}但是隻要用 @torch.jit.script_if_tracing 來處理少量的 symbolic shape 和 symbolic control flow大多數模型都是可以正確的被 torch.jit.tracecapture 的。Meta 應該有幾十上百個 vision 模型實現在 detectron2/d2go 裏, 目前基本都是走這條路部署的(我另有篇文章https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/介紹這裏面的細節)。

 

TensorFlow 的 whole-graph capture 就簡單了:TF 從第一天就有很好的 symbolic shape 和 symbolic control flow,用就完了。tf.autograph 甚至還自動化了一部分 control flow 的改寫工作。

 

所以用户少量改代碼仍然是必須的。當然,TorchDynamo 畢竟有着"改變用户要跑的 bytecode" 的超能力。所以如果願意的話,理論上可以讓用户的 whole-graph capture 工作變得更簡單。例如:
 

  • 模型中間的一些像 if x.shape[0] > 100 的分支有的可以通過 shape inference 等價轉移到模型開頭的。這樣的話就可以 capture 到更大的沒有分支的 subgraph。 這件事在 TorchDynamo 裏現在叫做 "guard"。 

     

  • 理論上可以把 python control flow 自動替換成 symbolic 的,類似tf.autograph 做的事情只不過輸入是 bytecode 而不是 source code。 
     

目前 TorchDynamo 的 "nopython" 模式就是 whole-graph capture 了。不過似乎還不是工作重心 (以下引用自https://docs.google.com/document/d/1tlgPcR2YmC3PcQuYDPUORFmEaBPQEmo8dsh4eUjnlyI/edit#heading=h.rmxeybu31e0):

 

PT2 will provide infrastructure for a no python export mode for edge and performance sensitive serving cases. The PT2 team won’t drive this end to end stack, but we will keep a feedback loop with the teams in charge of this and ensure the components we build are reusable in these situations.

 

不過與此同時PyTorch 2.0 最近在完善 symbolic shape 的支持;functorch 裏也加入了少量 control flow operator。這算是利好 whole-graph capture 的消息。

 

3
總結

 

總的來説由於 TorchDynamo 在 bytecode 層面做文章能做到一些其他方案做不到的事情。它的優點主要為 partial graph capture 服務: 讓用户的 Python 模型代碼在 0 修改的情況下就能 capture 並獲得加速。這體現了 PyTorch 對於 "Python first" 哲學的執念。這種執着是否有必要見仁見智。

 

TorchDynamo 的主要優勢來自對 bytecode 的讀寫。JIT scripting compiler 的失敗表明在 source code level 做不了太多事TorchDynamo 能在 bytecode level 做事情確實很巧妙。 不過 要完整的復刻 CPython bytecode interpreter 它的工作量、 維護難度(以及出 bug 的概率)都是不小的。

 

另外TorchDynamo 對 whole-graph capture 沒有很大的幫助。 對於複雜的模型用户該做的改寫還是得做。不過我估計 2.0 至少能對「用户該做什麼」有個清晰的説法。

 

當然最後 PT2 到底能不能把 compiler 做好還有很多其他因素:IR 怎麼設計何時specialize/recompile,各種 backend 不同的特性等等。比如 TorchDynamo 和 LazyTensor 使用的 IR 其實也不一樣。但是本文只討論 graph capture,其他問題就不提了。

(本文經授權後發佈。原文:https://www.zhihu.com/question/570220953/answer/2798657470)

 

其他人都在看

歡迎Star、試用OneFlow最新版本:https://github.com/Oneflow-Inc/oneflow/


 

 

本文分享自微信公眾號 - OneFlow(OneFlowTechnology)。
如有侵權,請聯繫 [email protected] 刪除。
本文參與“OSC源創計劃”,歡迎正在閲讀的你也加入,一起分享。