pockit.base.fastfunc
JITed, vectorized functions with automatic differentiation to compute value, gradient, and hessian.
FastFunc takes a Sympy expression and a list of Sympy symbols as function arguments, generates JITed, vectorized functions for value, gradient, and hessian. Gradient and hessian are generated by automatic differentiation and in sparse format, i.e., only non-zero elements are computed.
Suppose the expression of the input function is f(a_1, a_2, ..., a_n)
, with n
arguments
(a_1, a_2, ..., a_n)
as the second argument args
. The generated functions F
, G
, and H
take two arguments x
and k
, where x
is a 1D array of length n * k
, and k
is an integer.
The first n
elements of x
are the values of a_1
at k
different points, the next n
elements are
the values of a_2
, and so on. The return value of F
is a 1D array of length k
, where the i
-th
element is the value of f(a_1, a_2, ..., a_n)
at the i
-th point. The return value of G
is a 2D array
of shape (len(G_index), k)
, where G_index
contains the indices of non-zero elements in the gradient matrix.
The return value of H
is a 2D array of shape (len(H_index_row), k)
, where H_index_row, H_index_col
contain the indices of non-zero elements in the lower triangular part of the hessian matrix.
If simplify
is True
, every symbolic expression will be simplified (by sympy.simplify()
) before
being compiled. This will slow down the compilation speed.
If parallel
is True
, the parallel
flag will be passed to the Numba JIT compiler,
which will generate parallel code for multicore CPUs.
This will slow down both compilation and sometimes execution speed.
If fastmath
is True
, the fastmath
flag will be passed to the Numba JIT compiler.
See Numba
and LLVM documentation for details.
If cache
is a path to a file, the FastFunc object will do the following: 1. If the file does not exist, the generated
functions will be written to the file so they can be loaded later. 2. If the file exists and there is a hash
string at the beginning of the file (auto-generated by FastFunc), the hash will be compared with the hash of the
current function to determine whether to load the file directly or overwrite it. 3. If the file exists and there
is no hash string at the beginning of the file, the file is considered a user-provided file and will be loaded
directly.
Arguments:
- function:
Sympy.Expr
of the function. - args:
sympy.Symbol
s of the function arguments. - simplify: Whether to use
Sympy.simplify()
to simplify expressions before compilation. - parallel: Whether to use Numba
parallel
mode. - fastmath: Whether to use Numba
fastmath
mode. - cache: Path to a file to cache the generated functions.
Vectorized function to compute value.
Vectorized function to compute gradient.
Vectorized function to compute hessian.
Indices of non-zero elements for gradient.