opt_einsum¶
Optimized einsum can significantly reduce the overall execution time of einsum-like expressions by optimizing the expression’s contraction order and dispatching many operations to canonical BLAS, cuBLAS, or other specialized routines. Optimized einsum is agnostic to the backend and can handle NumPy, Dask, PyTorch, Tensorflow, CuPy, Sparse, Theano, JAX, and Autograd arrays as well as potentially any library which conforms to a standard API.
Features¶
The algorithms found in this repository often power the einsum
optimizations
in many of the above projects. For example, the optimization of np.einsum
has been passed upstream and most of the same features that can be found in
this repository can be enabled with numpy.einsum(..., optimize=True)
. However,
this repository often has more up to date algorithms for complex contractions.
Several advanced features are as follows:
Inspect detailed information about the path chosen.
Perform contractions with numerous backends, including on the GPU and with libraries such as TensorFlow and PyTorch.
Generate reusable expressions, potentially with constant tensors, that can be compiled for greater performance.
Use an arbitrary number of indices to find contractions for hundreds or even thousands of tensors.
Share intermediate computations among multiple contractions.
Compute gradients of tensor contractions using Autograd or JAX.
Example¶
Take the following einsum-like expression:
and consider two different algorithms:
import numpy as np
dim = 10
I = np.random.rand(dim, dim, dim, dim)
C = np.random.rand(dim, dim)
def naive(I, C):
# N^8 scaling
return np.einsum('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C)
def optimized(I, C):
# N^5 scaling
K = np.einsum('pi,ijkl->pjkl', C, I)
K = np.einsum('qj,pjkl->pqkl', C, K)
K = np.einsum('rk,pqkl->pqrl', C, K)
K = np.einsum('sl,pqrl->pqrs', C, K)
return K
>>> np.allclose(naive(I, C), optimized(I, C))
True
Most einsum functions do not consider building intermediate arrays; therefore, helping einsum functions by creating these intermediate arrays can result in considerable cost savings even for small N (N=10):
%timeit naive(I, C)
1 loops, best of 3: 829 ms per loop
%timeit optimized(I, C)
1000 loops, best of 3: 445 µs per loop
The index transformation is a well-known contraction that leads to straightforward intermediates. This contraction can be further complicated by considering that the shape of the C matrices need not be the same, in this case, the ordering in which the indices are transformed matters significantly. Logic can be built that optimizes the order; however, this is a lot of time and effort for a single expression.
The opt_einsum
package is a typically a drop-in replacement for einsum
functions and can handle this logic and path finding for you:
from opt_einsum import contract
dim = 30
I = np.random.rand(dim, dim, dim, dim)
C = np.random.rand(dim, dim)
%timeit optimized(I, C)
10 loops, best of 3: 65.8 ms per loop
%timeit contract('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C)
100 loops, best of 3: 16.2 ms per loop
The above will automatically find the optimal contraction order, in this case,
identical to that of the optimized function above, and compute the products
for you. Additionally, contract
can use vendor BLAS with the numpy.dot
function under the hood to exploit additional parallelism and performance.
Details about the optimized contraction order can be explored:
>>> import opt_einsum as oe
>>> path_info = oe.contract_path('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C)
>>> print(path_info[0])
[(0, 2), (0, 3), (0, 2), (0, 1)]
>>> print(path_info[1])
Complete contraction: pi,qj,ijkl,rk,sl->pqrs
Naive scaling: 8
Optimized scaling: 5
Naive FLOP count: 8.000e+08
Optimized FLOP count: 8.000e+05
Theoretical speedup: 1000.000
Largest intermediate: 1.000e+04 elements
--------------------------------------------------------------------------------
scaling BLAS current remaining
--------------------------------------------------------------------------------
5 GEMM ijkl,pi->jklp qj,rk,sl,jklp->pqrs
5 GEMM jklp,qj->klpq rk,sl,klpq->pqrs
5 GEMM klpq,rk->lpqr sl,lpqr->pqrs
5 GEMM lpqr,sl->pqrs pqrs->pqrs
Citation¶
If this code has benefited your research, please support us by citing:
Daniel G. A. Smith and Johnnie Gray, opt_einsum - A Python package for optimizing contraction order for einsum-like expressions. Journal of Open Source Software, 2018, 3(26), 753
Table of Contents¶
Install opt_einsum¶
You can install opt_einsum with conda
, with pip
, or by installing from source.
Conda¶
You can update opt_einsum using conda:
conda install opt_einsum -c conda-forge
This installs opt_einsum and the NumPy dependancy.
The opt_einsum package is maintained on the conda-forge channel.
Pip¶
To install opt_einsum with pip
there are a few options, depending on which
dependencies you would like to keep up to date:
pip install opt_einsum
Input Format¶
The opt_einsum
package was originally designed as a drop-in replacement for the np.einsum
function and supports all input formats that np.einsum
supports. There are
two styles of input accepted, a basic introduction to which can be found in the
documentation for numpy.einsum()
. In addition to this, opt_einsum
extends the allowed index labels to unicode or arbitrary hashable, comparable
objects in order to handle large contractions with many indices.
‘Equation’ Input¶
As with numpy.einsum()
, here you specify an equation as a string,
followed by the array arguments:
>>> import opt_einsum as oe
>>> eq = 'ijk,jkl->li'
>>> x, y = np.random.rand(2, 3, 4), np.random.rand(3, 4, 5)
>>> z = oe.contract(eq, x, y)
>>> z.shape
(5, 2)
However, in addition to the standard alphabet, opt_einsum
also supports
unicode characters:
>>> eq = "αβγ,βγδ->δα"
>>> oe.contract(eq, x, y).shape
(5, 2)
This enables access to thousands of possible index labels. One way to access
these programmatically is through the function
get_symbol()
:
>>> oe.get_symbol(805)
'α'
which maps an int
to a unicode characater. Note that as with
numpy.einsum()
if the output is not specified with ->
it will default
to the sorted order of all indices appearing once:
>>> eq = "αβγ,βγδ" # "->αδ" is implicit
>>> oe.contract(eq, x, y).shape
(2, 5)
‘Interleaved’ Input¶
The other input format is to ‘interleave’ the array arguments with their index
labels (‘subscripts’) in pairs, optionally specifying the output indices as a
final argument. As with numpy.einsum()
, integers are allowed as these
index labels:
>>> oe.contract(x, [1, 2, 3], y, [2, 3, 4], [4, 1]).shape
>>> (5, 2)
with the default output order again specified by the sorted order of indices
appearing once. However, unlike numpy.einsum()
, in opt_einsum
you can
also put anything hashable and comparable such as str in the subscript list.
A simple example of this syntax is:
>>> x, y, z = np.ones((1, 2)), np.ones((2, 2)), np.ones((2, 1))
>>> oe.contract(x, ('left', 'bond1'), y, ('bond1', 'bond2'), z, ('bond2', 'right'), ('left', 'right'))
array([[4.]])
The subscripts need to be hashable so that opt_einsum
can efficiently process them, and
they should also be comparable so as to allow a default sorted output. For example:
>>> x = np.array([[0, 1], [2, 0]])
>>> oe.contract(x, (0, 1)) # original matrix
array([[0, 1],
[2, 0]])
>>> oe.contract(x, (1, 0)) # the transpose
array([[0, 2],
[1, 0]])
>>> oe.contract(x, ('a', 'b')) # original matrix, consistent behavior
array([[0, 1],
[2, 0]])
>>> oe.contract(x, ('b', 'a')) # the transpose, consistent behavior
array([[0, 2],
[1, 0]])
>>> oe.contract(x, (0, 'a')) # relative sequence undefined, can't determine output
TypeError: For this input type lists must contain either Ellipsis or hashable and comparable object (e.g. int, str)
Backends & GPU Support¶
opt_einsum
is quite agnostic to the type of n-dimensional arrays (tensors)
it uses, since finding the contraction path only relies on getting the shape
attribute of each array supplied.
It can perform the underlying tensor contractions with various
libraries. In fact, any library that provides a tensordot()
and
transpose()
implementation can perform most normal contractions.
While more special functionality such as axes reduction is reliant on a
einsum()
implementation.
The following is a brief overview of libraries which have been tested with
opt_einsum
:
tensorflow: compiled tensor expressions that can run on GPU.
theano: compiled tensor expressions that can run on GPU.
cupy: numpy-like api for GPU tensors.
dask: larger-than-memory tensor computations, distributed scheduling, and potential reuse of intermediaries.
sparse: sparse tensors.
pytorch: numpy-like api for GPU tensors.
autograd: automatic derivative computation for tensor expressions
jax: compiled GPU tensor expressions including
autograd
-like functionality
opt_einsum
is agnostic to the type of n-dimensional arrays (tensors)
it uses, since finding the contraction path only relies on getting the shape
attribute of each array supplied.
It can perform the underlying tensor contractions with various
libraries. In fact, any library that provides a tensordot()
and
transpose()
implementation can perform most normal contractions.
While more special functionality such as axes reduction is reliant on a
einsum()
implementation.
Note
For a contraction to be possible without using a backend einsum, it must satisfy the following rule: in the full expression (including output indices) each index must appear twice. In other words, each dimension must be contracted with one other dimension, or left alone.
Backend agnostic contractions¶
The automatic backend detection will be detected based on the first supplied
array (default), this can be overridden by specifying the correct backend
argument for the type of arrays supplied when calling
contract()
. For example, if you had a library installed
called 'foo'
which provided an ndarray
like object with a
.shape
attribute as well as foo.tensordot
and foo.transpose
then
you could contract them with something like:
contract(einsum_str, *foo_arrays, backend='foo')
Behind the scenes opt_einsum
will find the contraction path, perform
pairwise contractions using e.g. foo.tensordot
and finally return the canonical
type those functions return.
Dask¶
dask is an example of a library which satisfies these requirements. For example:
>>> import opt_einsum as oe
>>> import dask.array as da
>>> shapes = (3, 200), (200, 300), (300, 4)
>>> dxs = [da.random.normal(0, 1, shp, chunks=(100, 100)) for shp in shapes]
>>> dxs
[dask.array<da.random.normal, shape=(3, 200), dtype=float64, chunksize=(3, 100)>,
dask.array<da.random.normal, shape=(200, 300), dtype=float64, chunksize=(100, 100)>,
dask.array<da.random.normal, shape=(300, 4), dtype=float64, chunksize=(100, 4)>]
>>> dy = oe.contract("ab,bc,cd", *dxs) # will infer backend='dask'
>>> dy
dask.array<transpose, shape=(3, 4), dtype=float64, chunksize=(3, 4)>
>>> dy.compute()
array([[ 470.71404665, 2.44931372, -28.47577265, 424.37716615],
[ 64.38328345, -287.40753131, 144.46515642, 324.88169821],
[-142.07153553, -180.41739259, 125.0973783 , -239.16754541]])
In this case, dask arrays in = dask array out, since dask arrays have a shape
attribute, and opt_einsum
can find dask.array.tensordot
and
dask.array.transpose
.
Sparse¶
The sparse library also fits the requirements and is supported. An example:
>>> import sparse as sp
>>> shapes = (3, 200), (200, 300), (300, 4)
>>> sxs = [sp.random(shp) for shp in shapes]
[<COO: shape=(3, 200), dtype=float64, nnz=6, sorted=False, duplicates=True>,
<COO: shape=(200, 300), dtype=float64, nnz=600, sorted=False, duplicates=True>,
<COO: shape=(300, 4), dtype=float64, nnz=12, sorted=False, duplicates=True>]
>>> sy = oe.contract("ab,bc,cd", *sxs)
<COO: shape=(3, 4), dtype=float64, nnz=0, sorted=False, duplicates=False>
Autograd¶
The autograd library is a drop-in for
numpy
that can automatically compute the gradients of array expressions.
opt_einsum
automatically dispatches the autograd
arrays correctly,
enabling a simple way to compute gradients of tensor contractions:
>>> import numpy as np
>>> import autograd
>>> shapes = [(2, 3), (3, 4), (4, 2)]
>>> x, y, z = [np.random.rand(*s) for s in shapes]
>>> # make single arg function as autograd takes derivative of first arg
>>> def foo(xyz):
... return oe.contract('ij,jk,ki->', *xyz)
>>> foo([x, y, z])
array(4.90422159)
>>> # wrap foo with autograd to compute gradients instead
>>> dfoo = autograd.grad(foo)
>>> dx, dy, dz = dfoo(arrays)
>>> dx, dy, dz
(array([[1.10056194, 1.25078356, 1.48211494],
[1.38945961, 1.5572077 , 1.65234003]]),
array([[0.41710717, 0.63202881, 0.84573502, 0.95069975],
[0.42706777, 0.73630994, 0.99328938, 0.77415267],
[0.40773334, 0.61693475, 0.82545726, 0.93132302]]),
array([[0.78747828, 1.28979012],
[1.26051133, 1.48835538],
[0.46896666, 0.55003072],
[1.10840828, 1.16722494]]))
Jax¶
jax is itself a drop-in for autograd
,
that additionally uses XLA to compile the
expressions, particularly for the GPU. Using it with opt_einsum
is very
simple:
>>> import jax
>>> # generate a compiled version of the above function
>>> jit_foo = jax.jit(foo)
>>> jit_foo([x, y, z])
DeviceArray(4.9042215, dtype=float32)
>>> # generate a compiled version of the gradient function
>>> jit_dfoo = jax.jit(jax.grad(foo))
>>> jit_dfoo([x, y, z])
[DeviceArray([[1.10056198, 1.25078356, 1.48211491],
[1.38945973, 1.5572077, 1.65234005]], dtype=float32),
DeviceArray([[0.41710716, 0.63202882, 0.84573501, 0.95069975],
[0.42706776, 0.73630995, 0.99328935, 0.7741527 ],
[0.40773335, 0.61693472, 0.82545722, 0.93132305]],
dtype=float32),
DeviceArray([[0.78747827, 1.28979015],
[1.2605114 , 1.4883554 ],
[0.46896666, 0.55003077],
[1.10840821, 1.16722488]], dtype=float32)]
Note
jax
defaults to converting all arrays to single precision. This
behaviour can be changed by running
from jax.config import config; config.update("jax_enable_x64", True)
before it has been imported and used at all.
Special (GPU) backends for numpy arrays¶
A particular case is if numpy arrays are required for the input and output,
however, a more performant backend is required such as performing the contraction on a GPU.
Unless the specified backend works on numpy arrays, this requires converting to
and from the backend array type. Currently opt_einsum
can handle this
automatically for:
all of which offer GPU support. Since tensorflow
and theano
both require
compiling the expression, this functionality is encapsulated in generating a
ContractExpression
using
contract_expression()
, which can then be called using numpy
arrays whilst specifiying backend='tensorflow'
etc.
Additionally, if arrays are marked as constant
(see Specifying Constants), then these arrays will be kept on the device
for optimal performance.
Theano¶
If theano
is installed, using it as backend is as simple as specifiying
backend='theano'
:
>>> shapes = (3, 200), (200, 300), (300, 4)
>>> expr = oe.contract_expression("ab,bc,cd", *shapes)
>>> expr
<ContractExpression('ab,bc,cd')>
>>> import numpy as np
>>> # GPU advantage mainly for low precision numbers
>>> xs = [np.random.randn(*shp).astype(np.float32) for shp in shapes]
>>> expr(*xs, backend='theano') # might see some fluff on first run
...
array([[ 129.28352 , -128.00702 , -164.62917 , -335.11682 ],
[-462.52344 , -121.12657 , -67.847626 , 624.5457 ],
[ 5.2838974, 36.441578 , 81.62851 , 703.1576 ]],
dtype=float32)
Note that you can still supply theano.tensor.TensorType
directly to
opt_einsum
(with backend='theano'
), and it will return the
relevant theano
type.
Tensorflow¶
To run the expression with tensorflow, you need to register a default session:
>>> import tensorflow as tf
>>> sess = tf.Session() # might see some fluff
...
>>> with sess.as_default(): out = expr(*xs, backend='tensorflow')
>>> out
array([[ 129.28357 , -128.00684 , -164.62903 , -335.1167 ],
[-462.52362 , -121.12659 , -67.84769 , 624.5455 ],
[ 5.2839584, 36.44155 , 81.62852 , 703.15784 ]],
dtype=float32)
Note that you can still supply this expression with, for example, a
tensorflow.placeholder
using backend='tensorflow'
, and then no
conversion would take place, instead you’d get a tensorflow.Tensor
back.
Version 1.9 of tensorflow also added support for eager execution of computations. If compilation of the contraction expression tensorflow graph is taking a substantial amount of time up then it can be advantageous to use this, especially since tensor contractions are quite compute-bound. This is achieved by running the following snippet:
import tensorflow as tf
tf.enable_eager_execution()
After which opt_einsum
will automatically detect eager mode if
backend='tensorflow'
is supplied to a
ContractExpression
.
Pytorch & Cupy¶
Both pytorch and cupy
offer numpy-like, GPU-enabled arrays which execute eagerly rather than
requiring any compilation. If they are installed, no steps are required to
utilize them other than specifiying the backend
keyword:
>>> expr(*xs, backend='torch')
array([[ 129.28357 , -128.00684 , -164.62903 , -335.1167 ],
[-462.52362 , -121.12659 , -67.84769 , 624.5455 ],
[ 5.2839584, 36.44155 , 81.62852 , 703.15784 ]],
dtype=float32)
>>> expr(*xs, backend='cupy')
array([[ 129.28357 , -128.00684 , -164.62903 , -335.1167 ],
[-462.52362 , -121.12659 , -67.84769 , 624.5455 ],
[ 5.2839584, 36.44155 , 81.62852 , 703.15784 ]],
dtype=float32)
And as with the other GPU backends, if raw cupy
or pytorch
arrays are
supplied the returned array will be of the same type, with no conversion
to or from numpy
arrays.
Jax¶
jax, as introduced above, can compile tensor
functions, in doing so often achieving better performance.
opt_einsum
expressions can handle this behind the scenes,
so again just the backend
keyword needs to be supplied:
>>> expr(*xs, backend='jax')
array([[ 129.28357 , -128.00684 , -164.62903 , -335.1167 ],
[-462.52362 , -121.12659 , -67.84769 , 624.5455 ],
[ 5.2839584, 36.44155 , 81.62852 , 703.15784 ]],
dtype=float32)
Contracting arbitrary objects¶
There is one more explicit backend that can handle arbitrary arrays of objects,
so long the objects themselves just support multiplication and addition (
__mul__
and __add__
dunder methods respectively). Use it by supplying
backend='object'
.
For example, imagine we want to perform a contraction of arrays made up of sympy symbols:
>>> import opt_einsum as oe
>>> import numpy as np
>>> import sympy
>>> # define the symbols
>>> a, b, c, d, e, f, g, h, i, j, k, l = [sympy.symbols(oe.get_symbol(i)) for i in range(12)]
>>> a * b + c * d
𝑎𝑏+𝑐𝑑
>>> # define the tensors (you might explicitly specify ``dtype=object``)
>>> X = np.array([[a, b], [c, d]])
>>> Y = np.array([[e, f], [g, h]])
>>> Z = np.array([[i, j], [k, l]])
>>> # contract the tensors!
>>> oe.contract('uv,vw,wu->u', X, Y, Z, backend='object')
array([i*(a*e + b*g) + k*(a*f + b*h), j*(c*e + d*g) + l*(c*f + d*h)],
dtype=object)
There are a few things to note here:
The returned array is a
numpy.ndarray
but since it hasdtype=object
it can really hold any python objectsWe had to explicitly use
backend='object'
, sincenumpy.einsum()
would have otherwise been dispatched to, which can’t handledtype=object
(thoughnumpy.tensordot()
in fact can)Although an optimized pairwise contraction order is used, the looping in each single contraction is performed in python so performance will be drastically lower than for numeric dtypes!
Reusing Paths¶
If you expect to use a particular contraction repeatedly, it can make things simpler and more efficient not to compute the path each time. Instead, supplying contract_expression()
with the contraction string and the shapes of the tensors generates a ContractExpression
which can then be repeatedly called with any matching set of arrays. For example:
>>> my_expr = oe.contract_expression("abc,cd,dbe->ea", (2, 3, 4), (4, 5), (5, 3, 6))
>>> print(my_expr)
<ContractExpression('abc,cd,dbe->ea')>
1. 'dbe,cd->bce' [GEMM]
2. 'bce,abc->ea' [GEMM]
The ContractExpression
can be called with 3 arrays that match the original shapes without having to recompute the path:
>>> x, y, z = (np.random.rand(*s) for s in [(2, 3, 4), (4, 5), (5, 3, 6)])
>>> my_expr(x, y, z)
array([[ 3.08331541, 4.13708916],
[ 2.92793729, 4.57945185],
[ 3.55679457, 5.56304115],
[ 2.6208398 , 4.39024187],
[ 3.66736543, 5.41450334],
[ 3.67772272, 5.46727192]])
Note that few checks are performed when calling the expression, and while it will work for a set of arrays with the same ranks as the original shapes but differing sizes, it might no longer be optimal.
Specifying Constants¶
Often one generates contraction expressions where some of the tensor arguments
will remain constant across many calls.
contract_expression()
allows you to specify the indices of
these constant arguments, allowing opt_einsum
to build and then reuse as
many constant contractions as possible. Take for example the equation:
>>> eq = "ij,jk,kl,lm,mn->ni"
where we know that only the first and last tensors will vary between calls.
We can specify this by marking the middle three as constant - we then need to
supply the actual arrays rather than just the shapes to
contract_expression()
:
>>> # A B C D E
>>> shapes = [(9, 5), (5, 5), (5, 5), (5, 5), (5, 8)]
>>> # mark the middle three arrays as constant
>>> constants = [1, 2, 3]
>>> # generate the constant arrays
>>> B, C, D = [np.random.randn(*shapes[i]) for i in constants]
>>> # supplied ops are now mix of shapes and arrays
>>> ops = (9, 5), B, C, D, (5, 8)
>>> expr = oe.contract_expression(eq, *ops, constants=constants)
>>> expr
<ContractExpression('ij,[jk,kl,lm],mn->ni', constants=[1, 2, 3])>
The expression now only takes the remaining two arrays as arguments (the
tensors with 'ij'
and 'mn'
indices), and will store as many reusable
constant contractions as possible.
>>> A1, E1 = np.random.rand(*shapes[0]), np.random.rand(*shapes[-1])
>>> out1 = expr(A1, E1)
>>> out1.shap
(8, 9)
>>> A2, E2 = np.random.rand(*shapes[0]), np.random.rand(*shapes[-1])
>>> out2 = expr(A2, E2)
>>> out2.shape
(8, 9)
>>> np.allclose(out1, out2)
False
>>> print(expr)
<ContractExpression('ij,[jk,kl,lm],mn->ni', constants=[1, 2, 3])>
1. 'jm,mn->jn' [GEMM]
2. 'jn,ij->ni' [GEMM]
Where we can see that the expression now only has to perform two contractions to compute the output.
Note
The constant part of an expression is lazily generated upon the first call
(specific to each backend), though it can also be explicitly built by calling
evaluate_constants()
.
We can confirm the advantage of using expressions and constants by timing the
following scenarios, first setting
A = np.random.rand(*shapes[0])
and E = np.random.rand(*shapes[-1])
.
contract from scratch:
>>> %timeit oe.contract(eq, A, B, C, D, E)
239 µs ± 5.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
contraction with an expression but no constants:
>>> expr_no_consts = oe.contract_expression(eq, *shapes)
>>> %timeit expr_no_consts(A, B, C, D, E)
76.7 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
contraction with an expression and constants marked:
>>> %timeit expr(A, E)
40.8 µs ± 1.22 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Although this gives us a rough idea, of course the efficiency savings are hugely dependent on the size of the contraction and number of possible constant contractions.
We also note that even if there are no constant contractions to perform, it can be very advantageous to specify constant tensors for particular backends. For instance, if a GPU backend is used, the constant tensors will be kept on the device rather than being transferred each time.
Sharing Intermediates¶
If you want to compute multiple similar contractions with common terms, you can embed them in a shared_intermediates()
context. Computations of subexpressions in this context will be memoized, and will be garbage collected when the contexts exits.
For example, suppose we want to compute marginals at each point in a factor chain:
inputs = 'ab,bc,cd,de,ef'
factors = [np.random.rand(1000, 1000) for _ in range(5)]
%%timeit
marginals = {output: contract('{}->{}'.format(inputs, output), *factors)
for output in 'abcdef'}
1 loop, best of 3: 5.82 s per loop
To share this computation, we can perform all contractions in a shared context:
%%timeit
with shared_intermediates():
marginals = {output: contract('{}->{}'.format(inputs, output), *factors)
for output in 'abcdef'}
1 loop, best of 3: 1.55 s per loop
If it is difficult to fit your code into a context, you can instead save the sharing cache for later reuse.
with shared_intermediates() as cache: # create a cache
pass
marginals = {}
for output in 'abcdef':
with shared_intermediates(cache): # reuse a common cache
marginals[output] = contract('{}->{}'.format(inputs, output), *factors)
del cache # garbage collect intermediates
Note that sharing contexts can be nested, so it is safe to to use shared_intermediates()
in library code without leaking intermediates into user caches.
Note
By default a cache is thread safe, to share intermediates between threads explicitly pass the same cache to each thread.
Introduction¶
Performing an optimized tensor contraction to speed up einsum
involves two
key stages:
Finding a pairwise contraction order, or ‘path’.
Performing the sequence of contractions given this path.
The better the quality of path found in the first step, the quicker the actual
contraction in the second step can be – often dramatically. However, finding
the optimal path is an NP-hard problem that can quickly become intractable,
meaning that a balance must be struck between the time spent finding a path,
and its quality. opt_einsum
handles this by using several path finding
algorithms, which can be manually specified using the optimize
keyword.
These are:
The
'optimal'
strategy - an exhaustive search of all possible pathsThe
'dynamic-programming'
strategy - a near-optimal search based off dynamic-programmingThe
'branch'
strategy - a more restricted search of many likely pathsThe
'greedy'
strategy - finds a path one step at a time using a cost heuristic
By default (optimize='auto'
), contract()
will select the
best of these it can while aiming to keep path finding times below around 1ms.
An analysis of each of these approaches’ performance can be found at the bottom
of this page.
For large and complex contractions, there is the 'random-greedy'
approach,
which samples many (by default 32) greedy paths and can be customized to
explicitly spend a maximum amount of time searching. Another preset,
'random-greedy-128'
, uses 128 paths for a more exhaustive search.
See The Random-Greedy Path page for more details on configuring these.
Finally, there is the 'auto-hq'
preset which targets a much larger search
time (~1sec) in return for finding very high quality paths, dispatching to the
'optimal'
, 'dynamic-programming'
and then 'random-greedy-128'
paths
depending on contraction size.
If you want to find the path separately to performing the
contraction, or just inspect information about the path found, you can use the
function contract_path()
.
Examining the Path¶
As an example, consider the following expression found in a perturbation theory (one of ~5,000 such expressions):
'bdik,acaj,ikab,ajac,ikbd'
At first, it would appear that this scales like N^7 as there are 7 unique indices; however, we can define a intermediate to reduce this scaling.
a = 'bdik,ikab,ikbd' (N^5 scaling)
result = 'acaj,ajac,a' (N^4 scaling)
This is a single possible path to the final answer (and notably, not the most optimal) out of many possible paths. Now, let opt_einsum compute the optimal path:
import opt_einsum as oe
# Take a complex string
einsum_string = 'bdik,acaj,ikab,ajac,ikbd->'
# Build random views to represent this contraction
unique_inds = set(einsum_string) - {',', '-', '>'}
index_size = [10, 17, 9, 10, 13, 16, 15, 14, 12]
sizes_dict = dict(zip(unique_inds, index_size))
views = oe.helpers.build_views(einsum_string, sizes_dict)
path, path_info = oe.contract_path(einsum_string, *views)
>>> print(path)
[(0, 4), (1, 3), (0, 1), (0, 1)]
>>> print(path_info)
Complete contraction: bdik,acaj,ikab,ajac,ikbd->
Naive scaling: 7
Optimized scaling: 4
Naive FLOP count: 2.387e+8
Optimized FLOP count: 8.068e+4
Theoretical speedup: 2958.354
Largest intermediate: 1.530e+3 elements
--------------------------------------------------------------------------------
scaling BLAS current remaining
--------------------------------------------------------------------------------
4 0 ikbd,bdik->ikb acaj,ikab,ajac,ikb->
4 GEMV/EINSUM ikb,ikab->a acaj,ajac,a->
3 0 ajac,acaj->a a,a->
1 DOT a,a-> ->
We can then check that actually performing the contraction produces the expected result:
import numpy as np
einsum_result = np.einsum("bdik,acaj,ikab,ajac,ikbd->", *views)
contract_result = oe.contract("bdik,acaj,ikab,ajac,ikbd->", *views)
>>> np.allclose(einsum_result, contract_result)
True
By contracting terms in the correct order we can see that this expression can be computed with N^4 scaling. Even with the overhead of finding the best order or ‘path’ and small dimensions,
opt_einsum
is roughly 3000 times faster than pure einsum for this expression.
Format of the Path¶
Let us look at the structure of a canonical einsum
path found in NumPy and its optimized variant:
einsum_path = [(0, 1, 2, 3, 4)]
opt_path = [(1, 3), (0, 2), (0, 2), (0, 1)]
In opt_einsum each element of the list represents a single contraction.
In the above example the einsum_path would effectively compute the result as a single contraction identical to that of einsum
, while the
opt_path would perform four contractions in order to reduce the overall scaling.
The first tuple in the opt_path, (1,3)
, pops the second and fourth terms, then contracts them together to produce a new term which is then appended to the list of terms, this is continued until all terms are contracted.
An example should illuminate this:
---------------------------------------------------------------------------------
scaling GEMM current remaining
---------------------------------------------------------------------------------
terms = ['bdik', 'acaj', 'ikab', 'ajac', 'ikbd'] contraction = (1, 3)
3 False ajac,acaj->a bdik,ikab,ikbd,a->
terms = ['bdik', 'ikab', 'ikbd', 'a'] contraction = (0, 2)
4 False ikbd,bdik->bik ikab,a,bik->
terms = ['ikab', 'a', 'bik'] contraction = (0, 2)
4 False bik,ikab->a a,a->
terms = ['a', 'a'] contraction = (0, 1)
1 DOT a,a-> ->
A path specified in this format can explicitly be supplied directly to
contract()
using the optimize
keyword:
contract_result = oe.contract("bdik,acaj,ikab,ajac,ikbd->", *views, optimize=opt_path)
>>> np.allclose(einsum_result, contract_result)
True
Performance Comparison¶
The following graphs should give some indication of the tradeoffs between path
finding time and path quality. They are generated by finding paths with each
possible algorithm for many randomly generated networks of n
tensors with
varying connectivity.
First we have the time to find each path as a function of the number of terms in the expression:
Clearly the exhaustive ('optimal'
, 'branch-all'
) and exponential
('branch-2'
) searches eventually scale badly, but for modest amounts of
terms they incur only a small overhead. The 'random-greedy'
approach is not
shown here as it is simply max_repeats
times slower than the 'greedy'
approach - at least if not parallelized.
Next we can look at the average FLOP speedup (as compared to the easiest path
to find, 'greedy'
):
One can see that the heirarchy of path qualities is:
'optimal'
(used by auto forn <= 4
)'branch-all'
(used by auto forn <= 6
)'branch-2'
(used by auto forn <= 8
)'branch-1'
(used by auto forn <= 14
)'greedy'
(used by auto for anything larger)
Note
The performance of the 'random=greedy'
approach (which is never used
automatically) can be found separately in The Random-Greedy Path section.
There are a few important caveats to note with this graph. Firstly, the benefits of more advanced path finding are very dependent on the complexity of the expression. For ‘simple’ contractions, all the different approaches will mostly find the same path (as here). However, for ‘tricky’ contractions, there will be certain cases where the more advanced algorithms will find much better paths. As such, while this graph gives a good idea of the relative performance of each algorithm, the ‘average speedup’ is not a perfect indicator since worst-case performance might be more critical.
Note that the speedups for any of the methods as compared to a standard
einsum
or a naively chosen path (such as path=[(0, 1), (0, 1), ...]
)
are all exponentially large and not shown.
The Optimal Path¶
The most optimal path can be found by searching through every possible way to contract the tensors together, this includes all combinations with the new intermediate tensors as well.
While this algorithm scales like N!, and can often become more costly to compute than the unoptimized contraction itself, it provides an excellent benchmark.
The function that computes this path in opt_einsum is called optimal()
and works by performing a recursive, depth-first search. By keeping track of the
best path found so far, in terms of total estimated FLOP count, the search can
then quickly prune many paths as soon as as they exceed this best.
This optimal strategy is used by default with the optimize='auto'
mode of
opt_einsum
for 4 tensors or less, though it can handle expressions of up to
9-10 tensors in a matter of seconds.
Let us look at an example:
Contraction: abc,dc,ac->bd
Build a list with tuples that have the following form:
iteration 0:
"(cost, path, list of input sets remaining)"
[ (0, [], [set(['a', 'c', 'b']), set(['d', 'c']), set(['a', 'c'])] ]
Since this is iteration zero, we have the initial list of input sets. We can consider three possible combinations where we contract list positions (0, 1), (0, 2), or (1, 2) together:
iteration 1:
[ (9504, [(0, 1)], [set(['a', 'c']), set(['a', 'c', 'b', 'd']) ]),
(1584, [(0, 2)], [set(['c', 'd']), set(['c', 'b']) ]),
(864, [(1, 2)], [set(['a', 'c', 'b']), set(['a', 'c', 'd']) ])]
We have now run through the three possible combinations, computed the cost of the contraction up to this point, and appended the resulting indices from the contraction to the list. As all contractions only have two remaining input sets the only possible contraction is (0, 1):
iteration 2:
[ (28512, [(0, 1), (0, 1)], [set(['b', 'd']) ]),
(3168, [(0, 2), (0, 1)], [set(['b', 'd']) ]),
(19872, [(1, 2), (0, 1)], [set(['b', 'd']) ])]
The final contraction cost is computed, and we choose the second path from the list as the overall cost is the lowest.
The Branching Path¶
While the optimal
path is guaranteed to find the smallest estimate FLOP
cost, it spends a lot of time exploring paths which are not likely to result in
an optimal path. For instance, outer products are usually not advantageous
unless absolutely necessary. Additionally, by trying a ‘good’ path first, it
should be possible to quickly establish a threshold FLOP cost which can then be
used to prune many bad paths.
The branching strategy (provided by branch()
) does
this by taking the recursive, depth-first approach of
optimal()
, whilst also sorting potential contractions
based on a heuristic cost, as in greedy()
.
There are two main flavours:
optimize='branch-all'
: explore all inner products, starting with those that look best according to the cost heuristic.
optimize='branch-2'
: similar, but at each step only explore the estimated best two possible contractions, leading to a maximum of 2^N paths assessed.
In both cases, branch()
takes an active approach to
pruning paths well before they hit the best total FLOP count, by comparing
them to the FLOP count (times some factor) achieved by the best path at the
same point in the contraction.
There is also 'branch-1'
, which, since it only explores a single path at
each step does not really ‘branch’ - this is essentially the approach of
'greedy'
.
In comparison, 'branch-1'
will be slower for large expressions, but for
small to medium expressions it might find slightly higher quality contractions
due to considering individual flop costs at each step.
The default optimize='auto'
mode of opt_einsum
will use
'branch-all'
for 5 or 6 tensors, though it should be able to handle
12-13 tensors in a matter or seconds. Likewise, 'branch-2'
will be used for
7 or 8 tensors, though it should be able to handle 20-22 tensors in a matter of
seconds. Finally, 'branch-1'
will be used by 'auto'
for expressions of
up to 14 tensors.
Customizing the Branching Path¶
The ‘branch and bound’ path can be customized by creating a custom
BranchBound
instance. For example:
optimizer = oe.BranchBound(nbranch=3, minimize='size', cutoff_flops_factor=None)
path, path_info = oe.contract_path(eq, *arrays, optimize=optimizer)
You could then tweak the settings (e.g. optimizer.nbranch = 4
) and the best
bound found so far will persist and be used to prune paths on the next call:
optimizer.nbranch = 4
path, path_info = oe.contract_path(eq, *arrays, optimize=optimizer)
The Greedy Path¶
The 'greedy'
approach provides a very efficient strategy for finding
contraction paths for expressions with large numbers of tensors.
It does this by eagerly choosing contractions in three stages:
Eagerly compute any Hadamard products (in arbitrary order – this is commutative).
Greedily contract pairs of remaining tensors, at each step choosing the pair that maximizes
reduced_size
– these are generally inner products.Greedily compute any pairwise outer products, at each step choosing the pair that minimizes
sum(input_sizes)
.
The cost heuristic reduced_size
is simply the size of the pair of potential
tensors to be contracted, minus the size of the resulting tensor.
The greedy
algorithm has space and time complexity O(n * k)
where n
is the number of input tensors and k
is the maximum number of tensors that
share any dimension (excluding dimensions that occur in the output or in every
tensor). As such, the algorithm scales well to very large sparse contractions
of low-rank tensors, and indeed, often finds the optimal, or close to optimal
path in such cases.
The greedy
functionality is provided by greedy()
,
and is selected by the default optimize='auto'
mode of opt_einsum
for
expressions with many inputs. Expressions of up to a thousand tensors
should still take well less than a second to find paths for.
Optimal Scaling Misses¶
The greedy algorithm, while inexpensive, can occasionally miss optimal scaling in some circumstances as seen below. The greedy
algorithm prioritizes expressions which remove the largest indices first, in this particular case this is the incorrect choice and it is difficult for any heuristic algorithm to “see ahead” as would be needed here.
It should be stressed these cases are quite rare and by default contract
uses the optimal
path for four and fewer inputs as the cost of evaluating the optimal
path is similar to that of the greedy
path. Similarly, for 5-8 inputs, contract
uses one of the
branching strategies which can find higher quality paths.
>>> M = np.random.rand(35, 37, 59)
>>> A = np.random.rand(35, 51, 59)
>>> B = np.random.rand(37, 51, 51, 59)
>>> C = np.random.rand(59, 27)
>>> path, desc = oe.contract_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="greedy")
>>> print(desc)
Complete contraction: xyf,xtf,ytpf,fr->tpr
Naive scaling: 6
Optimized scaling: 5
Naive FLOP count: 2.146e+10
Optimized FLOP count: 4.165e+08
Theoretical speedup: 51.533
Largest intermediate: 5.371e+06 elements
--------------------------------------------------------------------------------
scaling BLAS current remaining
--------------------------------------------------------------------------------
5 False ytpf,xyf->tpfx xtf,fr,tpfx->tpr
4 False tpfx,xtf->tpf fr,tpf->tpr
4 GEMM tpf,fr->tpr tpr->tpr
>>> path, desc = oe.contract_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="optimal")
>>> print(desc)
Complete contraction: xyf,xtf,ytpf,fr->tpr
Naive scaling: 6
Optimized scaling: 4
Naive FLOP count: 2.146e+10
Optimized FLOP count: 2.744e+07
Theoretical speedup: 782.283
Largest intermediate: 1.535e+05 elements
--------------------------------------------------------------------------------
scaling BLAS current remaining
--------------------------------------------------------------------------------
4 False xtf,xyf->tfy ytpf,fr,tfy->tpr
4 False tfy,ytpf->tfp fr,tfp->tpr
4 TDOT tfp,fr->tpr tpr->tpr
So we can see that the greedy
algorithm finds a path which is about 16
times slower than the optimal
one. In such cases, it might be worth using
one of the more exhaustive optimization strategies: 'optimal'
,
'branch-all'
or branch-2
(all of which will find the optimal path in
this example).
Customizing the Greedy Path¶
The greedy path is a local optimizer in that it only ever assesses pairs of
tensors to contract, assigning each a heuristic ‘cost’ and then choosing the
‘best’ of these. Custom greedy approaches can be implemented by supplying
callables to the cost_fn
and choose_fn
arguments of
greedy()
.
The Random-Greedy Path¶
For large and complex contractions the exhaustive approaches will be too slow
while the greedy path might be very far from optimal. In this case you might
want to consider the 'random-greedy'
path optimizer. This samples many
greedy paths and selects the best one found, which can often be exponentially
better than the average.
import opt_einsum as oe
import numpy as np
import math
eq, shapes = oe.helpers.rand_equation(40, 5, seed=1, d_max=2)
arrays = list(map(np.ones, shapes))
path_greedy = oe.contract_path(eq, *arrays, optimize='greedy')[1]
print(math.log2(path_greedy.opt_cost))
# 36.04683022558587
path_rand_greedy = oe.contract_path(eq, *arrays, optimize='random-greedy')[1]
print(math.log2(path_rand_greedy.opt_cost))
# 32.203616699170865
So here the random-greedy approach has found a path about
16 times quicker (= 2^(36 - 32)
).
This approach works by randomly choosing from the best n
contractions at
each step, weighted by a
Boltzmann factor with
respect to the contraction with the ‘best’ cost. As such, contractions with
very similar costs will be explored with equal probability, whereas those with
higher costs will be less likely, but still possible. In this way, the
optimizer can randomly explore the huge space of possible paths, but in a
guided manner.
The following graph roughly demonstrates the potential benefits of the
'random-greedy'
algorithm, here for large randomly generated contractions,
with either 8, 32 (the default), or 128 repeats:
Note
Bear in mind that such speed-ups are not guaranteed - it very much depends on how structured or complex your contractions are.
Customizing the Random-Greedy Path¶
The random-greedy optimizer can be customized by instantiating your own
RandomGreedy
object. Here you can control:
temperature
- how far to stray from the locally ‘best’ contractionsrel_temperature
- whether to normalize the temperaturenbranch
- how many contractions (branches) to consider at each stepcost_fn
- how to cost potential contractions
There are also the main RandomOptimizer
options:
max_repeats
- the maximum number of repeatsmax_time
- the maximum amount of time to run for (in seconds)minimize
- whether to minimize for total'flops'
or'size'
of the largest intermediate
For example, here we’ll create an optimizer, then change its temperature
whilst reusing it. We’ll also set a high max_repeats
and instead use a
maximum time to terminate the search:
optimizer = oe.RandomGreedy(max_time=2, max_repeats=1_000_000)
for T in [1000, 100, 10, 1, 0.1]:
optimizer.temperature = T
path_rand_greedy = oe.contract_path(eq, *arrays, optimize=optimizer)[1]
print(math.log2(optimizer.best['flops']))
# 32.81709395639357
# 32.67625007170783
# 31.719756871539033
# 31.62043317835677
# 31.253305891247
print(len(optimizer.costs)) # the total number of trials so far
# 2555
So we have improved a bit on the standard 'random-greedy'
(which does 32
repeats by default). The optimizer
object now stores both the best path
found so far - optimizer.path
- as well as the list of flop-costs and
maximum sizes found for each trial - optimizer.costs
and
optimizer.sizes
respectively.
Parallelizing the Random-Greedy Search¶
Since each greedy attempt is independent, the random-greedy approach is
naturally suited to parallelization. This can be automatically handled by
specifying the parallel
keyword like so:
# use same number of processes as cores
optimizer = oe.RandomGreedy(parallel=True)
# or use specific number of processes
optimizer = oe.RandomGreedy(parallel=4)
Warning
The pool-executor used to perform this parallelization is the
ProcessPoolExecutor
from the concurrent.futures
module. This
is only part of the standard library in Python 3. For Python 2 consider
installing the
backport of this module or see below.
For full control over the parallelization you can supply any pool-executor like object, which should have an API matching the Python 3 concurrent.futures module:
from concurrent.futures import ProcessPoolExecutor
pool = ProcessPoolExecutor()
optimizer = oe.RandomGreedy(parallel=pool, max_repeats=128)
path_rand_greedy = oe.contract_path(eq, *arrays, optimize=optimizer)[1]
print(math.log2(optimizer.best['flops']))
# 31.64992600300931
Other examples of such pools include:
The Dynamic Programming Path¶
The dynamic programming (DP) approach described in reference [1] provides an efficient way to find an asymptotically optimal contraction path by running the following steps:
Compute all traces, i.e. summations over indices occurring exactly in one input.
Decompose the contraction graph of inputs into disconnected subgraphs. Two inputs are connected if they share at least one summation index.
Find the contraction path for each of the disconnected subgraphs using a DP approach: The optimal contraction path for all sets of
n
(ranging from 1 to the number of inputs) connected tensors is found by combining sets ofm
andn-m
tensors.
Note that computing all the traces in the very beginning can never lead to a non-optimal contraction path.
Contractions of disconnected subgraphs can be optimized independently, which
still results in an optimal contraction path. However, the computational
complexity of finding the contraction path is drastically reduced: If the
subgraphs consist of n1
, n2
, … inputs, the computational complexity
is reduced from O(exp(n1 + n2 + ...))
to O(exp(n1) + exp(n2) + ...)
.
The DP approach will only perform pair contractions and by default will never compute intermediate outer products as in reference [1] it is shown that this always results in an asymptotically optimal contraction path.
A major optimization for DP is the cost capping strategy: The DP optimization only memorizes contractions for a subset of inputs, if the total cost for this contraction is smaller than the cost cap. The cost cap is initialized with the minimal possible cost, i.e. the product of all output dimensions, and is iteratively increased by multiplying it with the smallest dimension until a contraction path including all inputs is found.
Note that the worst case scaling of DP is exponential in the number of inputs. Nevertheless, if the contraction graph is not completely random, but exhibits a certain kind of structure, it can be used for large contraction graphs and is guaranteed to find an asymptotically optimal contraction path. For this reason it is the most frequently used contraction path optimizer in the field of tensor network states.
More specifically, the search is performed over connected subgraphs, which, for example, planar and tree-like graphs have far fewer of. As a rough guide, if the graph is planar, expressions with many tens of tensors are tractable, whereas if the graph is tree-like, expressions with many hundreds of tensors are tractable.
[1] Robert N. C. Pfeifer, Jutho Haegeman, and Frank Verstraete Phys. Rev. E 90, 033315 (2014). https://arxiv.org/abs/1304.6112
Customizing the Dynamic Programming Path¶
The default optimize='dp'
approach has sensible defaults but can be
customized with the DynamicProgramming
object.
import opt_einsum as oe
optimizer = oe.DynamicProgramming(
minimize='size', # optimize for largest intermediate tensor size
search_outer=True, # search through outer products as well
cost_cap=False, # don't use cost-capping strategy
)
oe.contract(eq, *arrays, optimize=optimizer)
Warning
Note that searching outer products will most likely drastically slow down the optimizer on all but the smallest examples.
Custom Path Optimizers¶
If you want to implement or just experiment with custom contaction paths then
you can easily by subclassing the PathOptimizer
object. For example, imagine we want to test the path that just blindly
contracts the first pair of tensors again and again. We would implement this
as:
import opt_einsum as oe
class MyOptimizer(oe.paths.PathOptimizer):
def __call__(self, inputs, output, size_dict, memory_limit=None):
return [(0, 1)] * (len(inputs) - 1)
Once defined we can use this as:
import numpy as np
# set-up a random contraction
eq, shapes = oe.helpers.rand_equation(10, 3, seed=42)
arrays = list(map(np.ones, shapes))
# set-up our optimizer and use it
optimizer = MyOptimizer()
path, path_info = oe.contract_path(eq, *arrays, optimize=optimizer)
print(path)
# [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
print(path_info.speedup)
# 133.21363671496357
Note that though we still get a considerable speedup over einsum
this is
of course not a good strategy to take in general.
Custom Random Optimizers¶
If your custom path optimizer is inherently random, then you can reuse all the machinery of the random-greedy approach. Namely:
A max-repeats or max-time approach
Minimization with respect to total flops or largest intermediate size
Parallelization using a pool-executor
This is done by subclassing the
RandomOptimizer
object and implementing a
setup
method. Here’s an example where we just randomly select any path
(again, although we get a considerable speedup over einsum
this is
not a good strategy to take in general):
from opt_einsum.path_random import ssa_path_compute_cost
class MyRandomOptimizer(oe.path_random.RandomOptimizer):
@staticmethod
def random_path(r, n, inputs, output, size_dict):
"""Picks a completely random contraction order.
"""
np.random.seed(r)
ssa_path = []
remaining = set(range(n))
while len(remaining) > 1:
i, j = np.random.choice(list(remaining), size=2, replace=False)
remaining.add(n + len(ssa_path))
remaining.remove(i)
remaining.remove(j)
ssa_path.append((i, j))
cost, size = ssa_path_compute_cost(ssa_path, inputs, output, size_dict)
return ssa_path, cost, size
def setup(self, inputs, output, size_dict):
"""Prepares the function and arguments to repeatedly call.
"""
n = len(inputs)
trial_fn = self.random_path
trial_args = (n, inputs, output, size_dict)
return trial_fn, trial_args
Which we can now instantiate using various other options:
optimizer = MyRandomOptimizer(max_repeats=1000, max_time=10,
parallel=True, minimize='size')
path, path_info = oe.contract_path(eq, *arrays, optimize=optimizer)
print(path)
# [(3, 4), (1, 3), (0, 3), (3, 5), (3, 4), (3, 4), (1, 0), (0, 1), (0, 1)]
print(path_info.speedup)
# 712829.9451056132
There are a few things to note here:
The core function (
MyRandomOptimizer.random_path
here), should take a trial numberr
as it first argumentIt should return a ssa_path (see
opt_einsum.paths.ssa_to_linear
andopt_einsum.paths.linear_to_ssa
) as well as a flops-cost and max-size.The
setup
method prepares this function, as well as any input to it, so that the trials will look roughly like[trial_fn(r, *trial_args) for r in range(max_repeats)]
. If you need to parse the standard arguments (into a network for example), it thus only needs to be done once per optimization
More details about RandomOptimizer
options can
be found in The Random-Greedy Path section.
Large Expressions with Greedy¶
Using the greedy method allows the contraction of hundreds of tensors. Here’s
an example from quantum of computing the inner product between two ‘Matrix
Product States’.
Graphically, if we represent each tensor as an O
, give it
the same number of ‘legs’ as it has indices, and join those legs when that
index is summed with another tensor, we get an expression for n
particles
that looks like:
O-O-O-O-O-O- -O-O-O-O-O-O
| | | | | | ... | | | | | |
O-O-O-O-O-O- -O-O-O-O-O-O
0 1 2 3 4 5 ........... n-2 n-1
The meaning of this is not that important other than its a large, useful
contraction. For n=100
it involves 200 different tensors and about 300
unique indices. With this many indices it can be useful to generate them with
the function get_symbol()
.
Let’s set up the required einsum string:
>>> import numpy as np
>>> import opt_einsum as oe
>>> n = 100
>>> phys_dim = 3
>>> bond_dim = 10
>>> # start with first site
... # O--
... # |
... # O--
>>> einsum_str = "ab,ac,"
>>> for i in range(1, n - 1):
... # set the upper left/right, middle and lower left/right indices
... # --O--
... # |
... # --O--
... j = 3 * i
... ul, ur, m, ll, lr = (oe.get_symbol(i)
... for i in (j - 1, j + 2, j, j - 2, j + 1))
>>> einsum_str += "{}{}{},{}{}{},".format(m, ul, ur, m, ll, lr)
>>> # finish with last site
... # --O
... # |
... # --O
>>> i = n - 1
>>> j = 3 * i
>>> ul, m, ll, = (oe.get_symbol(i) for i in (j - 1, j, j - 2))
>>> einsum_str += "{}{},{}{}".format(m, ul, m, ll)
Generate the shapes:
>>> def gen_shapes():
... yield (phys_dim, bond_dim)
... yield (phys_dim, bond_dim)
... for i in range(1, n - 1):
... yield(phys_dim, bond_dim, bond_dim)
... yield(phys_dim, bond_dim, bond_dim)
... yield (phys_dim, bond_dim)
... yield (phys_dim, bond_dim)
>>> shapes = tuple(gen_shapes())
Let’s time how long it takes to generate the expression ('greedy'
is used
by default, and we turn off the memory_limit
):
%timeit expr = oe.contract_expression(einsum_str, *shapes, memory_limit=-1)
76.2 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
This is pretty manageable, though we might want to think about splitting the expression up if we go a lot bigger. Importantly, we can then use this repeatedly with any set of matching arrays:
>>> arrays = [np.random.randn(*shp) / 4 for shp in shapes]
>>> expr(*arrays)
array(23.23628116)
>>> arrays = [np.random.randn(*shp) / 4 for shp in shapes]
>>> expr(*arrays)
array(-12.21091879)
And if we really want we can generate the full contraction path info:
>>> print(oe.contract_path(einsum_str, *arrays, memory_limit=-1)[1])
Complete contraction: ab,ac,dcf,dbe,gfi,geh,jil,jhk,mlo,mkn,por,pnq,sru,sqt,vux,vtw,yxA,ywz,BAD,BzC,EDG,ECF,HGJ,HFI,KJM,KIL,NMP,NLO,QPS,QOR,TSV,TRU,WVY,WUX,ZYÂ,ZXÁ,ÃÂÅ,ÃÁÄ,ÆÅÈ,ÆÄÇ,ÉÈË,ÉÇÊ,ÌËÎ,ÌÊÍ,ÏÎÑ,ÏÍÐ,ÒÑÔ,ÒÐÓ,ÕÔ×,ÕÓÖ,Ø×Ú,ØÖÙ,ÛÚÝ,ÛÙÜ,ÞÝà,ÞÜß,áàã,áßâ,äãæ,äâå,çæé,çåè,êéì,êèë,íìï,íëî,ðïò,ðîñ,óòõ,óñô,öõø,öô÷,ùøû,ù÷ú,üûþ,üúý,ÿþā,ÿýĀ,ĂāĄ,ĂĀă,ąĄć,ąăĆ,ĈćĊ,ĈĆĉ,ċĊč,ċĉČ,ĎčĐ,ĎČď,đĐē,đďĒ,ĔēĖ,ĔĒĕ,ėĖę,ėĕĘ,ĚęĜ,ĚĘě,ĝĜğ,ĝěĞ,ĠğĢ,ĠĞġ,ģĢĥ,ģġĤ,ĦĥĨ,ĦĤħ,ĩĨī,ĩħĪ,ĬīĮ,ĬĪĭ,įĮı,įĭİ,IJıĴ,IJİij,ĵĴķ,ĵijĶ,ĸķĺ,ĸĶĹ,ĻĺĽ,ĻĹļ,ľĽŀ,ľļĿ,ŁŀŃ,ŁĿł,ńŃņ,ńłŅ,Ňņʼn,ŇŅň,ŊʼnŌ,Ŋňŋ,ōŌŏ,ōŋŎ,ŐŏŒ,ŐŎő,œŒŕ,œőŔ,ŖŕŘ,ŖŔŗ,řŘś,řŗŚ,ŜśŞ,ŜŚŝ,şŞš,şŝŠ,ŢšŤ,ŢŠţ,ťŤŧ,ťţŦ,ŨŧŪ,ŨŦũ,ūŪŭ,ūũŬ,ŮŭŰ,ŮŬů,űŰų,űůŲ,ŴųŶ,ŴŲŵ,ŷŶŹ,ŷŵŸ,źŹż,źŸŻ,Žżſ,ŽŻž,ƀſƂ,ƀžƁ,ƃƂƅ,ƃƁƄ,Ɔƅƈ,ƆƄƇ,ƉƈƋ,ƉƇƊ,ƌƋƎ,ƌƊƍ,ƏƎƑ,ƏƍƐ,ƒƑƔ,ƒƐƓ,ƕƔƗ,ƕƓƖ,ƘƗƚ,ƘƖƙ,ƛƚƝ,ƛƙƜ,ƞƝƠ,ƞƜƟ,ơƠƣ,ơƟƢ,ƤƣƦ,ƤƢƥ,ƧƦƩ,Ƨƥƨ,ƪƩƬ,ƪƨƫ,ƭƬƯ,ƭƫƮ,ưƯƲ,ưƮƱ,ƳƲƵ,ƳƱƴ,ƶƵ,ƶƴ->
Naive scaling: 298
Optimized scaling: 5
Naive FLOP count: 1.031e+248
Optimized FLOP count: 1.168e+06
Theoretical speedup: 88264689284468460017580864156865782413140936705854966013600065426858041248009637246968036807489558012989638169986640870276510490846199301907401763236976204166215471281505344088317454144870323271826022036197984172898402324699098341524952317952.000
Largest intermediate: 3.000e+02 elements
--------------------------------------------------------------------------------
scaling BLAS current remaining
--------------------------------------------------------------------------------
4 TDOT dbe,ab->ade ac,dcf,gfi,geh,jil,jhk,mlo,mkn,por,pnq,sru,sqt,vux,vtw,yxA,ywz,BAD,BzC,EDG,ECF,HGJ,HFI,KJM,KIL,NMP,NLO,QPS,QOR,TSV,TRU,WVY,WUX,ZYÂ,ZXÁ,ÃÂÅ,ÃÁÄ,ÆÅÈ,ÆÄÇ,ÉÈË,ÉÇÊ,ÌËÎ,ÌÊÍ,ÏÎÑ,ÏÍÐ,ÒÑÔ,ÒÐÓ,ÕÔ×,ÕÓÖ,Ø×Ú,ØÖÙ,ÛÚÝ,ÛÙÜ,ÞÝà,ÞÜß,áàã,áßâ,äãæ,äâå,çæé,çåè,êéì,êèë,íìï,íëî,ðïò,ðîñ,óòõ,óñô,öõø,öô÷,ùøû,ù÷ú,üûþ,üúý,ÿþā,ÿýĀ,ĂāĄ,ĂĀă,ąĄć,ąăĆ,ĈćĊ,ĈĆĉ,ċĊč,ċĉČ,ĎčĐ,ĎČď,đĐē,đďĒ,ĔēĖ,ĔĒĕ,ėĖę,ėĕĘ,ĚęĜ,ĚĘě,ĝĜğ,ĝěĞ,ĠğĢ,ĠĞġ,ģĢĥ,ģġĤ,ĦĥĨ,ĦĤħ,ĩĨī,ĩħĪ,ĬīĮ,ĬĪĭ,įĮı,įĭİ,IJıĴ,IJİij,ĵĴķ,ĵijĶ,ĸķĺ,ĸĶĹ,ĻĺĽ,ĻĹļ,ľĽŀ,ľļĿ,ŁŀŃ,ŁĿł,ńŃņ,ńłŅ,Ňņʼn,ŇŅň,ŊʼnŌ,Ŋňŋ,ōŌŏ,ōŋŎ,ŐŏŒ,ŐŎő,œŒŕ,œőŔ,ŖŕŘ,ŖŔŗ,řŘś,řŗŚ,ŜśŞ,ŜŚŝ,şŞš,şŝŠ,ŢšŤ,ŢŠţ,ťŤŧ,ťţŦ,ŨŧŪ,ŨŦũ,ūŪŭ,ūũŬ,ŮŭŰ,ŮŬů,űŰų,űůŲ,ŴųŶ,ŴŲŵ,ŷŶŹ,ŷŵŸ,źŹż,źŸŻ,Žżſ,ŽŻž,ƀſƂ,ƀžƁ,ƃƂƅ,ƃƁƄ,Ɔƅƈ,ƆƄƇ,ƉƈƋ,ƉƇƊ,ƌƋƎ,ƌƊƍ,ƏƎƑ,ƏƍƐ,ƒƑƔ,ƒƐƓ,ƕƔƗ,ƕƓƖ,ƘƗƚ,ƘƖƙ,ƛƚƝ,ƛƙƜ,ƞƝƠ,ƞƜƟ,ơƠƣ,ơƟƢ,ƤƣƦ,ƤƢƥ,ƧƦƩ,Ƨƥƨ,ƪƩƬ,ƪƨƫ,ƭƬƯ,ƭƫƮ,ưƯƲ,ưƮƱ,ƳƲƵ,ƳƱƴ,ƶƵ,ƶƴ,ade->
4 TDOT dcf,ac->adf gfi,geh,jil,jhk,mlo,mkn,por,pnq,sru,sqt,vux,vtw,yxA,ywz,BAD,BzC,EDG,ECF,HGJ,HFI,KJM,KIL,NMP,NLO,QPS,QOR,TSV,TRU,WVY,WUX,ZYÂ,ZXÁ,ÃÂÅ,ÃÁÄ,ÆÅÈ,ÆÄÇ,ÉÈË,ÉÇÊ,ÌËÎ,ÌÊÍ,ÏÎÑ,ÏÍÐ,ÒÑÔ,ÒÐÓ,ÕÔ×,ÕÓÖ,Ø×Ú,ØÖÙ,ÛÚÝ,ÛÙÜ,ÞÝà,ÞÜß,áàã,áßâ,äãæ,äâå,çæé,çåè,êéì,êèë,íìï,íëî,ðïò,ðîñ,óòõ,óñô,öõø,öô÷,ùøû,ù÷ú,üûþ,üúý,ÿþā,ÿýĀ,ĂāĄ,ĂĀă,ąĄć,ąăĆ,ĈćĊ,ĈĆĉ,ċĊč,ċĉČ,ĎčĐ,ĎČď,đĐē,đďĒ,ĔēĖ,ĔĒĕ,ėĖę,ėĕĘ,ĚęĜ,ĚĘě,ĝĜğ,ĝěĞ,ĠğĢ,ĠĞġ,ģĢĥ,ģġĤ,ĦĥĨ,ĦĤħ,ĩĨī,ĩħĪ,ĬīĮ,ĬĪĭ,įĮı,įĭİ,IJıĴ,IJİij,ĵĴķ,ĵijĶ,ĸķĺ,ĸĶĹ,ĻĺĽ,ĻĹļ,ľĽŀ,ľļĿ,ŁŀŃ,ŁĿł,ńŃņ,ńłŅ,Ňņʼn,ŇŅň,ŊʼnŌ,Ŋňŋ,ōŌŏ,ōŋŎ,ŐŏŒ,ŐŎő,œŒŕ,œőŔ,ŖŕŘ,ŖŔŗ,řŘś,řŗŚ,ŜśŞ,ŜŚŝ,şŞš,şŝŠ,ŢšŤ,ŢŠţ,ťŤŧ,ťţŦ,ŨŧŪ,ŨŦũ,ūŪŭ,ūũŬ,ŮŭŰ,ŮŬů,űŰų,űůŲ,ŴųŶ,ŴŲŵ,ŷŶŹ,ŷŵŸ,źŹż,źŸŻ,Žżſ,ŽŻž,ƀſƂ,ƀžƁ,ƃƂƅ,ƃƁƄ,Ɔƅƈ,ƆƄƇ,ƉƈƋ,ƉƇƊ,ƌƋƎ,ƌƊƍ,ƏƎƑ,ƏƍƐ,ƒƑƔ,ƒƐƓ,ƕƔƗ,ƕƓƖ,ƘƗƚ,ƘƖƙ,ƛƚƝ,ƛƙƜ,ƞƝƠ,ƞƜƟ,ơƠƣ,ơƟƢ,ƤƣƦ,ƤƢƥ,ƧƦƩ,Ƨƥƨ,ƪƩƬ,ƪƨƫ,ƭƬƯ,ƭƫƮ,ưƯƲ,ưƮƱ,ƳƲƵ,ƳƱƴ,ƶƵ,ƶƴ,ade,adf->
4 GEMM ƶƵ,ƳƲƵ->ƳƶƲ gfi,geh,jil,jhk,mlo,mkn,por,pnq,sru,sqt,vux,vtw,yxA,ywz,BAD,BzC,EDG,ECF,HGJ,HFI,KJM,KIL,NMP,NLO,QPS,QOR,TSV,TRU,WVY,WUX,ZYÂ,ZXÁ,ÃÂÅ,ÃÁÄ,ÆÅÈ,ÆÄÇ,ÉÈË,ÉÇÊ,ÌËÎ,ÌÊÍ,ÏÎÑ,ÏÍÐ,ÒÑÔ,ÒÐÓ,ÕÔ×,ÕÓÖ,Ø×Ú,ØÖÙ,ÛÚÝ,ÛÙÜ,ÞÝà,ÞÜß,áàã,áßâ,äãæ,äâå,çæé,çåè,êéì,êèë,íìï,íëî,ðïò,ðîñ,óòõ,óñô,öõø,öô÷,ùøû,ù÷ú,üûþ,üúý,ÿþā,ÿýĀ,ĂāĄ,ĂĀă,ąĄć,ąăĆ,ĈćĊ,ĈĆĉ,ċĊč,ċĉČ,ĎčĐ,ĎČď,đĐē,đďĒ,ĔēĖ,ĔĒĕ,ėĖę,ėĕĘ,ĚęĜ,ĚĘě,ĝĜğ,ĝěĞ,ĠğĢ,ĠĞġ,ģĢĥ,ģġĤ,ĦĥĨ,ĦĤħ,ĩĨī,ĩħĪ,ĬīĮ,ĬĪĭ,įĮı,įĭİ,IJıĴ,IJİij,ĵĴķ,ĵijĶ,ĸķĺ,ĸĶĹ,ĻĺĽ,ĻĹļ,ľĽŀ,ľļĿ,ŁŀŃ,ŁĿł,ńŃņ,ńłŅ,Ňņʼn,ŇŅň,ŊʼnŌ,Ŋňŋ,ōŌŏ,ōŋŎ,ŐŏŒ,ŐŎő,œŒŕ,œőŔ,ŖŕŘ,ŖŔŗ,řŘś,řŗŚ,ŜśŞ,ŜŚŝ,şŞš,şŝŠ,ŢšŤ,ŢŠţ,ťŤŧ,ťţŦ,ŨŧŪ,ŨŦũ,ūŪŭ,ūũŬ,ŮŭŰ,ŮŬů,űŰų,űůŲ,ŴųŶ,ŴŲŵ,ŷŶŹ,ŷŵŸ,źŹż,źŸŻ,Žżſ,ŽŻž,ƀſƂ,ƀžƁ,ƃƂƅ,ƃƁƄ,Ɔƅƈ,ƆƄƇ,ƉƈƋ,ƉƇƊ,ƌƋƎ,ƌƊƍ,ƏƎƑ,ƏƍƐ,ƒƑƔ,ƒƐƓ,ƕƔƗ,ƕƓƖ,ƘƗƚ,ƘƖƙ,ƛƚƝ,ƛƙƜ,ƞƝƠ,ƞƜƟ,ơƠƣ,ơƟƢ,ƤƣƦ,ƤƢƥ,ƧƦƩ,Ƨƥƨ,ƪƩƬ,ƪƨƫ,ƭƬƯ,ƭƫƮ,ưƯƲ,ưƮƱ,ƳƱƴ,ƶƴ,ade,adf,ƳƶƲ->
4 GEMM ƶƴ,ƳƱƴ->ƳƶƱ gfi,geh,jil,jhk,mlo,mkn,por,pnq,sru,sqt,vux,vtw,yxA,ywz,BAD,BzC,EDG,ECF,HGJ,HFI,KJM,KIL,NMP,NLO,QPS,QOR,TSV,TRU,WVY,WUX,ZYÂ,ZXÁ,ÃÂÅ,ÃÁÄ,ÆÅÈ,ÆÄÇ,ÉÈË,ÉÇÊ,ÌËÎ,ÌÊÍ,ÏÎÑ,ÏÍÐ,ÒÑÔ,ÒÐÓ,ÕÔ×,ÕÓÖ,Ø×Ú,ØÖÙ,ÛÚÝ,ÛÙÜ,ÞÝà,ÞÜß,áàã,áßâ,äãæ,äâå,çæé,çåè,êéì,êèë,íìï,íëî,ðïò,ðîñ,óòõ,óñô,öõø,öô÷,ùøû,ù÷ú,üûþ,üúý,ÿþā,ÿýĀ,ĂāĄ,ĂĀă,ąĄć,ąăĆ,ĈćĊ,ĈĆĉ,ċĊč,ċĉČ,ĎčĐ,ĎČď,đĐē,đďĒ,ĔēĖ,ĔĒĕ,ėĖę,ėĕĘ,ĚęĜ,ĚĘě,ĝĜğ,ĝěĞ,ĠğĢ,ĠĞġ,ģĢĥ,ģġĤ,ĦĥĨ,ĦĤħ,ĩĨī,ĩħĪ,ĬīĮ,ĬĪĭ,įĮı,įĭİ,IJıĴ,IJİij,ĵĴķ,ĵijĶ,ĸķĺ,ĸĶĹ,ĻĺĽ,ĻĹļ,ľĽŀ,ľļĿ,ŁŀŃ,ŁĿł,ńŃņ,ńłŅ,Ňņʼn,ŇŅň,ŊʼnŌ,Ŋňŋ,ōŌŏ,ōŋŎ,ŐŏŒ,ŐŎő,œŒŕ,œőŔ,ŖŕŘ,ŖŔŗ,řŘś,řŗŚ,ŜśŞ,ŜŚŝ,şŞš,şŝŠ,ŢšŤ,ŢŠţ,ťŤŧ,ťţŦ,ŨŧŪ,ŨŦũ,ūŪŭ,ūũŬ,ŮŭŰ,ŮŬů,űŰų,űůŲ,ŴųŶ,ŴŲŵ,ŷŶŹ,ŷŵŸ,źŹż,źŸŻ,Žżſ,ŽŻž,ƀſƂ,ƀžƁ,ƃƂƅ,ƃƁƄ,Ɔƅƈ,ƆƄƇ,ƉƈƋ,ƉƇƊ,ƌƋƎ,ƌƊƍ,ƏƎƑ,ƏƍƐ,ƒƑƔ,ƒƐƓ,ƕƔƗ,ƕƓƖ,ƘƗƚ,ƘƖƙ,ƛƚƝ,ƛƙƜ,ƞƝƠ,ƞƜƟ,ơƠƣ,ơƟƢ,ƤƣƦ,ƤƢƥ,ƧƦƩ,Ƨƥƨ,ƪƩƬ,ƪƨƫ,ƭƬƯ,ƭƫƮ,ưƯƲ,ưƮƱ,ade,adf,ƳƶƲ,ƳƶƱ->
5 TDOT ade,geh->adgh gfi,jil,jhk,mlo,mkn,por,pnq,sru,sqt,vux,vtw,yxA,ywz,BAD,BzC,EDG,ECF,HGJ,HFI,KJM,KIL,NMP,NLO,QPS,QOR,TSV,TRU,WVY,WUX,ZYÂ,ZXÁ,ÃÂÅ,ÃÁÄ,ÆÅÈ,ÆÄÇ,ÉÈË,ÉÇÊ,ÌËÎ,ÌÊÍ,ÏÎÑ,ÏÍÐ,ÒÑÔ,ÒÐÓ,ÕÔ×,ÕÓÖ,Ø×Ú,ØÖÙ,ÛÚÝ,ÛÙÜ,ÞÝà,ÞÜß,áàã,áßâ,äãæ,äâå,çæé,çåè,êéì,êèë,íìï,íëî,ðïò,ðîñ,óòõ,óñô,öõø,öô÷,ùøû,ù÷ú,üûþ,üúý,ÿþā,ÿýĀ,ĂāĄ,ĂĀă,ąĄć,ąăĆ,ĈćĊ,ĈĆĉ,ċĊč,ċĉČ,ĎčĐ,ĎČď,đĐē,đďĒ,ĔēĖ,ĔĒĕ,ėĖę,ėĕĘ,ĚęĜ,ĚĘě,ĝĜğ,ĝěĞ,ĠğĢ,ĠĞġ,ģĢĥ,ģġĤ,ĦĥĨ,ĦĤħ,ĩĨī,ĩħĪ,ĬīĮ,ĬĪĭ,įĮı,įĭİ,IJıĴ,IJİij,ĵĴķ,ĵijĶ,ĸķĺ,ĸĶĹ,ĻĺĽ,ĻĹļ,ľĽŀ,ľļĿ,ŁŀŃ,ŁĿł,ńŃņ,ńłŅ,Ňņʼn,ŇŅň,ŊʼnŌ,Ŋňŋ,ōŌŏ,ōŋŎ,ŐŏŒ,ŐŎő,œŒŕ,œőŔ,ŖŕŘ,ŖŔŗ,řŘś,řŗŚ,ŜśŞ,ŜŚŝ,şŞš,şŝŠ,ŢšŤ,ŢŠţ,ťŤŧ,ťţŦ,ŨŧŪ,ŨŦũ,ūŪŭ,ūũŬ,ŮŭŰ,ŮŬů,űŰų,űůŲ,ŴųŶ,ŴŲŵ,ŷŶŹ,ŷŵŸ,źŹż,źŸŻ,Žżſ,ŽŻž,ƀſƂ,ƀžƁ,ƃƂƅ,ƃƁƄ,Ɔƅƈ,ƆƄƇ,ƉƈƋ,ƉƇƊ,ƌƋƎ,ƌƊƍ,ƏƎƑ,ƏƍƐ,ƒƑƔ,ƒƐƓ,ƕƔƗ,ƕƓƖ,ƘƗƚ,ƘƖƙ,ƛƚƝ,ƛƙƜ,ƞƝƠ,ƞƜƟ,ơƠƣ,ơƟƢ,ƤƣƦ,ƤƢƥ,ƧƦƩ,Ƨƥƨ,ƪƩƬ,ƪƨƫ,ƭƬƯ,ƭƫƮ,ưƯƲ,ưƮƱ,adf,ƳƶƲ,ƳƶƱ,adgh->
...
4 TDOT Ğğ,ĠğĢ->ĠĞĢ ĠĞġ,ģĢĥ,ģġĤ,Ĥĥ,ĠĞĢ->
4 GEMM ĠĞĢ,ĠĞġ->ġĢ ģĢĥ,ģġĤ,Ĥĥ,ġĢ->
4 GEMM Ĥĥ,ģĢĥ->ģĢĤ ģġĤ,ġĢ,ģĢĤ->
4 TDOT ģĢĤ,ģġĤ->ġĢ ġĢ,ġĢ->
2 DOT ġĢ,ġĢ-> ->
Where we can see the speedup over a naive einsum is about 10^241
, not bad!
Reusing Intermediaries with Dask¶
Dask provides a computational framework where
arrays and the computations on them are built up into a ‘task graph’ before
computation. Since opt_einsum
is compatible with dask
arrays this
means that multiple contractions can be built into the same task graph, which
then automatically reuses any shared arrays and contractions.
For example, imagine the two expressions:
>>> contraction1 = 'ab,dca,eb,cde'
>>> contraction2 = 'ab,cda,eb,cde'
>>> sizes = {l: 10 for l in 'abcde'}
The contraction 'ab,eb'
is shared between them and could only be done once.
First, let’s set up some numpy
arrays:
>>> terms1, terms2 = contraction1.split(','), contraction2.split(',')
>>> terms = set((*terms1, *terms2))
>>> terms
{'ab', 'cda', 'cde', 'dca', 'eb'}
>>> import numpy as np
>>> np_arrays = {s: np.random.randn(*(sizes[c] for c in s)) for s in terms}
>>> # filter the arrays needed for each expression
>>> np_ops1 = [np_arrays[s] for s in terms1]
>>> np_ops2 = [np_arrays[s] for s in terms2]
Typically we would compute these expressions separately:
>>> oe.contract(contraction1, *np_ops1)
array(114.78314052)
>>> oe.contract(contraction2, *np_ops2)
array(-75.55902751)
However, if we use dask arrays we can combine the two operations, so let’s set those up:
>>> import dask.array as da
>>> da_arrays = {s: da.from_array(np_arrays[s], chunks=1000, name=s) for s in inputs}
>>> da_arrays
{'ab': dask.array<ab, shape=(10, 10), dtype=float64, chunksize=(10, 10)>,
'cda': dask.array<cda, shape=(10, 10, 10), dtype=float64, chunksize=(10, 10, 10)>,
'cde': dask.array<cde, shape=(10, 10, 10), dtype=float64, chunksize=(10, 10, 10)>,
'dca': dask.array<dca, shape=(10, 10, 10), dtype=float64, chunksize=(10, 10, 10)>,
'eb': dask.array<eb, shape=(10, 10), dtype=float64, chunksize=(10, 10)>}
>>> da_ops1 = [da_arrays[s] for s in terms1]
>>> da_ops2 = [da_arrays[s] for s in terms2]
Note chunks
is a required argument relating to how the arrays are stored (see array-creation). Now we can perform the contraction:
>>> # these won't be immediately evaluated
>>> dy1 = oe.contract(contraction1, *da_ops1, backend='dask')
>>> dy2 = oe.contract(contraction2, *da_ops2, backend='dask')
>>> # wrap them in delayed to combine them into the same computation
>>> from dask import delayed
>>> dy = delayed([dy1, dy2])
>>> dy
Delayed('list-3af82335-b75e-47d6-b800-68490fc865fd')
As suggested by the name Delayed
, we have a placeholder for the result
so far. When we want to perform the computation we can call:
>>> dy.compute()
[114.78314052155015, -75.55902750513113]
The above matches the canonical numpy result. The computation can even be handled by various schedulers - see scheduling. Finally, to check we are reusing intermediaries, we can view the task graph generated for the computation:
>>> dy.visualize(optimize_graph=True)
Note
For sharing intermediates with other backends see Sharing Intermediates. Dask graphs are particularly useful for reusing intermediates beyond just contractions and can allow additional parallelization.
Function Reference¶
|
Evaluates the Einstein summation convention on the operands. |
|
Find a contraction order ‘path’, without performing the contraction. |
|
Generate a reusable expression for a given contraction with specific shapes, which can, for example, be cached. |
Helper class for storing an explicit |
|
A printable object to contain information about a contraction path. |
|
|
Computes all possible pair contractions in a depth-first recursive manner, sieving results based on |
|
Finds the path by a three stage algorithm: |
|
|
Get the symbol corresponding to int |
|
Context in which contract intermediate results are shared. |
|
Base class for different path optimizers to inherit from. |
|
|
Explores possible pair contractions in a depth-first recursive manner like the |
Base class for running any random path finder that benefits from repeated calling, possibly in a parallel fashion. |
|
|
|
Finds the optimal path of pairwise contractions without intermediate outer products based a dynamic programming approach presented in Phys. |
Changelog¶
3.3.0 / 2020-07-19¶
Adds a object
backend for optimized contractions on arbitrary Python objects.
New Features¶
(GH#145) Adds a
object
based backend so thatcontract(backend='object')
can be used on arbitrary objects such as SymPy symbols.
Enhancements¶
(GH#140) Better error messages when the requested
contract
backend cannot be found.(GH#141) Adds a check with RandomOptimizers to ensure the objects are not accidentally reused for different contractions.
(GH#149) Limits the
remaining
category for thecontract_path
output to only show up to 20 tensors to prevent issues with the quadratically scaling memory requirements and the number of print lines for large contractions.
3.1.0 / 2019-09-30¶
Adds a new dynamic programming algorithm to the suite of paths.
3.0.0 / 2019-08-10¶
This release moves opt_einsum
to be backend agnostic while adding support
additional backends such as Jax and Autograd. Support for Python 2.7 has been dropped and Python 3.5 will become the new minimum version, a Python deprecation policy equivalent to NumPy’s has been adopted.
New Features¶
(GH#78) A new random-optimizer has been implemented which uses Boltzmann weighting to explore alternative near-minimum paths using greedy-like schemes. This provides a fairly large path performance enhancements with a linear path time overhead.
(GH#78) A new PathOptimizer class has been implemented to provide a framework for building new optimizers. An example is that now custom cost functions can now be provided in the greedy formalism for building custom optimizers without a large amount of additional code.
(GH#81) The
backend="auto"
keyword has been implemented forcontract
allowing automatic detection of the correct backend to use based off provided tensors in the contraction.(GH#88) Autograd and Jax support have been implemented.
(GH#96) Deprecates Python 2 functionality and devops improvements.
2.3.0 / 2018-12-01¶
This release primarily focuses on expanding the suite of available path technologies to provide better optimization characistics for 4-20 tensors while decreasing the time to find paths for 50-200+ tensors. See Path Overview for more information.
New Features¶
(GH#60) A new
greedy
implementation has been added which is up to two orders of magnitude faster for 200 tensors.(GH#73) Adds a new
branch
path that usesgreedy
ideas to prune theoptimal
exploration space to provide a better path thangreedy
at suboptimal
cost.(GH#73) Adds a new
auto
keyword to theopt_einsum.contract()
path
option. This keyword automatically chooses the best path technology that takes under 1ms to execute.
Enhancements¶
(GH#61) The
opt_einsum.contract()
path
keyword has been changed tooptimize
to more closely match NumPy.path
will be deprecated in the future.(GH#61) The
opt_einsum.contract_path()
now returns aopt_einsum.contract.PathInfo()
object that can be queried for the scaling, flops, and intermediates of the path. The print representation of this object is identical to before.(GH#61) The default
memory_limit
is now unlimited by default based on community feedback.(GH#66) The Torch backend will now use
tensordot
when using a version of Torch which includes this functionality.(GH#68) Indices can now be any hashable object when provided in the “Interleaved Input” syntax.
(GH#74) Allows the default transpose operation to be overridden to take advantage of more advanced tensor transpose libraries.
(GH#73) The
optimal
path is now significantly faster.(GH#81) A documentation pass for v3.0.
Bug fixes¶
(GH#72) Fixes the “Interleaved Input” syntax and adds documentation.
2.1.3 / 2018-8-23¶
Bug fixes¶
Fixes unicode issue for large numbers of tensors in Python 2.7.
Fixes unicode install bug in README.md.
2.1.0 / 2018-8-15¶
opt_einsum
continues to improve its support for additional backends beyond NumPy with PyTorch.
We have also published the opt_einsum package in the Journal of Open Source Software. If you use this package in your work, please consider citing us!
New features¶
PyTorch backend support
Tensorflow eager-mode execution backend support
Enhancements¶
Intermediate tensordot-like expressions are now ordered to avoid transposes.
CI now uses conda backend to better support GPU and tensor libraries.
Now accepts arbitrary unicode indices rather than a subset.
New auto path option which switches between optimal and greedy at four tensors.
Bug fixes¶
Fixed issue where broadcast indices were incorrectly locked out of tensordot-like evaluations even after their dimension was broadcast.
2.0.1 / 2018-6-28¶
New Features¶
Allows unlimited Unicode indices.
Adds a Journal of Open-Source Software paper.
Minor documentation improvements.
2.0.0 / 2018-5-17¶
opt_einsum
is a powerful tensor contraction order optimizer for NumPy and related ecosystems.
New Features¶
Expressions can be precompiled so that the expression optimization need not happen multiple times.
The greedy order optimization algorithm has been tuned to be able to handle hundreds of tensors in several seconds.
Input indices can now be unicode so that expressions can have many thousands of indices.
GPU and distributed computing backends have been added such as Dask, TensorFlow, CUPy, Theano, and Sparse.
Bug Fixes¶
An error affecting cases where opt_einsum mistook broadcasting operations for matrix multiply has been fixed.
Most error messages are now more expressive.
1.0.0 / 2016-10-14¶
Einsum is a very powerful function for contracting tensors of arbitrary dimension and index. However, it is only optimized to contract two terms at a time resulting in non-optimal scaling for contractions with many terms. Opt_einsum aims to fix this by optimizing the contraction order which can lead to arbitrarily large speed ups at the cost of additional intermediate tensors.
Opt_einsum is also implemented into the np.einsum function as of NumPy v1.12.
New Features¶
Tensor contraction order optimizer.
opt_einsum.contract()
as a drop-in replacement fornumpy.einsum()
.