Nikola Stoyanov

Understanding Adjoint Optimization

Understanding adjoints

Writing up my PhD is a pain mainly because of the detailed explanations that need to be provided. A technique I used throughout my project is to solve parameter optimization problems using the adjoint method. I learned the technique and used it, but never really got it until reading an amazing paper. This post represents my notes and Julia implementation of the work detailed in https://math.mit.edu/~stevenj/18.336/adjoint.pdf by Steven G. Johnson which is by far the best description I have read on Adjoint methods. I am really grateful that I typed up these notes some time ago which makes my writing life easier... I hope they help someone to implement similar problems in Julia.

Problem

This work implements the solution to the optimization problem for the Schrödinger eigenequation described in https://math.mit.edu/~stevenj/18.336/adjoint.pdf by Steven G. Johnson.

Boilerplate

Let's setup the Julia environment. Here I will use the Optim.jl to solve the optimization problem and IterativeSolvers.jl which provides the conjugate gradient method to solve the matrix equations.

In [1]:
using Pkg; Pkg.activate(".")
Out[1]:
Activating environment at `~/git/Personal_Website/content/post/2020-06-14-adjoints/Project.toml`

In [2]:
using Optim, Plots, LineSearches, SparseArrays, LinearAlgebra, IterativeSolvers

The Adjoint Method

I recommend you read Steven G. Johnson's paper. The descriptions here are taken from that paper and summarized.

Given a set of equations M parametrized by a parameter P we optimize a function \(g(x, p)\) on some space \(x\). For optimizations we need the gradient \(\frac{dg}{dp}\). The adjoint method rewrites the equation such that the we get a constant cost, it does not depend on the size of \(P\) which is great if the dimension of the parameter space is large.

The change in the equation is shown next. Given \(Ax=b\) with \(A\) and \(b\) parametrized by \(P\):

\[ \begin{equation} \frac{dg}{dp} = g_p + g_x x_p \end{equation} \]

Where \(x\) is \(MxM\) with \(g_p\) being a Jacobian of size \(1xP\), \(g_x\) being a Jacobian on size \(1xM\), \(x_p\) of size \(MxP\). The difficulty comes from plugging in \(x_p\).

To get \(x_p\) we differentiate \(Ax=b\) by \(p_i\) for each parameter to get:

\[ \begin{equation} x_{pi} = A^{-1} (b_{pi} - A_{pi}x) \end{equation} \]

Plugging Equation 2 in Equation 1 we get:

\[ \begin{equation} g_x x_p = g_x A^{-1} (b_p - A_p x) = [g_x A^{-1}] (b_p - A_p x) \end {equation} \]

With the terms \(g_x A^{-1}\) of size \(1xM\) and \(b_p - A_p x\) of size \(MxP\). The difficulty being is that the computational cost is \(\mathcal{O}(M^2P)\).

To make the problem easier we can rearrange the equation. If we multiply by \(\lambda^T = g_x A^{-1}\) and from Equation 1 and Equation 2 we get:

\[ \begin{equation} \frac{dg}{dp} = g_p - \lambda^T (A_p x - b_p) \end{equation} \]

Where the multiplication now gives a tight bound \(\Theta(MP)\).

Eigenequation

The equation for the problem in the paper is an eigenequation so we need to also follow the section which derives this case. Equations of this type are given by:

\[ \begin{equation} Ax = \alpha x \end{equation} \]

In block form we get:

\[ \begin{equation} \hat{x} = \begin{bmatrix} x \\ \alpha \end{bmatrix} \end{equation} \]

And for \(f = Ax - \alpha x\):

\[ \begin{equation} \hat{f} = \begin{bmatrix} f \\ x^T x - 1 \end{bmatrix} \end{equation} \]

Using the same substitution for \(\lambda\) we get:

\[ \begin{equation} \hat{\lambda} = \begin{bmatrix} \lambda \\ \beta \end{bmatrix} \end{equation} \]

Putting this together:

\[ \begin{equation} (A - \alpha) \lambda = g_x^T - 2\beta x \end{equation} \]

\[ \begin{equation} -x^T \lambda = g_a \end{equation} \]

Schrödinger eigenequation

The Schrödinger eigenequation is given by:

\[ \begin{equation} \left(-\frac{d^2}{dx^2} + V(x)\right) \psi(x) = E\psi(x) \end{equation} \]

With periodic boundary conditions \(\psi(x+2) = \psi(x)\) over a domain \(x \in \).

In direct solving we are given \(V(x)\) and solve for \(\psi\) and \(E\), in parameter optimization (inverse) we are looking for \(V(x)\) that minimizes some cost. In this case the cost is:

\[ \begin{equation} g = \int_{-1}^{1} |\psi(x) - \psi_0 (x)|^2 dx \end{equation} \]

To solve the optimization problem we will:

In other cases the precalculated target data can come from experimental measurements, recorded data etc.

Implementation

We setup the domain and the finite difference discretisation.

In [3]:
const m = 0.02
const x = [i for i in -1:m :1]
const N = length(x)
const dx = x[2] - x[1]
Out[3]:
0.020000000000000018

We pick a function to generate the target data. This will later be used to compute the cost function and in inverse problems this is not known i.e. we are trying to backcalculate it. In my case such data comes from physical tests using thermocouples (very different than the example here but it can be solved with the same approach).

In [4]:
# Target solution, normalize and pick sign.
Ψ0 = 1.0 .+ sin.(π .* x .+ cos.(3 * π .* x))
Ψ0 = Ψ0 / sqrt(transpose(Ψ0) * Ψ0)
if(sum(Ψ0) < 0)
    Ψ0 = -Ψ0
end

In real problems we might be able to set prior knowledge about the problem but here we are just going to kick-start from an initial guess at \(0\).

In [5]:
const V0 = [0.0 for i in -1:m :1];

Discretise the space using center-difference scheme and set the periodic boundary conditions.

In [6]:
Mesh = spdiagm(-1 => [1.0 for i in 1:N-1],
               0 => [2.0 for i in 1:N],
               1 => [1.0 for i in 1:N-1])

# Periodic boundary conditions.
Mesh[1, N] = 1.0
Mesh[N, 1] = 1.0

# Build mesh.
Mesh = -Mesh / dx^2;

Next we need to setup the optimization problem. For this we need to pass the parameters which we setup above and create a data structure to store the current calculation state.

In [7]:
# Build constants.
const p = [N, dx, Ψ0, Mesh]

mutable struct Schrodinger{T1, T2, T3}
    A::T1
    E::T2
    Ψ::T3
end

schr = Schrodinger(Mesh, 0.0, zeros(N));

The finite difference disretisation with an initial guess allows us to solve \(Ax=b\) using Arnoldi iterations (Conjugate gradient) and computing the cost function as the least squares error between the generated target data and predictions form the the current solution state.

In [8]:
function schrodinger_fd(V, schr, p)
    N, _, Ψ0, Mesh = p

    A = Mesh + Diagonal(V)

    # Smallest values.
    E = eigvals(Matrix(A))[1]
    Ψ = eigvecs(Matrix(A))[:, 1]

    # Pick sign.
    if(sum(Ψ) < 0)
        Ψ = -Ψ
    end

    schr.A = A
    schr.E = E
    schr.Ψ = Ψ

    # Least-squares error.
    err = 0.0
    for i in 1:N
        err += (Ψ[i] - Ψ0[i])^2
    end

    return err
end
Out[8]:
schrodinger_fd (generic function with 1 method)

Finally, we add the derived equations for the adjoint and pass the solution state to calculate the gradient. The adjoint equations are solved with the Conjugate gradient method.

In [9]:
function schrodinger_fd_adj(gp, V, schr, p)
    N, dx, Ψ0, _ = p

    A = schr.A
    E = schr.E
    Ψ = schr.Ψ

    gΨ = Ψ - Ψ0
    g = transpose(gΨ) * gΨ * dx
    gΨ = gΨ * 2 * dx

    P(Ψx) = Ψx - Ψ * (transpose(Ψ) * Ψx)
    λ = cg(A - Diagonal([E for i in 1:N]), P(gΨ))
    λ = P(λ)

    copyto!(gp, -real(conj(λ) .* Ψ))
end
Out[9]:
schrodinger_fd_adj (generic function with 1 method)

We can now fill in the initial state with the guess for \(V_0\).

In [10]:
schrodinger_fd(V0, schr, p)
Out[10]:
0.36358289201172456

After the initial state we can pass with to Optim.jl and optimize using the Conjugate gradient method and the Nocedal and Wright line search.

In [11]:
res= optimize(V0 -> schrodinger_fd(V0, schr, p),
              (gp, V) -> schrodinger_fd_adj(gp, V, schr, p),
              V0,
              ConjugateGradient(;alphaguess = LineSearches.InitialStatic(),
                                linesearch = LineSearches.StrongWolfe()),
              Optim.Options(iterations = 1500))

show(res)
V = Optim.minimizer(res)
Out[11]:
 * Status: failure (reached maximum number of iterations)

 * Candidate solution
    Minimizer: [-6.97e+01, -6.31e+01, -5.21e+01,  ...]
    Minimum:   1.073967e-02

 * Found with
    Algorithm:     Conjugate Gradient
    Initial Point: [0.00e+00, 0.00e+00, 0.00e+00,  ...]

 * Convergence measures
    |x - x'|               = 2.25e-02 ≰ 0.0e+00
    |x - x'|/|x'|          = 2.76e-04 ≰ 0.0e+00
    |f(x) - f(x')|         = 5.17e-06 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 4.81e-04 ≰ 0.0e+00
    |g(x)|                 = 2.50e-07 ≰ 1.0e-08

 * Work counters
    Seconds run:   205  (vs limit Inf)
    Iterations:    1500
    f(x) calls:    25492
    ∇f(x) calls:   25492

Out[11]:
101-element Array{Float64,1}:
 -69.7179927132526  
 -63.069608186066425
 -52.087451574006884
 -37.36573473180763 
 -20.180400235224578
  -2.229225733813316
  14.749615511053337
  29.331853859440123
  40.62554608422707 
  48.3225249611126  
  52.61571536510694 
  54.042947983512384
  53.31503924672539 
   ⋮                
  17.178647984621353
   8.021381994966156
  -2.573762022324184
 -14.023136038593602
 -25.664072903254834
 -36.846554029251884
 -47.03051365926664 
 -55.83742898698404 
 -63.019981543247575
 -68.35860212436283 
 -71.54666184631233 
 -72.15224515470479 

After the convergence criteria are met we can get the final state.

In [12]:
# Calculate the Ψ for the optimized V.
schrodinger_fd(V, schr, p)
Ψ = schr.Ψ
Out[12]:
101-element Array{Float64,1}:
 0.14833164726134337 
 0.13919394825641024 
 0.1278699952817301  
 0.11509936367379768 
 0.10170432021101757 
 0.0884566618539783  
 0.07597235110988376 
 0.06465962208332779 
 0.05472117218785397 
 0.04619297069858297 
 0.03899745133110734 
 0.032993990421552606
 0.028017912623404584
 ⋮                   
 0.09614588450255365 
 0.10240636155599434 
 0.10997045784342774 
 0.11846840215967322 
 0.1274298017992534  
 0.13629635337111087 
 0.1444518047974951  
 0.15126517003159562 
 0.15614027501154157 
 0.1585660575136774  
 0.15816585545769815 
 0.1547451044185629  

Finally, we create a plot and see how well we did.

In [13]:
scatter(x, Ψ, label = "\\Psi_i",
        markersize = 3,
        markerstrokecolor = :blue,
        markercolor = :white)

plot!(x, Ψ0, label = "\\Psi_0",
      color = :red)

plot!(x, V / 1000, label = "V/1000",
      linestyle = :dash,
      color = :black)

plot!(xlims = (-1, 1),
      xticks = -1:0.2:1,
      ylims = (-0.15, 0.2),
      yticks = -0.15:0.05:0.2,
      grid = false,
      legend = :bottomleft,
      fmt = :svg,
      framestyle = :box)
Out[13]:

svg

That's pretty good, considering it only took two-three minutes on my Thinkpad laptop to optimize 100 dimensional case and the function is not trivial!