1. 程式人生 > >谷歌開源 TensorFlow 的簡化庫 JAX

谷歌開源 TensorFlow 的簡化庫 JAX

  

谷歌開源了一個 TensorFlow 的簡化庫 JAX。


JAX 結合了 Autograd 和 XLA,專門用於高效能機器學習研究。

憑藉 Autograd,JAX 可以求導迴圈、分支、遞迴和閉包函式,並且它可以進行三階求導。通過 grad,它支援自動模式反向求導(反向傳播)和正向求導,且二者可以任何順序任意組合。

得力於 XLA,可以在 GPU 和 TPU 上編譯和執行 NumPy 程式。預設情況下,編譯發生在底層,庫呼叫實時編譯和執行。但是 JAX 還允許使用單一函式 API jit 將 Python 函式及時編譯為 XLA 優化的核心。編譯和自動求導可以任意組合,因此可以在 Python 環境下實現複雜的演算法並獲得最大的效能。

demo:

import jax.numpy as np
from jax import grad, jit, vmap
from functools import partial

def predict(params, inputs):
  for W, b in params:
    outputs = np.dot(inputs, W) + b
    inputs = np.tanh(outputs)
  return outputs

def logprob_fun(params, inputs, targets):
  preds = predict(params, inputs)
  return np.sum((preds - targets)**2)

grad_fun = jit(grad(logprob_fun))  # compiled gradient evaluation function
perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0)))  # fast per-example grads

更深入地看,JAX 實際上是一個可擴充套件的可組合函式轉換系統,grad 和 jit 都是這種轉換的例項。

專案地址:https://github.com/google/JAX