Vector Jacobian Product#
At the core of autodiff is the vector-Jacobian product (VJP), or in PyTensor’s archaic terminology, the L-Operator (because the vector is on the left). The Jacobian is the matrix (or tensor) of all first-order partial derivatives. Let us completely ignore what the vector means, and think how do we go about computing the product of a general vector with the Jacobian matrix?
import pytensor.tensor as pt
from pytensor.gradient import Lop, jacobian as jacobian_raw
from pytensor.graph import rewrite_graph
import numpy as np
def jacobian(*args, vectorize=True, **kwargs):
return jacobian_raw(*args, vectorize=vectorize, **kwargs)
def simplify_print(graph, **print_options):
rewrite_graph(graph, include=("fast_run",), exclude=("inplace", "BlasOpt")).dprint(**print_options)
Elemtwise operations#
The naive way is to create the full Jacobian matrix and then right-multiply it by the vector. Let’s look at a concrete example for the Elemtwise operation log(x).
x = pt.vector("x")
log_x = pt.log(x)
If x has length 3, the Jacobian of y with respect to x is a 3x3 matrix, since there are 3 outputs and 3 inputs.
Each entry contains the partial derivative of a one of the outputs (rows) with respect to one of the inputs (columns).
For the elementwise operation log(x), the Jacobian is a diagonal matrix, as each input affects only the corresponding output. For the log operation the partial derivatives are given by \(\frac{1}{x_i}\), so the Jacobian is:
J_log = jacobian(log_x, x)
simplify_print(J_log)
True_div [id A]
├─ Eye{dtype='float64'} [id B]
│ ├─ Shape_i{0} [id C]
│ │ └─ x [id D]
│ ├─ Shape_i{0} [id C]
│ │ └─ ···
│ └─ 0 [id E]
└─ ExpandDims{axis=0} [id F]
└─ x [id D]
J_log.eval({"x": [1.0, 2.0, 3.0]})
array([[1. , 0. , 0. ],
[0. , 0.5 , 0. ],
[0. , 0. , 0.33333333]])
To get the vector-Jacobian product, we will left-multiply the Jacobian by a vector v. In this case, it simplifies to an elementwise division of the vector v by the input vector x:
It is unnecessary to compute the whole Jacobian matrix, and then perform a vector-matrix multiplication. The Lop
returns the smart computations directly:
v = pt.vector("v")
vjp_log = Lop(log_x, wrt=x, eval_points=v)
simplify_print(vjp_log)
True_div [id A]
├─ v [id B]
└─ x [id C]
vjp_log.eval({"x": [1.0, 2.0, 3.0], "v": [4.0, 5.0, 6.0]})
array([4. , 2.5, 2. ])
Cumsum operation#
A pattern that will become obvious in this notebook is that we can often exploit some property of the Jacobian matrix (and that we want to multiply it by a vector) to compute the VJP cheaply. Let’s take a look at the cumulative sum operation.
x = pt.vector("x")
cumsum_x = pt.cumsum(x)
cumsum_x.eval({"x": [1.0, 2.0, 3.0]})
array([1., 3., 6.])
The jacobian of the cumulative sum operation is a lower triangular matrix of ones, since the first input affects all outputs additively, the second input affects all outputs but the first, and so on, until the last input which only affects the last output. If x has length 3:
PyTensor autodiff builds this jacobian in a funny way. Starting from a diagonal matrix, it flips the columns, performs a cumsum across the them and then flips them back. A more direct way would do cumsum along the row of the diagonal matrix, but since a flip is just a view (no copy needed) it doesn’t actually cost us much.
J_cumsum = jacobian(cumsum_x, x)
simplify_print(J_cumsum)
Subtensor{:, ::step} [id A]
├─ CumOp{1, add} [id B]
│ └─ Subtensor{:, ::step} [id C]
│ ├─ Eye{dtype='float64'} [id D]
│ │ ├─ Shape_i{0} [id E]
│ │ │ └─ x [id F]
│ │ ├─ Shape_i{0} [id E]
│ │ │ └─ ···
│ │ └─ 0 [id G]
│ └─ -1 [id H]
└─ -1 [id H]
J_cumsum.eval({"x": [1.0, 2.0, 3.0]}).astype(int)
array([[1, 0, 0],
[1, 1, 0],
[1, 1, 1]])
The left-multiplication of the Jacobian by a vector v has a special structure as well. Let’s write it down:
The final result is a cumulative sum of the vector v, but in reverse order.
v = pt.vector("v")
vjp_cumsum = Lop(cumsum_x, x, v)
simplify_print(vjp_cumsum)
Subtensor{::step} [id A]
├─ CumOp{None, add} [id B]
│ └─ Subtensor{::step} [id C]
│ ├─ v [id D]
│ └─ -1 [id E]
└─ -1 [id E]
vjp_cumsum.eval({"x": [1.0, 2.0, 3.0], "v": [1, 1, 1]})
array([3., 2., 1.])
Convolution operation#
Next, we shall look at an operation with two inputs - the discrete convolution.
x = pt.vector("x")
y = pt.vector("y", shape=(2,))
convolution_xy = pt.signal.convolve1d(x, y, mode="full")
convolution_xy.eval({"x": [0, 1, 2], "y": [1, -1]})
array([ 0., 1., 1., -2.])
If you’re not familiar with convolution, we get those four numbers by padding x
with zeros and then performing an inner product with the flipped y
, one pair of values at a time
x_padded = np.array([0, 0, 1, 2, 0])
res = np.array([
x_padded[0:2] @ [-1, 1],
x_padded[1:3] @ [-1, 1],
x_padded[2:4] @ [-1, 1],
x_padded[3:5] @ [-1, 1],
])
res
array([ 0, 1, 1, -2])
Let’s focus on the Jacobian wrt to y, as that’s smaller. If you look at the expression above you’ll see that it implies the following jacobian:
The constant zeros come from the padding. Curious how PyTensor builds this sort of jacobian?
J_convolution = jacobian(convolution_xy, y)
simplify_print(J_convolution)
Blockwise{Convolve1d, (n),(k),()->(o)} [id A]
├─ Eye{dtype='float64'} [id B]
│ ├─ Add [id C]
│ │ ├─ 1 [id D]
│ │ └─ Shape_i{0} [id E]
│ │ └─ x [id F]
│ ├─ Add [id C]
│ │ └─ ···
│ └─ 0 [id G]
├─ ExpandDims{axis=0} [id H]
│ └─ Subtensor{::step} [id I]
│ ├─ x [id F]
│ └─ -1 [id J]
└─ [False] [id K]
It performs a batched “valid” convolution between eye(4) and the flipped x vector. In a valid convolution, there is no padding, and we only multiply the sub-sequences that match in length.
J_convolution.eval({"x": [0, 1, 2]})
array([[0., 0.],
[1., 0.],
[2., 1.],
[0., 2.]])
Following the theme, is there any special structure in this Jacobian that can be exploited to compute VJP efficiently?
v = pt.vector("v", shape=(4,))
vjp_convolution = Lop(convolution_xy, y, v)
simplify_print(vjp_convolution)
Convolve1d [id A]
├─ v [id B]
├─ Subtensor{::step} [id C]
│ ├─ x [id D]
│ └─ -1 [id E]
└─ ScalarFromTensor [id F]
└─ False [id G]
It’s just the “valid” convolution between v and x flipped. Our Jacobian has a toeplitz structure, and the dot product between such a matrix and a vector is equivalent to a discrete convolution!
vjp_convolution.eval({"v": [1, 2, 3, 4], "x": [0, 1, 2]})
array([ 8., 11.])
Transpose operation#
For a final example let’s look at matrix tranposition. This is a simple operation, but is no longer a vector function.
A = pt.matrix("A", shape=(2, None))
transpose_A = A.T
transpose_A.type.shape
(None, 2)
To be able to think about the Jacobian (and then the VJP) we need to look at this operation in terms of raveled input and outputs.
transpose_A.ravel().eval({"A": np.arange(6).reshape(2, 3)})
array([0., 3., 1., 4., 2., 5.])
The Jacobian is then a (6 x 6) permutation matrix like this:
J_transpose = jacobian(transpose_A.ravel(), A).reshape((6, 6))
J_transpose.eval({"A": np.zeros((2, 3))})
array([[1., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0.],
[0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0.],
[0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 1.]])
PyTensor builds this Jacobian with two reshapes and a tranpose of an eye
.
simplify_print(J_transpose)
Reshape{2} [id A]
├─ Transpose{axes=[0, 2, 1]} [id B]
│ └─ Reshape{3} [id C]
│ ├─ Eye{dtype='float64'} [id D]
│ │ ├─ Mul [id E]
│ │ │ ├─ 2 [id F]
│ │ │ └─ Shape_i{1} [id G]
│ │ │ └─ A [id H]
│ │ ├─ Mul [id E]
│ │ │ └─ ···
│ │ └─ 0 [id I]
│ └─ MakeVector{dtype='int64'} [id J]
│ ├─ Mul [id E]
│ │ └─ ···
│ ├─ Shape_i{1} [id G]
│ │ └─ ···
│ └─ 2 [id F]
└─ [6 6] [id K]
To recreate the outcome of Lop
, we ravel the V
matrix, multiply it with the Jacobian defined on the raveled function, and reshape the result to the original input shape.
V = pt.matrix("V", shape=(3, 2))
naive_vjp_transpose = (V.ravel() @ J_transpose).reshape(V.shape[::-1])
vjp_eval_dict = {"V": np.arange(6).reshape((3, 2)), "A": np.zeros((2, 3))}
naive_vjp_transpose.eval(vjp_eval_dict)
array([[0., 2., 4.],
[1., 3., 5.]])
Because J is a permutation matrix, the multiplication with it simply rearranges the entries of V
.
What’s more, after the reshape, we end up with a simple transposition of the original V
matrix!
Unsurprisingly, Lop
takes the direct shortcut:
Lop(transpose_A, A, V).dprint()
Transpose{axes=[1, 0]} [id A]
└─ V [id B]
<ipykernel.iostream.OutStream at 0x7fe7bb110a60>
VJP and auto-diff#
It is time to reveal the meaning of the mysterious vector (or reshaped tensor) v
. In the context ouf auto-diff, it is the vector that accumulates the partial derivatives of intermediate computations. If you chain the VJP for each operation in your graph you obtain reverse-mode autodiff.
Let’s look at a simple example with the operations we discussed already:
x = pt.vector("x")
log_x = pt.log(x)
cumsum_log_x = pt.cumsum(log_x)
grad_out_wrt_x = pt.grad(cumsum_log_x.sum(), x)
simplify_print(grad_out_wrt_x)
True_div [id A]
├─ Subtensor{::step} [id B]
│ ├─ CumOp{None, add} [id C]
│ │ └─ Alloc [id D]
│ │ ├─ [1.] [id E]
│ │ └─ Shape_i{0} [id F]
│ │ └─ x [id G]
│ └─ -1 [id H]
└─ x [id G]
You may recognize the gradient components from the examples above. The gradient simplifies to cumsum(ones_like(x))[::-1] / x
We can build the same graph manually, by chaining two Lop
calls and setting the initial grad_vec
to ones
with the right shape.
grad_vec = pt.ones_like(cumsum_log_x)
grad_out_wrt_x = Lop(log_x, x, Lop(cumsum_log_x, log_x, grad_vec))
simplify_print(grad_out_wrt_x)
True_div [id A]
├─ Subtensor{::step} [id B]
│ ├─ CumOp{None, add} [id C]
│ │ └─ Alloc [id D]
│ │ ├─ [1.] [id E]
│ │ └─ Shape_i{0} [id F]
│ │ └─ x [id G]
│ └─ -1 [id H]
└─ x [id G]
Similarly, forward-mode autodiff makes use of the R-Operator (Rop) or Jacobian-vector product (JVP) to accumulate the partial derivations from inputs to outputs.
Conclusion#
We hope this sheds some light on how PyTensor (and most auto-diff frameworks) implement vector Jacobian products efficiently, in a way that avoids both having to build the full jacobian and having to multiply with it.