如何看待PyTorch 2.0?

語言: CN / TW / HK

作者|吳育昕

1

為什麼是TorchDynamo

Graph capture 把用户 Python 寫的模型代碼變成 graph,是一切編譯的根基。而 PyTorch 在試了這麼多方案之後似乎已經鎖定 TorchDynamo 作為 graph capture 的未來方向了,所以寫一點關於 TorchDynamo 的內容,主要是解釋到底為什麼要做這個東西(離開FB一年了,內容主要憑自己的猜測和理解)。

一句話儘量解釋 TorchDynamo 幹了什麼: 利用 PEP523(https://peps.python.org/pep-0523/ ) 的 API 在用户執行每個 python frame 前, 拿到這個 frame 的 bytecode,把其中認識的部分用 tracing 的方式提取出 graph (並送給後端編譯),不認識的部分維持原樣。把修改後的 bytecode還給 CPython 跑。

由於 LazyTensor 和 TorchDynamo 都做 tracing,並且都是 best-effort graph capture,即只編譯自己能 capture 的部分,capture 不到的用 Python 跑 (aka Python fallback),所以觀感上兩者可能會差不多。

然而,這兩個方案的差別正是 TorchDynamo 關鍵的地方:

LazyTensor 是個純靠 tracing 的方案,不可避免的問題是「只能看見 trace 到的部分,只有 trace 一下才知道哪裏不能 trace」。 而每次執行模型的時候,不能 trace 的部分可能不太一樣。為了保證正確性,LazyTensor 就不得不每次執行都要重新 trace。舉個極端的例子,模型裏寫了一個torch.add(tensor, random.random()) ,其中 random 是個 LazyTensor 看不見摸不着的 Python 函數,那只有重新 trace 才能保證正確性。

而當 TorchDynamo 修改 bytecode 的時候,事情就不太一樣了:

  1. 在 bytecode 裏能夠看得見所有需要的信息,所以能夠證明「這段模型代碼沒有用到奇怪的東西所以不需要重新 trace」。
  2. 光證明了「不需要 trace」不代表可以真的不 trace,因為用户的代碼還是一行行給 Python 來跑的。但是 TorchDynamo 又來了:CPython 到底跑什麼 bytecode 是可以被它換掉的!

因此它可以做到這麼一件事:當用户 call 一個被 capture 過的模型時,模型裏大部分 Python 代碼都相當於不存在了,連 symbolic execution 的 overhead 都沒有,而被換成了編譯後的 native code。這一點在以前所有的 partial graph capture 的方案裏是做不到的:

  • LazyTensor 即使編譯過的 graph 也要每次重新在 Python 裏 trace 一遍,才能發現「哦,這個 graph 我曾見過的」。
  • @torch.jit.script 、@tf.function、 @jax.jit 可以把裝飾的 python code 換成編譯後的,但是這都依賴用户把這個 subgraph refactor 出來放到一個單獨的函數裏。而 TorchDynamo 是全自動不需要用户改代碼的。
  • 這種 refactor 除了增加額外的工作量之外,還可能與用户的代碼結構衝突,因為 「用來編譯的graph的邊界」與「用户代碼需要的抽象邊界」很可能不 match:例如用户本來希望寫三個函數,但是最佳的優化是把其中兩個半函數變成一個 graph,這會讓用户很尷尬。

這只是一個最直接的例子。由於能夠讀寫 bytecode,理論上 TorchDynamo 能 access 更多 LazyTensor 根本沒有的信息,做更多事情(後面會提到)。而讀寫 bytecode 的難度比 source code要低不少,所以成為了一個可行的方案。

2
whole-graph capture用處不大?

有的人可能會説,上面提到的東西對 whole-graph capture 沒太大用啊。

我覺得確實是這樣:TorchDynamo 是一個對 partial-graph capture 追求極致的方案,能夠對幾乎所有的 Python 實現的模型開箱即用有加速,不用改代碼——前提是還要跑 Python 作為 fallback。但是部署一般需要的是 whole-graph capture 整個模型在一個 graph 裏不能用 Python。

用 tracing 做 whole-graph capture 的前提是用户要在 Python 代碼裏避免所有不能被 trace 的東西,最常見的用户要做的三件事是:使用 symbolic shape,使用 symbolic control flow,禁用除了當前 tensor library之外的所有其它 library。如果用户做到了這些,那隻要一個普通的 symbolic tracing 就能 capture 到完整的 graph 了, 不需要 TorchDynamo 這麼複雜的機制。TorchDynamo 可能可以略微簡化用户做這些的工作量,但我感覺不會有本質不同。

我個人的觀點是,從實用角度出發,要求用户做上面幾件事不算是太複雜的要求:禁用其他 library 理所應當就不説了;即使今天 PyTorch 還沒有很好的 symbolic {shape, control flow},但是隻要用 @
torch.jit.script_if_tracing 來處理少量的 symbolic shape 和 symbolic control flow,大多數模型都是可以正確的被 torch.jit.tracecapture 的。Meta 應該有幾十上百個 vision 模型實現在 detectron2/d2go 裏, 目前基本都是走這條路部署的(我另有篇文章https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/ 介紹這裏面的細節)。

TensorFlow 的 whole-graph capture 就簡單了:TF 從第一天就有很好的 symbolic shape 和 symbolic control flow,用就完了。tf.autograph 甚至還自動化了一部分 control flow 的改寫工作。

所以,用户少量改代碼仍然是必須的。當然,TorchDynamo 畢竟有着"改變用户要跑的 bytecode" 的超能力。所以如果願意的話,理論上可以讓用户的 whole-graph capture 工作變得更簡單。例如:

  • 模型中間的一些像 if x.shape[0] > 100 的分支,有的可以通過 shape inference 等價轉移到模型開頭的。這樣的話就可以 capture 到更大的沒有分支的 subgraph。 這件事在 TorchDynamo 裏現在叫做 "guard"。
  • 理論上可以把 python control flow 自動替換成 symbolic 的,類似tf.autograph 做的事情,只不過輸入是 bytecode 而不是 source code。

目前 TorchDynamo 的 "nopython" 模式就是 whole-graph capture 了。不過似乎還不是工作重心 (以下引用自https://docs.google.com/document/d/1tlgPcR2YmC3PcQuYDPUORFmEaBPQEmo8dsh4eUjnlyI/edit#heading=h.rmxeybu31e0 ):

PT2 will provide infrastructure for a no python export mode for edge and performance sensitive serving cases. The PT2 team won’t drive this end to end stack, but we will keep a feedback loop with the teams in charge of this and ensure the components we build are reusable in these situations.

不過與此同時,PyTorch 2.0 最近在完善 symbolic shape 的支持;functorch 裏也加入了少量 control flow operator。這算是利好 whole-graph capture 的消息。

3
總結

總的來説,由於 TorchDynamo 在 bytecode 層面做文章,能做到一些其他方案做不到的事情。它的優點主要為 partial graph capture 服務: 讓用户的 Python 模型代碼在 0 修改的情況下就能 capture 並獲得加速。這體現了 PyTorch 對於 "Python first" 哲學的執念。這種執着是否有必要,見仁見智。

TorchDynamo 的主要優勢來自對 bytecode 的讀寫。JIT scripting compiler 的失敗表明在 source code level 做不了太多事,TorchDynamo 能在 bytecode level 做事情確實很巧妙。不過,要完整的復刻 CPython bytecode interpreter,它的工作量、維護難度(以及出 bug 的概率)都是不小的。

另外,TorchDynamo 對 whole-graph capture 沒有很大的幫助。 對於複雜的模型,用户該做的改寫還是得做。不過我估計 2.0 至少能對「用户該做什麼」有個清晰的説法。

當然,最後 PT2 到底能不能把 compiler 做好,還有很多其他因素:IR 怎麼設計,何時specialize/recompile,各種 backend 不同的特性等等。比如 TorchDynamo 和 LazyTensor 使用的 IR 其實也不一樣。但是本文只討論 graph capture,其他問題就不提了。
(本文經授權後發佈。原文:https://www.zhihu.com/question/570220953/answer/2798657470)

歡迎 Star、試用 OneFlow 最新版本:
https://github.com/Oneflow-Inc/oneflow/