Jekyll2019-04-30T07:00:20+00:00https://msurtsukov.github.io/feed.xmlMSurDeep Learning enthusiastMikhail Surtsukovmsurtsukov@gmail.comNeural Ordinary Differential Equations2019-03-04T00:00:00+00:002019-03-04T00:00:00+00:00https://msurtsukov.github.io/Neural-ODE<p>A significant portion of processes can be described by differential equations: let it be evolution of physical systems, medical conditions of a patient, fundamental properties of markets, etc. Such data is sequential and continuous in its nature, meaning that observations are merely realizations of some continuously changing state.</p>
<p>There is also another type of sequential data that is discrete – NLP data, for example: its state changes discretely, from one symbol to another, from one word to another.</p>
<p>Today both these types are normally processed using recurrent neural networks. They are, however, essentially different in their nature, and it seems that they should be treated differently.</p>
<p>At the last NIPS conference a very interesting <a href="https://arxiv.org/abs/1806.07366">paper</a> was presented that attempts to tackle this problem. Authors propose a very promising approach, which they call <strong>Neural Ordinary Differential Equations</strong>.</p>
<p>Here I tried to reproduce and summarize the results of original paper, making it a little easier to familiarize yourself with the idea. As I believe, this new architecture may soon be, among convolutional and recurrent networks, in a toolbox of any data scientist.</p>
<p><img src="/assets/node/backprop.png" alt="backprop" width="600px" class="align-center" /></p>
<!--more-->
<p>Imagine a problem: there is a process following an unknown ODE and some (noisy) observations along its trajectory</p>
<script type="math/tex; mode=display">\frac{dz}{dt} = f(z(t), t) \tag{1}</script>
<script type="math/tex; mode=display">\{(z_0, t_0),(z_1, t_1),...,(z_M, t_M)\} - \text{observations}</script>
<p>Is it possible to find an approximation <script type="math/tex">\widehat{f}(z, t, \theta)</script> of dynamics function <script type="math/tex">f(z, t)</script>?</p>
<p>First, consider a somewhat simpler task: there are only 2 observations, at the beginning and at the end of the trajectory, <script type="math/tex">(z_0, t_0), (z_1, t_1)</script>. One starts the evolution of the system from <script type="math/tex">z_0, t_0</script> for time <script type="math/tex">t_1 - t_0</script> with some parameterized dynamics function using any ODE initial value solver. After that, one ends up being at some new state <script type="math/tex">\hat{z_1}, t_1</script>, compares it with the observation <script type="math/tex">z_1</script>, and tries to minimize the difference by varying the parameters <script type="math/tex">\theta</script>.</p>
<p>Or, more formally, consider optimizing the following loss function <script type="math/tex">L(\hat{z_1})</script>:</p>
<script type="math/tex; mode=display">L(z(t_1)) = L \Big( \int_{t_0}^{t_1} f(z(t), t, \theta)dt \Big) = L \big( \text{ODESolve}(z(t_0), f, t_0, t_1, \theta) \big) \tag{2}</script>
<p style="text-align: center">Figure 1: Continuous backpropagation of the gradient requires solving the augmented ODE backwards in time. <br /> Arrows represent adjusting backpropagated gradients with gradients from observations. <br />
Figure from the original paper</p>
<p>In case you don’t want to dig into the maths, the above figure representes what is going on. Black trajectory represents solving the ODE during forward propagation. Red arrows represent solving the adjoint ODE during backpropagation.</p>
<p>To optimize <script type="math/tex">L</script> one needs to compute the gradients wrt. its parameters: <script type="math/tex">z(t_0), t_0, t_1, \theta</script>. To do this let us first determine how loss depends on the state at every moment of time <script type="math/tex">(z(t))</script>:</p>
<script type="math/tex; mode=display">a(t) = -\frac{\partial L}{\partial z(t)} \tag{3}</script>
<p><script type="math/tex">a(t)</script> is called <em>adjoint</em>, its dynamics is given by another ODE, which can be thought of as an instantaneous analog of the chain rule</p>
<script type="math/tex; mode=display">\frac{d a(t)}{d t} = -a(t) \frac{\partial f(z(t), t, \theta)}{\partial z} \tag{4}</script>
<p>Actual derivation of this particular formula can be found in the appendix of the original paper.</p>
<p>All vectors here are considered row vectors, whereas the original paper uses both column and row representations.</p>
<p>One can then compute</p>
<script type="math/tex; mode=display">\frac{\partial L}{\partial z(t_0)} = \int_{t_1}^{t_0} a(t) \frac{\partial f(z(t), t, \theta)}{\partial z} dt \tag{5}</script>
<p>To compute the gradients wrt. to <script type="math/tex">t</script> and <script type="math/tex">\theta</script> one can think of them as if they were part of the augmented state</p>
<script type="math/tex; mode=display">\frac{d}{dt} \begin{bmatrix} z \\ \theta \\ t \end{bmatrix} (t) = f_{\text{aug}}([z, \theta, t]) := \begin{bmatrix} f([z, \theta, t ]) \\ 0 \\ 1 \end{bmatrix} \tag{6}</script>
<p>Adjoint state to this augmented state is then</p>
<script type="math/tex; mode=display">a_{\text{aug}} := \begin{bmatrix} a \\ a_{\theta} \\ a_t \end{bmatrix}, a_{\theta}(t) := \frac{\partial L}{\partial \theta(t)}, a_t(t) := \frac{\partial L}{\partial t(t)} \tag{7}</script>
<p>Gradient of the augmented dynamics</p>
<script type="math/tex; mode=display">% <![CDATA[
\frac{\partial f_{\text{aug}}}{\partial [z, \theta, t]} = \begin{bmatrix}
\frac{\partial f}{\partial z} & \frac{\partial f}{\partial \theta} & \frac{\partial f}{\partial t} \\
0 & 0 & 0 \\
0 & 0 & 0
\end{bmatrix} \tag{8} %]]></script>
<p>Adjoint state ODE from formula (4) is then</p>
<script type="math/tex; mode=display">% <![CDATA[
\frac{d a_{\text{aug}}}{dt} = - \begin{bmatrix} a\frac{\partial f}{\partial z} & a\frac{\partial f}{\partial \theta} & a\frac{\partial f}{\partial t}\end{bmatrix} \tag{9} %]]></script>
<p>By solving this adjoint augmented ODE initial value problem one gets</p>
<script type="math/tex; mode=display">\frac{\partial L}{\partial z(t_0)} = \int_{t_1}^{t_0} a(t) \frac{\partial f(z(t), t, \theta)}{\partial z} dt \tag{10}</script>
<script type="math/tex; mode=display">\frac{\partial L}{\partial \theta} = \int_{t_1}^{t_0} a(t) \frac{\partial f(z(t), t, \theta)}{\partial \theta} dt \tag{11}</script>
<script type="math/tex; mode=display">\frac{\partial L}{\partial t_0} = \int_{t_1}^{t_0} a(t) \frac{\partial f(z(t), t, \theta)}{\partial t} dt \tag{12}</script>
<p>which, together with,</p>
<script type="math/tex; mode=display">\frac{\partial L}{\partial t_1} = - a(t) \frac{\partial f(z(t), t, \theta)}{\partial t} \tag{13}</script>
<p>complements gradients wrt. all the ODESolve parameters.</p>
<p>The gradients (10), (11), (12), (13) can be calculated altogether during a single call of the ODESolve with augmented state dynamics (9).</p>
<p><img src="/assets/node/pseudocode.png" alt="pseudocode" width="800px" class="align-center" /></p>
<div align="center">Figure from the original paper</div>
<p>The algorithm above describes backpropagation of gradients for the ODE initial value problem with subsequent observations. This algorithm lies in the heart of Neural ODEs.</p>
<p>In case there are many observations along the trajectory, one computes the adjoint augmented ODE dynamics for subsequent observations, adjusting the backpropagated gradients with direct gradients at observation times, as shown above on <em>figure 1</em>.</p>
<h1 id="implementation">Implementation</h1>
<p>The code below is my own implementation of the <strong>Neural ODE</strong>. I did it solely for better understanding of what’s going on. However it is very close to what is actually implemented in authors’ <a href="https://github.com/rtqichen/torchdiffeq">repository</a>. This notebook collects all the code that’s necessary for understanding in one place and is slightly more commented. For actual usage and experiments I suggest using authors’ original implementation.</p>
<p>Below is the code if you are interested.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">math</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">from</span> <span class="nn">IPython.display</span> <span class="kn">import</span> <span class="n">clear_output</span>
<span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm_notebook</span> <span class="k">as</span> <span class="n">tqdm</span>
<span class="kn">import</span> <span class="nn">matplotlib</span> <span class="k">as</span> <span class="n">mpl</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="o">%</span><span class="n">matplotlib</span> <span class="n">inline</span>
<span class="kn">import</span> <span class="nn">seaborn</span> <span class="k">as</span> <span class="n">sns</span>
<span class="n">sns</span><span class="o">.</span><span class="n">color_palette</span><span class="p">(</span><span class="s">"bright"</span><span class="p">)</span>
<span class="kn">import</span> <span class="nn">matplotlib</span> <span class="k">as</span> <span class="n">mpl</span>
<span class="kn">import</span> <span class="nn">matplotlib.cm</span> <span class="k">as</span> <span class="n">cm</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">Tensor</span>
<span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
<span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
<span class="kn">from</span> <span class="nn">torch.autograd</span> <span class="kn">import</span> <span class="n">Variable</span>
<span class="n">use_cuda</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span>
</code></pre></div></div>
<p>Implement any ordinary differential equation initial value solver. For the sake of simplicity it’ll be Euler’s ODE initial value solver, however any explicit or implicit method will do.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">ode_solve</span><span class="p">(</span><span class="n">z0</span><span class="p">,</span> <span class="n">t0</span><span class="p">,</span> <span class="n">t1</span><span class="p">,</span> <span class="n">f</span><span class="p">):</span>
<span class="s">"""
Simplest Euler ODE initial value solver
"""</span>
<span class="n">h_max</span> <span class="o">=</span> <span class="mf">0.05</span>
<span class="n">n_steps</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">((</span><span class="nb">abs</span><span class="p">(</span><span class="n">t1</span> <span class="o">-</span> <span class="n">t0</span><span class="p">)</span><span class="o">/</span><span class="n">h_max</span><span class="p">)</span><span class="o">.</span><span class="nb">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
<span class="n">h</span> <span class="o">=</span> <span class="p">(</span><span class="n">t1</span> <span class="o">-</span> <span class="n">t0</span><span class="p">)</span><span class="o">/</span><span class="n">n_steps</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">t0</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">z0</span>
<span class="k">for</span> <span class="n">i_step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_steps</span><span class="p">):</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">z</span> <span class="o">+</span> <span class="n">h</span> <span class="o">*</span> <span class="n">f</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">t</span> <span class="o">+</span> <span class="n">h</span>
<span class="k">return</span> <span class="n">z</span>
</code></pre></div></div>
<p>We also implement a superclass of parameterized dynamics function in the form of neural network with a couple useful methods.</p>
<p>First, one needs to be able to flatten all the parameters that the function depends on.</p>
<p>Second, one needs to implement a method that computes the augmented dynamics. This augmented dynamics depends on the gradient of the function wrt. its inputs and parameters. In order to not have to specify them by hand for every new architecture, we will use <strong>torch.autograd.grad</strong> method.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">ODEF</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">forward_with_grad</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">t</span><span class="p">,</span> <span class="n">grad_outputs</span><span class="p">):</span>
<span class="s">"""Compute f and a df/dz, a df/dp, a df/dt"""</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">grad_outputs</span>
<span class="n">adfdz</span><span class="p">,</span> <span class="n">adfdt</span><span class="p">,</span> <span class="o">*</span><span class="n">adfdp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span>
<span class="p">(</span><span class="n">out</span><span class="p">,),</span> <span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span> <span class="o">+</span> <span class="nb">tuple</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">()),</span> <span class="n">grad_outputs</span><span class="o">=</span><span class="p">(</span><span class="n">a</span><span class="p">),</span>
<span class="n">allow_unused</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">retain_graph</span><span class="o">=</span><span class="bp">True</span>
<span class="p">)</span>
<span class="c"># grad method automatically sums gradients for batch items, we have to expand them back</span>
<span class="k">if</span> <span class="n">adfdp</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">adfdp</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">p_grad</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span> <span class="k">for</span> <span class="n">p_grad</span> <span class="ow">in</span> <span class="n">adfdp</span><span class="p">])</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="n">adfdp</span> <span class="o">=</span> <span class="n">adfdp</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">batch_size</span>
<span class="k">if</span> <span class="n">adfdt</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">adfdt</span> <span class="o">=</span> <span class="n">adfdt</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="n">batch_size</span>
<span class="k">return</span> <span class="n">out</span><span class="p">,</span> <span class="n">adfdz</span><span class="p">,</span> <span class="n">adfdt</span><span class="p">,</span> <span class="n">adfdp</span>
<span class="k">def</span> <span class="nf">flatten_parameters</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">p_shapes</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">flat_parameters</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
<span class="n">p_shapes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">size</span><span class="p">())</span>
<span class="n">flat_parameters</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">flatten</span><span class="p">())</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">flat_parameters</span><span class="p">)</span>
</code></pre></div></div>
<p>The code below incapsulates forward and backward passes of <em>Neural ODE</em>. We have to separate it from main <strong><em>torch.nn.Module</em></strong> because custom backward function can’t be implemented inside Module, but can be implemented inside <strong><em>torch.autograd.Function</em></strong>. So this is just a little workaround.</p>
<p>This function underlies the whole Neural ODE method.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">ODEAdjoint</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">z0</span><span class="p">,</span> <span class="n">t</span><span class="p">,</span> <span class="n">flat_parameters</span><span class="p">,</span> <span class="n">func</span><span class="p">):</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">ODEF</span><span class="p">)</span>
<span class="n">bs</span><span class="p">,</span> <span class="o">*</span><span class="n">z_shape</span> <span class="o">=</span> <span class="n">z0</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
<span class="n">time_len</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">time_len</span><span class="p">,</span> <span class="n">bs</span><span class="p">,</span> <span class="o">*</span><span class="n">z_shape</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">z0</span><span class="p">)</span>
<span class="n">z</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">z0</span>
<span class="k">for</span> <span class="n">i_t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">time_len</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span>
<span class="n">z0</span> <span class="o">=</span> <span class="n">ode_solve</span><span class="p">(</span><span class="n">z0</span><span class="p">,</span> <span class="n">t</span><span class="p">[</span><span class="n">i_t</span><span class="p">],</span> <span class="n">t</span><span class="p">[</span><span class="n">i_t</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">func</span><span class="p">)</span>
<span class="n">z</span><span class="p">[</span><span class="n">i_t</span><span class="o">+</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">z0</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">func</span> <span class="o">=</span> <span class="n">func</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">z</span><span class="o">.</span><span class="n">clone</span><span class="p">(),</span> <span class="n">flat_parameters</span><span class="p">)</span>
<span class="k">return</span> <span class="n">z</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">dLdz</span><span class="p">):</span>
<span class="s">"""
dLdz shape: time_len, batch_size, *z_shape
"""</span>
<span class="n">func</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">func</span>
<span class="n">t</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">flat_parameters</span> <span class="o">=</span> <span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span>
<span class="n">time_len</span><span class="p">,</span> <span class="n">bs</span><span class="p">,</span> <span class="o">*</span><span class="n">z_shape</span> <span class="o">=</span> <span class="n">z</span><span class="o">.</span><span class="n">size</span><span class="p">()</span>
<span class="n">n_dim</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">z_shape</span><span class="p">)</span>
<span class="n">n_params</span> <span class="o">=</span> <span class="n">flat_parameters</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<span class="c"># Dynamics of augmented system to be calculated backwards in time</span>
<span class="k">def</span> <span class="nf">augmented_dynamics</span><span class="p">(</span><span class="n">aug_z_i</span><span class="p">,</span> <span class="n">t_i</span><span class="p">):</span>
<span class="s">"""
tensors here are temporal slices
t_i - is tensor with size: bs, 1
aug_z_i - is tensor with size: bs, n_dim*2 + n_params + 1
"""</span>
<span class="n">z_i</span><span class="p">,</span> <span class="n">a</span> <span class="o">=</span> <span class="n">aug_z_i</span><span class="p">[:,</span> <span class="p">:</span><span class="n">n_dim</span><span class="p">],</span> <span class="n">aug_z_i</span><span class="p">[:,</span> <span class="n">n_dim</span><span class="p">:</span><span class="mi">2</span><span class="o">*</span><span class="n">n_dim</span><span class="p">]</span> <span class="c"># ignore parameters and time</span>
<span class="c"># Unflatten z and a</span>
<span class="n">z_i</span> <span class="o">=</span> <span class="n">z_i</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="o">*</span><span class="n">z_shape</span><span class="p">)</span>
<span class="n">a</span> <span class="o">=</span> <span class="n">a</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="o">*</span><span class="n">z_shape</span><span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">set_grad_enabled</span><span class="p">(</span><span class="bp">True</span><span class="p">):</span>
<span class="n">t_i</span> <span class="o">=</span> <span class="n">t_i</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="bp">True</span><span class="p">)</span>
<span class="n">z_i</span> <span class="o">=</span> <span class="n">z_i</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">requires_grad_</span><span class="p">(</span><span class="bp">True</span><span class="p">)</span>
<span class="n">func_eval</span><span class="p">,</span> <span class="n">adfdz</span><span class="p">,</span> <span class="n">adfdt</span><span class="p">,</span> <span class="n">adfdp</span> <span class="o">=</span> <span class="n">func</span><span class="o">.</span><span class="n">forward_with_grad</span><span class="p">(</span><span class="n">z_i</span><span class="p">,</span> <span class="n">t_i</span><span class="p">,</span> <span class="n">grad_outputs</span><span class="o">=</span><span class="n">a</span><span class="p">)</span> <span class="c"># bs, *z_shape</span>
<span class="n">adfdz</span> <span class="o">=</span> <span class="n">adfdz</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">z_i</span><span class="p">)</span> <span class="k">if</span> <span class="n">adfdz</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="o">*</span><span class="n">z_shape</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">z_i</span><span class="p">)</span>
<span class="n">adfdp</span> <span class="o">=</span> <span class="n">adfdp</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">z_i</span><span class="p">)</span> <span class="k">if</span> <span class="n">adfdp</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">n_params</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">z_i</span><span class="p">)</span>
<span class="n">adfdt</span> <span class="o">=</span> <span class="n">adfdt</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">z_i</span><span class="p">)</span> <span class="k">if</span> <span class="n">adfdt</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span> <span class="k">else</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">z_i</span><span class="p">)</span>
<span class="c"># Flatten f and adfdz</span>
<span class="n">func_eval</span> <span class="o">=</span> <span class="n">func_eval</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">n_dim</span><span class="p">)</span>
<span class="n">adfdz</span> <span class="o">=</span> <span class="n">adfdz</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">n_dim</span><span class="p">)</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">func_eval</span><span class="p">,</span> <span class="o">-</span><span class="n">adfdz</span><span class="p">,</span> <span class="o">-</span><span class="n">adfdp</span><span class="p">,</span> <span class="o">-</span><span class="n">adfdt</span><span class="p">),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">dLdz</span> <span class="o">=</span> <span class="n">dLdz</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">time_len</span><span class="p">,</span> <span class="n">bs</span><span class="p">,</span> <span class="n">n_dim</span><span class="p">)</span> <span class="c"># flatten dLdz for convenience</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="c">## Create placeholders for output gradients</span>
<span class="c"># Prev computed backwards adjoints to be adjusted by direct gradients</span>
<span class="n">adj_z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">n_dim</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dLdz</span><span class="p">)</span>
<span class="n">adj_p</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">n_params</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dLdz</span><span class="p">)</span>
<span class="c"># In contrast to z and p we need to return gradients for all times</span>
<span class="n">adj_t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">time_len</span><span class="p">,</span> <span class="n">bs</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">dLdz</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i_t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">time_len</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">):</span>
<span class="n">z_i</span> <span class="o">=</span> <span class="n">z</span><span class="p">[</span><span class="n">i_t</span><span class="p">]</span>
<span class="n">t_i</span> <span class="o">=</span> <span class="n">t</span><span class="p">[</span><span class="n">i_t</span><span class="p">]</span>
<span class="n">f_i</span> <span class="o">=</span> <span class="n">func</span><span class="p">(</span><span class="n">z_i</span><span class="p">,</span> <span class="n">t_i</span><span class="p">)</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">n_dim</span><span class="p">)</span>
<span class="c"># Compute direct gradients</span>
<span class="n">dLdz_i</span> <span class="o">=</span> <span class="n">dLdz</span><span class="p">[</span><span class="n">i_t</span><span class="p">]</span>
<span class="n">dLdt_i</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bmm</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">dLdz_i</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">f_i</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))[:,</span> <span class="mi">0</span><span class="p">]</span>
<span class="c"># Adjusting adjoints with direct gradients</span>
<span class="n">adj_z</span> <span class="o">+=</span> <span class="n">dLdz_i</span>
<span class="n">adj_t</span><span class="p">[</span><span class="n">i_t</span><span class="p">]</span> <span class="o">=</span> <span class="n">adj_t</span><span class="p">[</span><span class="n">i_t</span><span class="p">]</span> <span class="o">-</span> <span class="n">dLdt_i</span>
<span class="c"># Pack augmented variable</span>
<span class="n">aug_z</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">z_i</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">n_dim</span><span class="p">),</span> <span class="n">adj_z</span><span class="p">,</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="n">n_params</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">z</span><span class="p">),</span> <span class="n">adj_t</span><span class="p">[</span><span class="n">i_t</span><span class="p">]),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="c"># Solve augmented system backwards</span>
<span class="n">aug_ans</span> <span class="o">=</span> <span class="n">ode_solve</span><span class="p">(</span><span class="n">aug_z</span><span class="p">,</span> <span class="n">t_i</span><span class="p">,</span> <span class="n">t</span><span class="p">[</span><span class="n">i_t</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">augmented_dynamics</span><span class="p">)</span>
<span class="c"># Unpack solved backwards augmented system</span>
<span class="n">adj_z</span><span class="p">[:]</span> <span class="o">=</span> <span class="n">aug_ans</span><span class="p">[:,</span> <span class="n">n_dim</span><span class="p">:</span><span class="mi">2</span><span class="o">*</span><span class="n">n_dim</span><span class="p">]</span>
<span class="n">adj_p</span><span class="p">[:]</span> <span class="o">+=</span> <span class="n">aug_ans</span><span class="p">[:,</span> <span class="mi">2</span><span class="o">*</span><span class="n">n_dim</span><span class="p">:</span><span class="mi">2</span><span class="o">*</span><span class="n">n_dim</span> <span class="o">+</span> <span class="n">n_params</span><span class="p">]</span>
<span class="n">adj_t</span><span class="p">[</span><span class="n">i_t</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">aug_ans</span><span class="p">[:,</span> <span class="mi">2</span><span class="o">*</span><span class="n">n_dim</span> <span class="o">+</span> <span class="n">n_params</span><span class="p">:]</span>
<span class="k">del</span> <span class="n">aug_z</span><span class="p">,</span> <span class="n">aug_ans</span>
<span class="c">## Adjust 0 time adjoint with direct gradients</span>
<span class="c"># Compute direct gradients</span>
<span class="n">dLdz_0</span> <span class="o">=</span> <span class="n">dLdz</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">dLdt_0</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">bmm</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">dLdz_0</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">f_i</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))[:,</span> <span class="mi">0</span><span class="p">]</span>
<span class="c"># Adjust adjoints</span>
<span class="n">adj_z</span> <span class="o">+=</span> <span class="n">dLdz_0</span>
<span class="n">adj_t</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">adj_t</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">dLdt_0</span>
<span class="k">return</span> <span class="n">adj_z</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="o">*</span><span class="n">z_shape</span><span class="p">),</span> <span class="n">adj_t</span><span class="p">,</span> <span class="n">adj_p</span><span class="p">,</span> <span class="bp">None</span>
</code></pre></div></div>
<p>Wrap ode adjoint function in <strong>nn.Module</strong> for convenience.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">NeuralODE</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">func</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">NeuralODE</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">()</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">ODEF</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">func</span> <span class="o">=</span> <span class="n">func</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z0</span><span class="p">,</span> <span class="n">t</span><span class="o">=</span><span class="n">Tensor</span><span class="p">([</span><span class="mf">0.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">]),</span> <span class="n">return_whole_sequence</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">z0</span><span class="p">)</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">ODEAdjoint</span><span class="o">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">z0</span><span class="p">,</span> <span class="n">t</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">func</span><span class="o">.</span><span class="n">flatten_parameters</span><span class="p">(),</span> <span class="bp">self</span><span class="o">.</span><span class="n">func</span><span class="p">)</span>
<span class="k">if</span> <span class="n">return_whole_sequence</span><span class="p">:</span>
<span class="k">return</span> <span class="n">z</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">z</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
</code></pre></div></div>
<h1 id="application">Application</h1>
<h2 id="learning-true-dynamics-function-proof-of-concept"><em>Learning true dynamics function (proof of concept)</em></h2>
<p>As a proof-of-concept we will now test if Neural ODE can indeed restore true dynamics function using sampled data.</p>
<p>To test this we will specify an ODE, evolve it and sample points on its trajectory, and then restore it.</p>
<p>First, we’ll test a simple linear ODE. Dynamics is given with a matrix.</p>
<script type="math/tex; mode=display">% <![CDATA[
\frac{dz}{dt} = \begin{bmatrix}-0.1 & -1.0\\1.0 & -0.1\end{bmatrix} z %]]></script>
<p>Trained function here is also a simple matrix.</p>
<p>The trained function here is also a simple matrix.</p>
<p><img src="/assets/node/linear_learning.gif" alt="learning gif" class="align-center" /></p>
<p>Next, slighty more sophisticated dynamics (no gif as its learning process is not so satisfying :)).<br />
Trained function here is MLP with one hidden layer.</p>
<p><img src="/assets/node/comp_result.png" alt="complicated result" class="align-center" /></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">LinearODEF</span><span class="p">(</span><span class="n">ODEF</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">W</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">LinearODEF</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lin</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lin</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">W</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</code></pre></div></div>
<p>Dynamics is simply given with a matrix.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">SpiralFunctionExample</span><span class="p">(</span><span class="n">LinearODEF</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">SpiralFunctionExample</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">(</span><span class="n">Tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">0.1</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">1.</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1</span><span class="p">]]))</span>
</code></pre></div></div>
<p>Initial random linear dynamics function to be optimized</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">RandomLinearODEF</span><span class="p">(</span><span class="n">LinearODEF</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">RandomLinearODEF</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span><span class="o">/</span><span class="mf">2.</span><span class="p">)</span>
</code></pre></div></div>
<p>More sophisticated dynamics for creating trajectories</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">TestODEF</span><span class="p">(</span><span class="n">ODEF</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">x0</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">TestODEF</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">A</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">A</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">A</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">B</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">B</span><span class="o">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">B</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">x0</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">x0</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">):</span>
<span class="n">xTx0</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">x</span><span class="o">*</span><span class="bp">self</span><span class="o">.</span><span class="n">x0</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">dxdt</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">xTx0</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">A</span><span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">x0</span><span class="p">)</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="o">-</span><span class="n">xTx0</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">B</span><span class="p">(</span><span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">x0</span><span class="p">)</span>
<span class="k">return</span> <span class="n">dxdt</span>
</code></pre></div></div>
<p>Dynamics function to be optimized is MLP</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">NNODEF</span><span class="p">(</span><span class="n">ODEF</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_dim</span><span class="p">,</span> <span class="n">hid_dim</span><span class="p">,</span> <span class="n">time_invariant</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">NNODEF</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">time_invariant</span> <span class="o">=</span> <span class="n">time_invariant</span>
<span class="k">if</span> <span class="n">time_invariant</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lin1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_dim</span><span class="p">,</span> <span class="n">hid_dim</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lin1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_dim</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">hid_dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lin2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hid_dim</span><span class="p">,</span> <span class="n">hid_dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">lin3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hid_dim</span><span class="p">,</span> <span class="n">in_dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">elu</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ELU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">time_invariant</span><span class="p">:</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">elu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lin1</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">elu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lin2</span><span class="p">(</span><span class="n">h</span><span class="p">))</span>
<span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lin3</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
<span class="k">return</span> <span class="n">out</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">to_np</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">x</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">plot_trajectories</span><span class="p">(</span><span class="n">obs</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">times</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">trajs</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">save</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">8</span><span class="p">)):</span>
<span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="n">figsize</span><span class="p">)</span>
<span class="k">if</span> <span class="n">obs</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">if</span> <span class="n">times</span> <span class="ow">is</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">times</span> <span class="o">=</span> <span class="p">[</span><span class="bp">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">obs</span><span class="p">)</span>
<span class="k">for</span> <span class="n">o</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">obs</span><span class="p">,</span> <span class="n">times</span><span class="p">):</span>
<span class="n">o</span><span class="p">,</span> <span class="n">t</span> <span class="o">=</span> <span class="n">to_np</span><span class="p">(</span><span class="n">o</span><span class="p">),</span> <span class="n">to_np</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="k">for</span> <span class="n">b_i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">o</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span>
<span class="n">plt</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">o</span><span class="p">[:,</span> <span class="n">b_i</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">o</span><span class="p">[:,</span> <span class="n">b_i</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">c</span><span class="o">=</span><span class="n">t</span><span class="p">[:,</span> <span class="n">b_i</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">cmap</span><span class="o">=</span><span class="n">cm</span><span class="o">.</span><span class="n">plasma</span><span class="p">)</span>
<span class="k">if</span> <span class="n">trajs</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="k">for</span> <span class="n">z</span> <span class="ow">in</span> <span class="n">trajs</span><span class="p">:</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">to_np</span><span class="p">(</span><span class="n">z</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">z</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">z</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">lw</span><span class="o">=</span><span class="mf">1.5</span><span class="p">)</span>
<span class="k">if</span> <span class="n">save</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
<span class="n">plt</span><span class="o">.</span><span class="n">savefig</span><span class="p">(</span><span class="n">save</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">conduct_experiment</span><span class="p">(</span><span class="n">ode_true</span><span class="p">,</span> <span class="n">ode_trained</span><span class="p">,</span> <span class="n">n_steps</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">plot_freq</span><span class="o">=</span><span class="mi">10</span><span class="p">):</span>
<span class="c"># Create data</span>
<span class="n">z0</span> <span class="o">=</span> <span class="n">Variable</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">([[</span><span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">]]))</span>
<span class="n">t_max</span> <span class="o">=</span> <span class="mf">6.29</span><span class="o">*</span><span class="mi">5</span>
<span class="n">n_points</span> <span class="o">=</span> <span class="mi">200</span>
<span class="n">index_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">n_points</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="nb">int</span><span class="p">)</span>
<span class="n">index_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">([</span><span class="n">index_np</span><span class="p">[:,</span> <span class="bp">None</span><span class="p">]])</span>
<span class="n">times_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">t_max</span><span class="p">,</span> <span class="n">num</span><span class="o">=</span><span class="n">n_points</span><span class="p">)</span>
<span class="n">times_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">([</span><span class="n">times_np</span><span class="p">[:,</span> <span class="bp">None</span><span class="p">]])</span>
<span class="n">times</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">times_np</span><span class="p">[:,</span> <span class="p">:,</span> <span class="bp">None</span><span class="p">])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">z0</span><span class="p">)</span>
<span class="n">obs</span> <span class="o">=</span> <span class="n">ode_true</span><span class="p">(</span><span class="n">z0</span><span class="p">,</span> <span class="n">times</span><span class="p">,</span> <span class="n">return_whole_sequence</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
<span class="n">obs</span> <span class="o">=</span> <span class="n">obs</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">obs</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.01</span>
<span class="c"># Get trajectory of random timespan</span>
<span class="n">min_delta_time</span> <span class="o">=</span> <span class="mf">1.0</span>
<span class="n">max_delta_time</span> <span class="o">=</span> <span class="mf">5.0</span>
<span class="n">max_points_num</span> <span class="o">=</span> <span class="mi">32</span>
<span class="k">def</span> <span class="nf">create_batch</span><span class="p">():</span>
<span class="n">t0</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">t_max</span> <span class="o">-</span> <span class="n">max_delta_time</span><span class="p">)</span>
<span class="n">t1</span> <span class="o">=</span> <span class="n">t0</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">min_delta_time</span><span class="p">,</span> <span class="n">max_delta_time</span><span class="p">)</span>
<span class="n">idx</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">index_np</span><span class="p">[(</span><span class="n">times_np</span> <span class="o">></span> <span class="n">t0</span><span class="p">)</span> <span class="o">&</span> <span class="p">(</span><span class="n">times_np</span> <span class="o"><</span> <span class="n">t1</span><span class="p">)])[:</span><span class="n">max_points_num</span><span class="p">])</span>
<span class="n">obs_</span> <span class="o">=</span> <span class="n">obs</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
<span class="n">ts_</span> <span class="o">=</span> <span class="n">times</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span>
<span class="k">return</span> <span class="n">obs_</span><span class="p">,</span> <span class="n">ts_</span>
<span class="c"># Train Neural ODE</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">ode_trained</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_steps</span><span class="p">):</span>
<span class="n">obs_</span><span class="p">,</span> <span class="n">ts_</span> <span class="o">=</span> <span class="n">create_batch</span><span class="p">()</span>
<span class="n">z_</span> <span class="o">=</span> <span class="n">ode_trained</span><span class="p">(</span><span class="n">obs_</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">ts_</span><span class="p">,</span> <span class="n">return_whole_sequence</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">mse_loss</span><span class="p">(</span><span class="n">z_</span><span class="p">,</span> <span class="n">obs_</span><span class="o">.</span><span class="n">detach</span><span class="p">())</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span><span class="n">retain_graph</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">%</span> <span class="n">plot_freq</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">z_p</span> <span class="o">=</span> <span class="n">ode_trained</span><span class="p">(</span><span class="n">z0</span><span class="p">,</span> <span class="n">times</span><span class="p">,</span> <span class="n">return_whole_sequence</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">plot_trajectories</span><span class="p">(</span><span class="n">obs</span><span class="o">=</span><span class="p">[</span><span class="n">obs</span><span class="p">],</span> <span class="n">times</span><span class="o">=</span><span class="p">[</span><span class="n">times</span><span class="p">],</span> <span class="n">trajs</span><span class="o">=</span><span class="p">[</span><span class="n">z_p</span><span class="p">],</span> <span class="n">save</span><span class="o">=</span><span class="n">f</span><span class="s">"assets/imgs/{name}/{i}.png"</span><span class="p">)</span>
<span class="n">clear_output</span><span class="p">(</span><span class="n">wait</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">ode_true</span> <span class="o">=</span> <span class="n">NeuralODE</span><span class="p">(</span><span class="n">SpiralFunctionExample</span><span class="p">())</span>
<span class="n">ode_trained</span> <span class="o">=</span> <span class="n">NeuralODE</span><span class="p">(</span><span class="n">RandomLinearODEF</span><span class="p">())</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">conduct_experiment</span><span class="p">(</span><span class="n">ode_true</span><span class="p">,</span> <span class="n">ode_trained</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="s">"linear"</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">func</span> <span class="o">=</span> <span class="n">TestODEF</span><span class="p">(</span><span class="n">Tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">0.1</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.5</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.1</span><span class="p">]]),</span> <span class="n">Tensor</span><span class="p">([[</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">]]),</span> <span class="n">Tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">1.</span><span class="p">,</span> <span class="mf">0.</span><span class="p">]]))</span>
<span class="n">ode_true</span> <span class="o">=</span> <span class="n">NeuralODE</span><span class="p">(</span><span class="n">func</span><span class="p">)</span>
<span class="n">func</span> <span class="o">=</span> <span class="n">NNODEF</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="n">time_invariant</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">ode_trained</span> <span class="o">=</span> <span class="n">NeuralODE</span><span class="p">(</span><span class="n">func</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">conduct_experiment</span><span class="p">(</span><span class="n">ode_true</span><span class="p">,</span> <span class="n">ode_trained</span><span class="p">,</span> <span class="mi">3000</span><span class="p">,</span> <span class="s">"comp"</span><span class="p">,</span> <span class="n">plot_freq</span><span class="o">=</span><span class="mi">30</span><span class="p">)</span>
</code></pre></div></div>
<p>As one can see, Neural ODEs are pretty successful in approximating dynamics. Now let’s check if they can be used in a slightly more complicated (MNIST, ha-ha) task.</p>
<h2 id="neural-ode-inspired-by-resnets">Neural ODE inspired by ResNets</h2>
<p>In residual networks hidden state changes according to the formula</p>
<script type="math/tex; mode=display">h_{t+1} = h_{t} + f(h_{t}, \theta_{t})</script>
<p>where <script type="math/tex">t \in \{0...T\}</script> is residual block number and <script type="math/tex">f</script> is a function learned by layers inside the block.</p>
<p>If one takes a limit of an infinite number of residual blocks with smaller steps one gets continuous dynamics of hidden units to be an ordinary differential equation just as we had above.</p>
<script type="math/tex; mode=display">\frac{dh(t)}{dt} = f(h(t), t, \theta)</script>
<p>Starting from the input layer <script type="math/tex">h(0)</script>, one can define the output layer <script type="math/tex">h(T)</script> to be the solution to this ODE initial value problem at some time T.</p>
<p>Now one can treat <script type="math/tex">\theta</script> as parameters shared among all infinitesimally small residual blocks.</p>
<h3 id="testing-neural-ode-architecture-on-mnist">Testing Neural ODE architecture on MNIST</h3>
<p>In this section we test the ability of Neural ODE’s to be used as a component in more conventional architectures.
In particular, we will use Neural ODE in place of residual blocks in MNIST classifier.</p>
<p><img src="/assets/node/mnist_example.png" alt="mnist_example" width="400px" class="align-center" /></p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">norm</span><span class="p">(</span><span class="n">dim</span><span class="p">):</span>
<span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">conv3x3</span><span class="p">(</span><span class="n">in_feats</span><span class="p">,</span> <span class="n">out_feats</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span>
<span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_feats</span><span class="p">,</span> <span class="n">out_feats</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">add_time</span><span class="p">(</span><span class="n">in_tensor</span><span class="p">,</span> <span class="n">t</span><span class="p">):</span>
<span class="n">bs</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">in_tensor</span><span class="o">.</span><span class="n">shape</span>
<span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">in_tensor</span><span class="p">,</span> <span class="n">t</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="n">bs</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">)),</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">ConvODEF</span><span class="p">(</span><span class="n">ODEF</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">ConvODEF</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">conv1</span> <span class="o">=</span> <span class="n">conv3x3</span><span class="p">(</span><span class="n">dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">conv2</span> <span class="o">=</span> <span class="n">conv3x3</span><span class="p">(</span><span class="n">dim</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">norm2</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">):</span>
<span class="n">xt</span> <span class="o">=</span> <span class="n">add_time</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv1</span><span class="p">(</span><span class="n">xt</span><span class="p">)))</span>
<span class="n">ht</span> <span class="o">=</span> <span class="n">add_time</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="n">dxdt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm2</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">conv2</span><span class="p">(</span><span class="n">ht</span><span class="p">)))</span>
<span class="k">return</span> <span class="n">dxdt</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">ContinuousNeuralMNISTClassifier</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ode</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">ContinuousNeuralMNISTClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">downsampling</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">norm</span><span class="p">(</span><span class="mi">64</span><span class="p">),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="n">norm</span><span class="p">(</span><span class="mi">64</span><span class="p">),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="bp">True</span><span class="p">),</span>
<span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">feature</span> <span class="o">=</span> <span class="n">ode</span>
<span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">norm</span><span class="p">(</span><span class="mi">64</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">avg_pool</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">AdaptiveAvgPool2d</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">fc</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">downsampling</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feature</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">avg_pool</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">shape</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">prod</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:]))</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span>
<span class="n">out</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">out</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">func</span> <span class="o">=</span> <span class="n">ConvODEF</span><span class="p">(</span><span class="mi">64</span><span class="p">)</span>
<span class="n">ode</span> <span class="o">=</span> <span class="n">NeuralODE</span><span class="p">(</span><span class="n">func</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">ContinuousNeuralMNISTClassifier</span><span class="p">(</span><span class="n">ode</span><span class="p">)</span>
<span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torchvision</span>
<span class="n">img_std</span> <span class="o">=</span> <span class="mf">0.3081</span>
<span class="n">img_mean</span> <span class="o">=</span> <span class="mf">0.1307</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">train_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">torchvision</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="s">"data/mnist"</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
<span class="n">transform</span><span class="o">=</span><span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
<span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="n">img_mean</span><span class="p">,),</span> <span class="p">(</span><span class="n">img_std</span><span class="p">,))</span>
<span class="p">])</span>
<span class="p">),</span>
<span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span>
<span class="p">)</span>
<span class="n">test_loader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span>
<span class="n">torchvision</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">MNIST</span><span class="p">(</span><span class="s">"data/mnist"</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">download</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span>
<span class="n">transform</span><span class="o">=</span><span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Compose</span><span class="p">([</span>
<span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">ToTensor</span><span class="p">(),</span>
<span class="n">torchvision</span><span class="o">.</span><span class="n">transforms</span><span class="o">.</span><span class="n">Normalize</span><span class="p">((</span><span class="n">img_mean</span><span class="p">,),</span> <span class="p">(</span><span class="n">img_std</span><span class="p">,))</span>
<span class="p">])</span>
<span class="p">),</span>
<span class="n">batch_size</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">True</span>
<span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">())</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="n">epoch</span><span class="p">):</span>
<span class="n">num_items</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">train_losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
<span class="n">criterion</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span>
<span class="k">print</span><span class="p">(</span><span class="n">f</span><span class="s">"Training Epoch {epoch}..."</span><span class="p">)</span>
<span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">train_loader</span><span class="p">),</span> <span class="n">total</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">train_loader</span><span class="p">)):</span>
<span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="n">train_losses</span> <span class="o">+=</span> <span class="p">[</span><span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()]</span>
<span class="n">num_items</span> <span class="o">+=</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="s">'Train loss: {:.5f}'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">train_losses</span><span class="p">)))</span>
<span class="k">return</span> <span class="n">train_losses</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">test</span><span class="p">():</span>
<span class="n">accuracy</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="n">num_items</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">model</span><span class="o">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="n">criterion</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span>
<span class="k">print</span><span class="p">(</span><span class="n">f</span><span class="s">"Testing..."</span><span class="p">)</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">test_loader</span><span class="p">),</span> <span class="n">total</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">test_loader</span><span class="p">)):</span>
<span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">target</span> <span class="o">=</span> <span class="n">target</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="n">accuracy</span> <span class="o">+=</span> <span class="n">torch</span><span class="o">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o">==</span> <span class="n">target</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="n">num_items</span> <span class="o">+=</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">accuracy</span> <span class="o">=</span> <span class="n">accuracy</span> <span class="o">*</span> <span class="mi">100</span> <span class="o">/</span> <span class="n">num_items</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Test Accuracy: {:.3f}</span><span class="si">%</span><span class="s">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">accuracy</span><span class="p">))</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">n_epochs</span> <span class="o">=</span> <span class="mi">5</span>
<span class="n">test</span><span class="p">()</span>
<span class="n">train_losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_epochs</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
<span class="n">train_losses</span> <span class="o">+=</span> <span class="n">train</span><span class="p">(</span><span class="n">epoch</span><span class="p">)</span>
<span class="n">test</span><span class="p">()</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>
<span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">history</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">({</span><span class="s">"loss"</span><span class="p">:</span> <span class="n">train_losses</span><span class="p">})</span>
<span class="n">history</span><span class="p">[</span><span class="s">"cum_data"</span><span class="p">]</span> <span class="o">=</span> <span class="n">history</span><span class="o">.</span><span class="n">index</span> <span class="o">*</span> <span class="n">batch_size</span>
<span class="n">history</span><span class="p">[</span><span class="s">"smooth_loss"</span><span class="p">]</span> <span class="o">=</span> <span class="n">history</span><span class="o">.</span><span class="n">loss</span><span class="o">.</span><span class="n">ewm</span><span class="p">(</span><span class="n">halflife</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span>
<span class="n">history</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="s">"cum_data"</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="s">"smooth_loss"</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">title</span><span class="o">=</span><span class="s">"train error"</span><span class="p">)</span>
</code></pre></div></div>
<div class="highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Testing...
100% 79/79 [00:01<00:00, 45.69it/s]
Test Accuracy: 9.740%
Training Epoch 1...
100% 1875/1875 [01:15<00:00, 24.69it/s]
Train loss: 0.20137
Testing...
100% 79/79 [00:01<00:00, 46.64it/s]
Test Accuracy: 98.680%
Training Epoch 2...
100% 1875/1875 [01:17<00:00, 24.32it/s]
Train loss: 0.05059
Testing...
100% 79/79 [00:01<00:00, 46.11it/s]
Test Accuracy: 97.760%
Training Epoch 3...
100% 1875/1875 [01:16<00:00, 24.63it/s]
Train loss: 0.03808
Testing...
100% 79/79 [00:01<00:00, 45.65it/s]
Test Accuracy: 99.000%
Training Epoch 4...
100% 1875/1875 [01:17<00:00, 24.28it/s]
Train loss: 0.02894
Testing...
100% 79/79 [00:01<00:00, 45.42it/s]
Test Accuracy: 99.130%
Training Epoch 5...
100% 1875/1875 [01:16<00:00, 24.67it/s]
Train loss: 0.02424
Testing...
100% 79/79 [00:01<00:00, 45.89it/s]
Test Accuracy: 99.170%
</code></pre></div></div>
<p><img src="/assets/node/train_error.png" alt="train error" class="align-center" /></p>
<p>After a very rough training procedure of only 5 epochs and 6 minutes of training the model already has test error of less than 1%. Which shows that Neural ODE architecture fits very good as a component in more conventional nets.</p>
<p>In their paper, authors also compare this classifier to simple 1-layer MLP, to ResNet with alike architecture, and to same ODE architecture, but in which gradients propagated directly through ODESolve (without adjoint gradient method) (RK-Net).</p>
<p><img src="/assets/node/methods_compare.png" alt=""Methods comparison"" class="align-center" /></p>
<div align="center">Figure from original paper</div>
<p>According to them, 1-layer MLP with roughly the same amount of parameters as Neural ODE-Net has much higher test error, ResNet with roughly the same error has much more parameters, and RK-Net with direct backpropagation through ODESolver has slightly higher error and linearly growing memory usage.</p>
<p>In their paper, authors use implicit Runge-Kutta solver with adaptive step size instead of simple Euler’s method. They also examine some ODE-Net characteristics.</p>
<p><img src="/assets/node/ode_solver_attrs.png" alt=""Node attrs"" class="align-center" /></p>
<div align="center">ODE-Net characteristics (NFE Forward - number of function evaluations during forward pass)</div>
<div align="center">Figure from original paper</div>
<ul>
<li>(a) Changing tolerable Numerical Error varies the number of steps per forward pass evaluation.</li>
<li>(b) Time spent by the forward call is proportional to the number of function evaluations.</li>
<li>(c) Number of backward evaluations is roughly half the number of forward evaluations, this suggests that adjoint method is more computationally efficient than direct backpropagation through ODESolver.</li>
<li>(d) As ODE-Net becomes more and more trained, it demands more and more evaluations, presumably adapting to the increasing complexity of the model.</li>
</ul>
<h2 id="generative-latent-function-time-series-model">Generative latent function time-series model</h2>
<p>Neural ODE seems to be more suitable for continuous sequential data even when this continuous trajectory is in some unknown latent space.</p>
<p>In this section we will experiment with generating continuous sequential data using Neural ODE and exploring its latent space a bit.
Authors also compare it to the same sequential data but generated with Recurrent Neural Networks.</p>
<p>The approach here is slightly different from the corresponding example in authors repository, the one here has a more diverse set of trajectories.</p>
<h3 id="data">Data</h3>
<p>Training data consists of random spirals, one half of which is clockwise and another is counter-clockwise. Then, random subtimespans of size 100 are sampled from these spirals, having passed through encoder rnn model in reversed order yielding a latent starting state, which then evolves creating a trajectory in the latent space. This latent trajectory is then mapped onto the data space trajectory and compared with the actual data observations. Thus, the model learns to generate data-alike trajectories.</p>
<p><img src="/assets/node/spirals_examples.png" alt="image.png" class="align-center" /></p>
<div align="center">Examples of spirals in the dataset</div>
<h3 id="vae-as-a-generative-model">VAE as a generative model</h3>
<p>A generative model through sampling procedure:</p>
<script type="math/tex; mode=display">z_{t_0} \sim \mathcal{N}(0, I)</script>
<script type="math/tex; mode=display">z_{t_1}, z_{t_2},...,z_{t_M} = \text{ODESolve}(z_{t_0}, f, \theta_f, t_0,...,t_M)</script>
<script type="math/tex; mode=display">\text{each } x_{t_i} \sim p(x \mid z_{t_i};\theta_x)</script>
<p>Which can be trained using variational autoencoder approach:</p>
<ol>
<li>Run the RNN encoder through the time series backwards in time to infer the parameters <script type="math/tex">\mu_{z_{t_0}}</script>, <script type="math/tex">\sigma_{z_{t_0}}</script> of variational posterior and sample from it</li>
</ol>
<script type="math/tex; mode=display">z_{t_0} \sim q \left( z_{t_0} \mid x_{t_0},...,x_{t_M}; t_0,...,t_M; \theta_q \right) = \mathcal{N} \left(z_{t_0} \mid \mu_{z_{t_0}} \sigma_{z_{t_0}} \right)</script>
<ol>
<li>Obtain the latent trajectory</li>
</ol>
<script type="math/tex; mode=display">z_{t_1}, z_{t_2},...,z_{t_N} = \text{ODESolve}(z_{t_0}, f, \theta_f, t_0,...,t_N), \text{ where } \frac{d z}{d t} = f(z, t; \theta_f)</script>
<ol>
<li>
<p>Map the latent trajectory onto the data space using another neural network: <script type="math/tex">\hat{x_{t_i}}(z_{t_i}, t_i; \theta_x)</script></p>
</li>
<li>
<p>Maximize Evidence Lower BOund estimate for sampled trajectory</p>
</li>
</ol>
<script type="math/tex; mode=display">\text{ELBO} \approx N \Big( \sum_{i=0}^{M} \log p(x_{t_i} \mid z_{t_i}(z_{t_0}; \theta_f); \theta_x) + KL \left( q( z_{t_0} \mid x_{t_0},...,x_{t_M}; t_0,...,t_M; \theta_q) \parallel \mathcal{N}(0, I) \right) \Big)</script>
<p>And in case of Gaussian posterior <script type="math/tex">p(x \mid z_{t_i};\theta_x)</script> and known noise level <script type="math/tex">\sigma_x</script></p>
<script type="math/tex; mode=display">\text{ELBO} \approx -N \Big( \sum_{i=1}^{M}\frac{(x_i - \hat{x_i} )^2}{\sigma_x^2} - \log \sigma_{z_{t_0}}^2 + \mu_{z_{t_0}}^2 + \sigma_{z_{t_0}}^2 \Big) + C</script>
<p>Computation graph of the latent ODE model can be depicted like this</p>
<p><img src="/assets/node/vae_model.png" alt="vae_model" class="align-center" /></p>
<div align="center">Figure from the original paper</div>
<p>One can then test how this model extrapolates the trajectory from only its initial moment observations.</p>
<h3 id="defining-the-models">Defining the models</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">RNNEncoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">RNNEncoder</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_dim</span> <span class="o">=</span> <span class="n">input_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">hidden_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span> <span class="o">=</span> <span class="n">latent_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">rnn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GRU</span><span class="p">(</span><span class="n">input_dim</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hid2lat</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_dim</span><span class="p">,</span> <span class="mi">2</span><span class="o">*</span><span class="n">latent_dim</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">):</span>
<span class="c"># Concatenate time to input</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="n">t</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">t</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span>
<span class="n">t</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mf">0.</span>
<span class="n">xt</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">),</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">_</span><span class="p">,</span> <span class="n">h0</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnn</span><span class="p">(</span><span class="n">xt</span><span class="o">.</span><span class="n">flip</span><span class="p">((</span><span class="mi">0</span><span class="p">,)))</span> <span class="c"># Reversed</span>
<span class="c"># Compute latent dimension</span>
<span class="n">z0</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hid2lat</span><span class="p">(</span><span class="n">h0</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">z0_mean</span> <span class="o">=</span> <span class="n">z0</span><span class="p">[:,</span> <span class="p">:</span><span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span><span class="p">]</span>
<span class="n">z0_log_var</span> <span class="o">=</span> <span class="n">z0</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span><span class="p">:]</span>
<span class="k">return</span> <span class="n">z0_mean</span><span class="p">,</span> <span class="n">z0_log_var</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">NeuralODEDecoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">NeuralODEDecoder</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_dim</span> <span class="o">=</span> <span class="n">output_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">hidden_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span> <span class="o">=</span> <span class="n">latent_dim</span>
<span class="n">func</span> <span class="o">=</span> <span class="n">NNODEF</span><span class="p">(</span><span class="n">latent_dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">,</span> <span class="n">time_invariant</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">ode</span> <span class="o">=</span> <span class="n">NeuralODE</span><span class="p">(</span><span class="n">func</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">l2h</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">latent_dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">h2o</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_dim</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">z0</span><span class="p">,</span> <span class="n">t</span><span class="p">):</span>
<span class="n">zs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ode</span><span class="p">(</span><span class="n">z0</span><span class="p">,</span> <span class="n">t</span><span class="p">,</span> <span class="n">return_whole_sequence</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">hs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">l2h</span><span class="p">(</span><span class="n">zs</span><span class="p">)</span>
<span class="n">xs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">h2o</span><span class="p">(</span><span class="n">hs</span><span class="p">)</span>
<span class="k">return</span> <span class="n">xs</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">class</span> <span class="nc">ODEVAE</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">ODEVAE</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">output_dim</span> <span class="o">=</span> <span class="n">output_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">hidden_dim</span> <span class="o">=</span> <span class="n">hidden_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span> <span class="o">=</span> <span class="n">latent_dim</span>
<span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">RNNEncoder</span><span class="p">(</span><span class="n">output_dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">NeuralODEDecoder</span><span class="p">(</span><span class="n">output_dim</span><span class="p">,</span> <span class="n">hidden_dim</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">,</span> <span class="n">MAP</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
<span class="n">z_mean</span><span class="p">,</span> <span class="n">z_log_var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="k">if</span> <span class="n">MAP</span><span class="p">:</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">z_mean</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">z_mean</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">z_mean</span><span class="p">)</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">z_log_var</span><span class="p">)</span>
<span class="n">x_p</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x_p</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">z_mean</span><span class="p">,</span> <span class="n">z_log_var</span>
<span class="k">def</span> <span class="nf">generate_with_seed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">seed_x</span><span class="p">,</span> <span class="n">t</span><span class="p">):</span>
<span class="n">seed_t_len</span> <span class="o">=</span> <span class="n">seed_x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">z_mean</span><span class="p">,</span> <span class="n">z_log_var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">seed_x</span><span class="p">,</span> <span class="n">t</span><span class="p">[:</span><span class="n">seed_t_len</span><span class="p">])</span>
<span class="n">x_p</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">z_mean</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="k">return</span> <span class="n">x_p</span>
</code></pre></div></div>
<h3 id="generating-dataset">Generating dataset</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t_max</span> <span class="o">=</span> <span class="mf">6.29</span><span class="o">*</span><span class="mi">5</span>
<span class="n">n_points</span> <span class="o">=</span> <span class="mi">200</span>
<span class="n">noise_std</span> <span class="o">=</span> <span class="mf">0.02</span>
<span class="n">num_spirals</span> <span class="o">=</span> <span class="mi">1000</span>
<span class="n">index_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">n_points</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="nb">int</span><span class="p">)</span>
<span class="n">index_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">([</span><span class="n">index_np</span><span class="p">[:,</span> <span class="bp">None</span><span class="p">]])</span>
<span class="n">times_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">t_max</span><span class="p">,</span> <span class="n">num</span><span class="o">=</span><span class="n">n_points</span><span class="p">)</span>
<span class="n">times_np</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">([</span><span class="n">times_np</span><span class="p">[:,</span> <span class="bp">None</span><span class="p">]]</span> <span class="o">*</span> <span class="n">num_spirals</span><span class="p">)</span>
<span class="n">times</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">times_np</span><span class="p">[:,</span> <span class="p">:,</span> <span class="bp">None</span><span class="p">])</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
<span class="c"># Generate random spirals parameters</span>
<span class="n">normal01</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributions</span><span class="o">.</span><span class="n">Normal</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span>
<span class="n">x0</span> <span class="o">=</span> <span class="n">Variable</span><span class="p">(</span><span class="n">normal01</span><span class="o">.</span><span class="n">sample</span><span class="p">((</span><span class="n">num_spirals</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span> <span class="o">*</span> <span class="mf">2.0</span>
<span class="n">W11</span> <span class="o">=</span> <span class="o">-</span><span class="mf">0.1</span> <span class="o">*</span> <span class="n">normal01</span><span class="o">.</span><span class="n">sample</span><span class="p">((</span><span class="n">num_spirals</span><span class="p">,))</span><span class="o">.</span><span class="nb">abs</span><span class="p">()</span> <span class="o">-</span> <span class="mf">0.05</span>
<span class="n">W22</span> <span class="o">=</span> <span class="o">-</span><span class="mf">0.1</span> <span class="o">*</span> <span class="n">normal01</span><span class="o">.</span><span class="n">sample</span><span class="p">((</span><span class="n">num_spirals</span><span class="p">,))</span><span class="o">.</span><span class="nb">abs</span><span class="p">()</span> <span class="o">-</span> <span class="mf">0.05</span>
<span class="n">W21</span> <span class="o">=</span> <span class="o">-</span><span class="mf">1.0</span> <span class="o">*</span> <span class="n">normal01</span><span class="o">.</span><span class="n">sample</span><span class="p">((</span><span class="n">num_spirals</span><span class="p">,))</span><span class="o">.</span><span class="nb">abs</span><span class="p">()</span>
<span class="n">W12</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">*</span> <span class="n">normal01</span><span class="o">.</span><span class="n">sample</span><span class="p">((</span><span class="n">num_spirals</span><span class="p">,))</span><span class="o">.</span><span class="nb">abs</span><span class="p">()</span>
<span class="n">xs_list</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_spirals</span><span class="p">):</span>
<span class="k">if</span> <span class="n">i</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span> <span class="c"># Make it counter-clockwise</span>
<span class="n">W21</span><span class="p">,</span> <span class="n">W12</span> <span class="o">=</span> <span class="n">W12</span><span class="p">,</span> <span class="n">W21</span>
<span class="n">func</span> <span class="o">=</span> <span class="n">LinearODEF</span><span class="p">(</span><span class="n">Tensor</span><span class="p">([[</span><span class="n">W11</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">W12</span><span class="p">[</span><span class="n">i</span><span class="p">]],</span> <span class="p">[</span><span class="n">W21</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">W22</span><span class="p">[</span><span class="n">i</span><span class="p">]]]))</span>
<span class="n">ode</span> <span class="o">=</span> <span class="n">NeuralODE</span><span class="p">(</span><span class="n">func</span><span class="p">)</span>
<span class="n">xs</span> <span class="o">=</span> <span class="n">ode</span><span class="p">(</span><span class="n">x0</span><span class="p">[</span><span class="n">i</span><span class="p">:</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">times</span><span class="p">[:,</span> <span class="n">i</span><span class="p">:</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="n">return_whole_sequence</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">xs_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">xs</span><span class="p">)</span>
<span class="n">orig_trajs</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">xs_list</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span>
<span class="n">samp_trajs</span> <span class="o">=</span> <span class="n">orig_trajs</span> <span class="o">+</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn_like</span><span class="p">(</span><span class="n">orig_trajs</span><span class="p">)</span> <span class="o">*</span> <span class="n">noise_std</span>
<span class="n">samp_ts</span> <span class="o">=</span> <span class="n">times</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">9</span><span class="p">))</span>
<span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">ax</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">axes</span><span class="p">):</span>
<span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">samp_trajs</span><span class="p">[:,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">samp_trajs</span><span class="p">[:,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">c</span><span class="o">=</span><span class="n">samp_ts</span><span class="p">[:,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">cmap</span><span class="o">=</span><span class="n">cm</span><span class="o">.</span><span class="n">plasma</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy.random</span> <span class="k">as</span> <span class="n">npr</span>
<span class="k">def</span> <span class="nf">gen_batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n_sample</span><span class="o">=</span><span class="mi">100</span><span class="p">):</span>
<span class="n">n_batches</span> <span class="o">=</span> <span class="n">samp_trajs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">//</span> <span class="n">batch_size</span>
<span class="n">time_len</span> <span class="o">=</span> <span class="n">samp_trajs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">n_sample</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">n_sample</span><span class="p">,</span> <span class="n">time_len</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_batches</span><span class="p">):</span>
<span class="k">if</span> <span class="n">n_sample</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span>
<span class="n">t0_idx</span> <span class="o">=</span> <span class="n">npr</span><span class="o">.</span><span class="n">multinomial</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">[</span><span class="mf">1.</span> <span class="o">/</span> <span class="p">(</span><span class="n">time_len</span> <span class="o">-</span> <span class="n">n_sample</span><span class="p">)]</span> <span class="o">*</span> <span class="p">(</span><span class="n">time_len</span> <span class="o">-</span> <span class="n">n_sample</span><span class="p">))</span>
<span class="n">t0_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">t0_idx</span><span class="p">)</span>
<span class="n">tM_idx</span> <span class="o">=</span> <span class="n">t0_idx</span> <span class="o">+</span> <span class="n">n_sample</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">t0_idx</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">tM_idx</span> <span class="o">=</span> <span class="n">time_len</span>
<span class="n">frm</span><span class="p">,</span> <span class="n">to</span> <span class="o">=</span> <span class="n">batch_size</span><span class="o">*</span><span class="n">i</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">*</span><span class="p">(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span>
<span class="k">yield</span> <span class="n">samp_trajs</span><span class="p">[</span><span class="n">t0_idx</span><span class="p">:</span><span class="n">tM_idx</span><span class="p">,</span> <span class="n">frm</span><span class="p">:</span><span class="n">to</span><span class="p">],</span> <span class="n">samp_ts</span><span class="p">[</span><span class="n">t0_idx</span><span class="p">:</span><span class="n">tM_idx</span><span class="p">,</span> <span class="n">frm</span><span class="p">:</span><span class="n">to</span><span class="p">]</span>
</code></pre></div></div>
<h3 id="training">Training</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">vae</span> <span class="o">=</span> <span class="n">ODEVAE</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
<span class="n">vae</span> <span class="o">=</span> <span class="n">vae</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
<span class="n">vae</span> <span class="o">=</span> <span class="n">vae</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">optim</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">vae</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">betas</span><span class="o">=</span><span class="p">(</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.999</span><span class="p">),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">preload</span> <span class="o">=</span> <span class="bp">False</span>
<span class="n">n_epochs</span> <span class="o">=</span> <span class="mi">20000</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">100</span>
<span class="n">plot_traj_idx</span> <span class="o">=</span> <span class="mi">1</span>
<span class="n">plot_traj</span> <span class="o">=</span> <span class="n">orig_trajs</span><span class="p">[:,</span> <span class="n">plot_traj_idx</span><span class="p">:</span><span class="n">plot_traj_idx</span><span class="o">+</span><span class="mi">1</span><span class="p">]</span>
<span class="n">plot_obs</span> <span class="o">=</span> <span class="n">samp_trajs</span><span class="p">[:,</span> <span class="n">plot_traj_idx</span><span class="p">:</span><span class="n">plot_traj_idx</span><span class="o">+</span><span class="mi">1</span><span class="p">]</span>
<span class="n">plot_ts</span> <span class="o">=</span> <span class="n">samp_ts</span><span class="p">[:,</span> <span class="n">plot_traj_idx</span><span class="p">:</span><span class="n">plot_traj_idx</span><span class="o">+</span><span class="mi">1</span><span class="p">]</span>
<span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
<span class="n">plot_traj</span> <span class="o">=</span> <span class="n">plot_traj</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">plot_obs</span> <span class="o">=</span> <span class="n">plot_obs</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">plot_ts</span> <span class="o">=</span> <span class="n">plot_ts</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="k">if</span> <span class="n">preload</span><span class="p">:</span>
<span class="n">vae</span><span class="o">.</span><span class="n">load_state_dict</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">"models/vae_spirals.sd"</span><span class="p">))</span>
<span class="k">for</span> <span class="n">epoch_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_epochs</span><span class="p">):</span>
<span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">train_iter</span> <span class="o">=</span> <span class="n">gen_batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span>
<span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="n">train_iter</span><span class="p">:</span>
<span class="n">optim</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
<span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
<span class="n">x</span><span class="p">,</span> <span class="n">t</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">cuda</span><span class="p">(),</span> <span class="n">t</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">max_len</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">([</span><span class="mi">30</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">100</span><span class="p">])</span>
<span class="n">permutation</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="n">t</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">permutation</span><span class="p">)</span>
<span class="n">permutation</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">permutation</span><span class="p">[:</span><span class="n">max_len</span><span class="p">])</span>
<span class="n">x</span><span class="p">,</span> <span class="n">t</span> <span class="o">=</span> <span class="n">x</span><span class="p">[</span><span class="n">permutation</span><span class="p">],</span> <span class="n">t</span><span class="p">[</span><span class="n">permutation</span><span class="p">]</span>
<span class="n">x_p</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">z_mean</span><span class="p">,</span> <span class="n">z_log_var</span> <span class="o">=</span> <span class="n">vae</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="n">kl_loss</span> <span class="o">=</span> <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">torch</span><span class="o">.</span><span class="nb">sum</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">z_log_var</span> <span class="o">-</span> <span class="n">z_mean</span><span class="o">**</span><span class="mi">2</span> <span class="o">-</span> <span class="n">torch</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z_log_var</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="p">((</span><span class="n">x</span><span class="o">-</span><span class="n">x_p</span><span class="p">)</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="nb">sum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="nb">sum</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">/</span> <span class="n">noise_std</span><span class="o">**</span><span class="mi">2</span> <span class="o">+</span> <span class="n">kl_loss</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
<span class="n">loss</span> <span class="o">/=</span> <span class="n">max_len</span>
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">optim</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
<span class="k">print</span><span class="p">(</span><span class="n">f</span><span class="s">"Epoch {epoch_idx}"</span><span class="p">)</span>
<span class="n">frm</span><span class="p">,</span> <span class="n">to</span><span class="p">,</span> <span class="n">to_seed</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">200</span><span class="p">,</span> <span class="mi">50</span>
<span class="n">seed_trajs</span> <span class="o">=</span> <span class="n">samp_trajs</span><span class="p">[</span><span class="n">frm</span><span class="p">:</span><span class="n">to_seed</span><span class="p">]</span>
<span class="n">ts</span> <span class="o">=</span> <span class="n">samp_ts</span><span class="p">[</span><span class="n">frm</span><span class="p">:</span><span class="n">to</span><span class="p">]</span>
<span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
<span class="n">seed_trajs</span> <span class="o">=</span> <span class="n">seed_trajs</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">ts</span> <span class="o">=</span> <span class="n">ts</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">samp_trajs_p</span> <span class="o">=</span> <span class="n">to_np</span><span class="p">(</span><span class="n">vae</span><span class="o">.</span><span class="n">generate_with_seed</span><span class="p">(</span><span class="n">seed_trajs</span><span class="p">,</span> <span class="n">ts</span><span class="p">))</span>
<span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">9</span><span class="p">))</span>
<span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">ax</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">axes</span><span class="p">):</span>
<span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">to_np</span><span class="p">(</span><span class="n">seed_trajs</span><span class="p">[:,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">]),</span> <span class="n">to_np</span><span class="p">(</span><span class="n">seed_trajs</span><span class="p">[:,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">]),</span> <span class="n">c</span><span class="o">=</span><span class="n">to_np</span><span class="p">(</span><span class="n">ts</span><span class="p">[</span><span class="n">frm</span><span class="p">:</span><span class="n">to_seed</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">]),</span> <span class="n">cmap</span><span class="o">=</span><span class="n">cm</span><span class="o">.</span><span class="n">plasma</span><span class="p">)</span>
<span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">to_np</span><span class="p">(</span><span class="n">orig_trajs</span><span class="p">[</span><span class="n">frm</span><span class="p">:</span><span class="n">to</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">]),</span> <span class="n">to_np</span><span class="p">(</span><span class="n">orig_trajs</span><span class="p">[</span><span class="n">frm</span><span class="p">:</span><span class="n">to</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">]))</span>
<span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">samp_trajs_p</span><span class="p">[:,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">samp_trajs_p</span><span class="p">[:,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
<span class="k">print</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">losses</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">median</span><span class="p">(</span><span class="n">losses</span><span class="p">))</span>
<span class="n">clear_output</span><span class="p">(</span><span class="n">wait</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
</code></pre></div></div>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">spiral_0_idx</span> <span class="o">=</span> <span class="mi">3</span>
<span class="n">spiral_1_idx</span> <span class="o">=</span> <span class="mi">6</span>
<span class="n">homotopy_p</span> <span class="o">=</span> <span class="n">Tensor</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mf">0.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mi">10</span><span class="p">)[:,</span> <span class="bp">None</span><span class="p">])</span>
<span class="n">vae</span> <span class="o">=</span> <span class="n">vae</span>
<span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
<span class="n">homotopy_p</span> <span class="o">=</span> <span class="n">homotopy_p</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">vae</span> <span class="o">=</span> <span class="n">vae</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">spiral_0</span> <span class="o">=</span> <span class="n">orig_trajs</span><span class="p">[:,</span> <span class="n">spiral_0_idx</span><span class="p">:</span><span class="n">spiral_0_idx</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">spiral_1</span> <span class="o">=</span> <span class="n">orig_trajs</span><span class="p">[:,</span> <span class="n">spiral_1_idx</span><span class="p">:</span><span class="n">spiral_1_idx</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">ts_0</span> <span class="o">=</span> <span class="n">samp_ts</span><span class="p">[:,</span> <span class="n">spiral_0_idx</span><span class="p">:</span><span class="n">spiral_0_idx</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="p">:]</span>
<span class="n">ts_1</span> <span class="o">=</span> <span class="n">samp_ts</span><span class="p">[:,</span> <span class="n">spiral_1_idx</span><span class="p">:</span><span class="n">spiral_1_idx</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="p">:]</span>
<span class="k">if</span> <span class="n">use_cuda</span><span class="p">:</span>
<span class="n">spiral_0</span><span class="p">,</span> <span class="n">ts_0</span> <span class="o">=</span> <span class="n">spiral_0</span><span class="o">.</span><span class="n">cuda</span><span class="p">(),</span> <span class="n">ts_0</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">spiral_1</span><span class="p">,</span> <span class="n">ts_1</span> <span class="o">=</span> <span class="n">spiral_1</span><span class="o">.</span><span class="n">cuda</span><span class="p">(),</span> <span class="n">ts_1</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">z_cw</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">vae</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">spiral_0</span><span class="p">,</span> <span class="n">ts_0</span><span class="p">)</span>
<span class="n">z_cc</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">vae</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">spiral_1</span><span class="p">,</span> <span class="n">ts_1</span><span class="p">)</span>
<span class="n">homotopy_z</span> <span class="o">=</span> <span class="n">z_cw</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">homotopy_p</span><span class="p">)</span> <span class="o">+</span> <span class="n">z_cc</span> <span class="o">*</span> <span class="n">homotopy_p</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">6</span><span class="o">*</span><span class="n">np</span><span class="o">.</span><span class="n">pi</span><span class="p">,</span> <span class="mi">200</span><span class="p">))</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="bp">None</span><span class="p">]</span><span class="o">.</span><span class="n">expand</span><span class="p">(</span><span class="mi">200</span><span class="p">,</span> <span class="mi">10</span><span class="p">)[:,</span> <span class="p">:,</span> <span class="bp">None</span><span class="p">]</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">t</span><span class="o">.</span><span class="n">cuda</span><span class="p">()</span> <span class="k">if</span> <span class="n">use_cuda</span> <span class="k">else</span> <span class="n">t</span>
<span class="n">hom_gen_trajs</span> <span class="o">=</span> <span class="n">vae</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">homotopy_z</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">axes</span> <span class="o">=</span> <span class="n">axes</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span>
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">ax</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">axes</span><span class="p">):</span>
<span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">to_np</span><span class="p">(</span><span class="n">hom_gen_trajs</span><span class="p">[:,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">]),</span> <span class="n">to_np</span><span class="p">(</span><span class="n">hom_gen_trajs</span><span class="p">[:,</span> <span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">]))</span>
<span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">vae</span><span class="o">.</span><span class="n">state_dict</span><span class="p">(),</span> <span class="s">"models/vae_spirals.sd"</span><span class="p">)</span>
</code></pre></div></div>
<p>This is what I got after a night of training</p>
<p><img src="/assets/node/spirals_reconstructed.png" alt="spiral reconstruction with seed" class="align-center" /></p>
<div align="center">Dots are noisy observations of the original trajectories (blue), <br /> yellow are reconstructed and interpolated trajectories using dots as inputs. <br /> Color of the dots represents time. </div>
<p>Reconstuctions of some examples are not very good. Maybe the model is not complex enough or haven’t been trained for a long enough time. Anyway, results look very credible.</p>
<p>Now lets have a look at what happens if we interpolate the latent variable of the clockwise trajectory to another - the counter-clockwise one.</p>
<p><img src="/assets/node/spirals_homotopy.png" alt="homotopy" class="align-center" /></p>
<p>Authors also compare reconstructed trajectories using initial moment of time observations of Neural ODE and simple RNN.</p>
<p><img src="/assets/node/ode_rnn_comp.png" alt="ode_rnn_comp" class="align-center" /></p>
<div align="center">Figure from the original paper</div>
<h2 id="continuous-normalizing-flows">Continuous normalizing flows</h2>
<p>The original paper also contributes a lot in the topic of Normalizing Flows. Normalizing flows are used when one needs to sample from a complicated distribution originating from a change of variables in some simple distribution (e.q. Gaussian), while still being able to know the probability density of each sample.<br />
They show that using continuous change of variables is much more computationally efficient and interpretable than previous methods._</p>
<p>Normalizing flows are very useful in such models as <em>Variational AutoEncoders</em>, <em>Bayesian Neural Networks</em> and other things in Bayesian setting.</p>
<p>This topic, however, is beyond the scope of the present notebook, and those interested are adressed to the original paper.</p>
<p>To tease you a bit:</p>
<p><img src="/assets/node/CNF_NF_comp.png" alt="CNF_NF_comp" class="align-center" /></p>
<div align="center">Visualizing the transformation from noise (simple distribution) to data (complicated distribution) for two datasets; <br /> X-axis represents density and samples transformation with "time" (for CNF) and "depth" (for NF) <br />Figure from the original paper</div>
<p>This concludes my little investigation of <strong>Neural ODEs</strong>. Hope you found it useful!</p>
<h1 id="useful-links">Useful links</h1>
<ul>
<li><a href="https://arxiv.org/abs/1806.07366">Original paper</a></li>
<li><a href="https://github.com/rtqichen/torchdiffeq">Authors’ PyTorch implementation</a></li>
<li><a href="https://www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/variational-inference-i.pdf">Variational Inference</a></li>
<li><a href="https://habr.com/en/post/331552/">My article on VAE (Russian)</a></li>
<li><a href="https://www.jeremyjordan.me/variational-autoencoders/">VAE explained</a></li>
<li><a href="http://akosiorek.github.io/ml/2018/04/03/norm_flows.html">More on Normalizing Flows</a></li>
<li><a href="https://arxiv.org/abs/1505.05770">Variational Inference with Normalizing Flows Paper</a></li>
</ul>Mikhail Surtsukovmsurtsukov@gmail.comA significant portion of processes can be described by differential equations: let it be evolution of physical systems, medical conditions of a patient, fundamental properties of markets, etc. Such data is sequential and continuous in its nature, meaning that observations are merely realizations of some continuously changing state. There is also another type of sequential data that is discrete – NLP data, for example: its state changes discretely, from one symbol to another, from one word to another. Today both these types are normally processed using recurrent neural networks. They are, however, essentially different in their nature, and it seems that they should be treated differently. At the last NIPS conference a very interesting paper was presented that attempts to tackle this problem. Authors propose a very promising approach, which they call Neural Ordinary Differential Equations. Here I tried to reproduce and summarize the results of original paper, making it a little easier to familiarize yourself with the idea. As I believe, this new architecture may soon be, among convolutional and recurrent networks, in a toolbox of any data scientist.Generative models collection2018-12-25T00:00:00+00:002018-12-25T00:00:00+00:00https://msurtsukov.github.io/GANS<p>PyTorch implementations of various generative models to be trained and evaluated on <strong>CelebA</strong> dataset. The models are: <em>Deep Convolutional GAN, Least Squares GAN,
Wasserstein GAN, Wasserstein GAN Gradient Penalty, Information Maximizing GAN, Boundary Equilibrium GAN, Variational AutoEncoder and Variational AutoEncoder GAN</em>.
All models have as close as possible nets architectures and implementations with necessary deviations required by their articles.</p>
<!--more-->
<p>For now all models except InfoGAN are <strong>conditional</strong> on attributes, with attributes vector concatenated to latent variable for generator and to
channels for discriminator. However if desired conditionality can be easily removed.</p>
<p>Most of the code is shared between the models, so adding a model, changing models architectures or usage on different dataset all require little effort.
As most GANs differ only in the way generator and discriminator losses are computed adding a model might only require to inherit from GAN superclass and provide losses.</p>
<p>For visual control during training <em>Visdom</em> is used and requires only to provide values and images interested in.</p>
<h1 id="comments-on-models-and-results">Comments on models and results</h1>
<p>All models share architecture of <em>DCGAN</em> with slight deviations and were trained using Adam(0.5, 0.999) with batch size is 64 and learning rate is 0.0001.</p>
<p>Most models were not trained long enough.</p>
<p>Latent dimension is 128 for all models and models are conditioned on 40 binary attributes such as: ‘bald’, ‘eyeglasses’, ‘male’, etc.</p>
<p>Epochs count is based on discriminator steps (number of generator steps for Wassertein and non-Wassertein GANs differ greatly during one epoch).</p>
<h2 id="dcgan">DCGAN</h2>
<p><a href="https://arxiv.org/abs/1511.06434">Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks</a></p>
<p>Learning rate is 0.0001 and 2 generator steps per discriminator step provided better results.</p>
<p><img src="/assets/gans/dcgan.gif" alt="dcgan" width="900px" class="align-center" /></p>
<h2 id="lsgan">LSGAN</h2>
<p><a href="https://arxiv.org/abs/1611.04076">Least Squares Generative Adversarial Networks</a></p>
<script type="math/tex; mode=display">a = -1, b = 1, c = 0</script>
<p>Learning rate is 0.0001 and 2 generator steps per discriminator step.</p>
<p>Had slightly more stable training then DCGAN.</p>
<p><img src="/assets/gans/lsgan.gif" alt="lsgan" width="900px" class="align-center" /></p>
<h2 id="wgan">WGAN</h2>
<p><a href="https://arxiv.org/abs/1701.07875">Wasserstein GAN</a></p>
<p>Learning rate is 0.0001 and 5 discriminator steps per generator step.</p>
<p><img src="/assets/gans/wgan.gif" alt="wgan" width="900px" class="align-center" /></p>
<h2 id="wgangp">WGANGP</h2>
<p><a href="https://arxiv.org/abs/1704.00028">Improved Training of Wasserstein GANs</a></p>
<script type="math/tex; mode=display">\lambda = 10</script>
<p>Learning rate is 0.0001 and 5 discriminator steps per generator step.</p>
<p><img src="/assets/gans/wgangp.gif" alt="wgangp" width="900px" class="align-center" /></p>
<h2 id="infogan">InfoGAN</h2>
<p><a href="https://arxiv.org/abs/1606.03657">InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets</a></p>
<script type="math/tex; mode=display">\lambda = 1</script>
<p>Learning rate is 0.0001 and 2 generator steps per discriminator step.</p>
<p>In contrast to other models this one was not trained with conditional attributes. Instead they were used as binomial latent variables. Their true distribution was used as prior for fake images attributes, however their real meaning was to be understood by the model.</p>
<p><img src="/assets/gans/infogan.gif" alt="infogan" width="900px" class="align-center" /></p>
<h3 id="todo">TODO</h3>
<p>Check whether attributes really got their true interpretation.</p>
<h2 id="began">BEGAN</h2>
<p><a href="https://arxiv.org/abs/1703.10717">BEGAN: Boundary Equilibrium Generative Adversarial Networks</a></p>
<script type="math/tex; mode=display">\gamma=0.5, \lambda=0.01</script>
<p>Learning rate is 0.0001 and 2 generator steps per discriminator step.</p>
<p>Skip-connections proposed by the authors of the article were not introduced here. As this would require to heavily modify the nets architectures.
The model experienced heavy mode-collapsing. Authors suggest to reduce learning rate to avoid it. This was not tried.</p>
<p><img src="/assets/gans/began.gif" alt="began" width="900px" class="align-center" /></p>
<h2 id="vae">VAE</h2>
<p><a href="https://arxiv.org/abs/1312.6114">Auto-Encoding Variational Bayes</a></p>
<p>Learning rate is 0.0001.</p>
<p><img src="/assets/gans/vae.gif" alt="vae" width="900px" class="align-center" /></p>
<h2 id="vaegan">VAEGAN</h2>
<p><a href="https://arxiv.org/abs/1512.09300">Autoencoding beyond pixels using a learned similarity metric</a></p>
<script type="math/tex; mode=display">\gamma=0.01</script>
<p>Learning rate is 0.0001 and 2 generator steps per discriminator step.</p>
<p>Generator loss was changed to exploit log(1-x) -> -log(x) trick. It slightly losses probabilistic interpretation,
but at least it is able to train this way.</p>
<p>(Epochs count is based on generator steps)</p>
<h3 id="real-and-decoded">Real and decoded</h3>
<p><img src="/assets/gans/vaegan-real.png" alt="vaegan-real" width="900px" class="align-center" /></p>
<p><img src="/assets/gans/vaegan-decoded.png" alt="vaegan-decoded" width="900px" class="align-center" /></p>
<p><img src="/assets/gans/vaegan.gif" alt="vaegan" width="900px" class="align-center" /></p>
<p>Link to the <a href="https://github.com/msurtsukov/generative-models-collection">repository</a></p>Mikhail Surtsukovmsurtsukov@gmail.comPyTorch implementations of various generative models to be trained and evaluated on CelebA dataset. The models are: Deep Convolutional GAN, Least Squares GAN, Wasserstein GAN, Wasserstein GAN Gradient Penalty, Information Maximizing GAN, Boundary Equilibrium GAN, Variational AutoEncoder and Variational AutoEncoder GAN. All models have as close as possible nets architectures and implementations with necessary deviations required by their articles.AutoEncoders in Keras: VAE-GAN2017-07-01T00:00:00+00:002017-07-01T00:00:00+00:00https://msurtsukov.github.io/AE6<p>In the previous part, we created a CVAE autoencoder, whose decoder is able to generate a digit of a given label, we also tried to create pictures of numbers of other labels in the style of a given picture. It turned out pretty good, but the numbers were generated blurry.</p>
<p>In the last part, we studied how the GANs work, getting quite clear images of numbers, but the possibility of coding and transferring the style was lost.</p>
<p>In this part we will try to take the best from both approaches by combining variational autoencoders (VAE) and generative competing networks (GAN).</p>
<p>The approach, which will be described later, is based on the article [Autoencoding beyond pixels using a learned similarity metric, Larsen et al, 2016].</p>
<p><img src="/assets/ae/vaegan.png" alt="pde" width="400px" class="align-center" /></p>
<!--more-->
<p>Full Russian text is available <a href="https://habr.com/ru/post/332074/">here</a></p>
<p><a href="https://github.com/msurtsukov/ae_vae_gan">Repository with jupyter notebook</a></p>Mikhail Surtsukovmsurtsukov@gmail.comIn the previous part, we created a CVAE autoencoder, whose decoder is able to generate a digit of a given label, we also tried to create pictures of numbers of other labels in the style of a given picture. It turned out pretty good, but the numbers were generated blurry. In the last part, we studied how the GANs work, getting quite clear images of numbers, but the possibility of coding and transferring the style was lost. In this part we will try to take the best from both approaches by combining variational autoencoders (VAE) and generative competing networks (GAN). The approach, which will be described later, is based on the article [Autoencoding beyond pixels using a learned similarity metric, Larsen et al, 2016].AutoEncoders in Keras: VAE2017-06-24T00:00:00+00:002017-06-24T00:00:00+00:00https://msurtsukov.github.io/AE3<p>In the last part, we have already discussed what hidden variables are, looked at their distribution, and also understood that it is difficult to generate new objects from the distribution of latent variables in ordinary autoencoders. In order to be able to generate new objects, the space of latent variables must be predictable.</p>
<p>Variational Autoencoders are autoencoders that learn to map objects into a given hidden space and sample from it. Therefore, variational autoencoders are also referred to the family of generative models.</p>
<p><img src="/assets/ae/vae.png" alt="pde" width="400px" class="align-center" />
Illustration from <a href="http://ijdykeman.github.io/ml/2016/12/21/cvae.html">here</a></p>
<!--more-->
<p>Full Russian text is available <a href="https://habr.com/ru/post/331552/">here</a></p>
<p><a href="https://github.com/msurtsukov/ae_vae_gan">Repository with jupyter notebook</a></p>Mikhail Surtsukovmsurtsukov@gmail.comIn the last part, we have already discussed what hidden variables are, looked at their distribution, and also understood that it is difficult to generate new objects from the distribution of latent variables in ordinary autoencoders. In order to be able to generate new objects, the space of latent variables must be predictable. Variational Autoencoders are autoencoders that learn to map objects into a given hidden space and sample from it. Therefore, variational autoencoders are also referred to the family of generative models. Illustration from hereAutoEncoders in Keras: Introduction2017-06-23T00:00:00+00:002017-06-23T00:00:00+00:00https://msurtsukov.github.io/AE1<p>While diving into Deep Learning, the topic of auto-encoders caught me, especially in terms of generating new objects. In an effort to improve the quality of generation, I read various blogs and literature on the topic of generative approaches. As a result, I decided to reflect the gained experience in a small series of articles, in which I tried briefly and with examples to describe all those problem areas I had encountered myself, while at the same time introducing to Keras.</p>
<p><img src="/assets/ae/ae.png" alt="pde" width="400px" class="align-center" /></p>
<!--more-->
<p>Full Russian text is available <a href="https://habr.com/ru/post/331382/">here</a></p>
<p><a href="https://github.com/msurtsukov/ae_vae_gan">Repository with jupyter notebook</a></p>Mikhail Surtsukovmsurtsukov@gmail.comWhile diving into Deep Learning, the topic of auto-encoders caught me, especially in terms of generating new objects. In an effort to improve the quality of generation, I read various blogs and literature on the topic of generative approaches. As a result, I decided to reflect the gained experience in a small series of articles, in which I tried briefly and with examples to describe all those problem areas I had encountered myself, while at the same time introducing to Keras.AutoEncoders in Keras: Manifold learning and latent variables2017-06-23T00:00:00+00:002017-06-23T00:00:00+00:00https://msurtsukov.github.io/AE2<p>In order to better understand how autoencoders work, and also to subsequently generate something new from codes, it is worth understanding what codes are and how they can be interpreted.</p>
<p><img src="/assets/ae/lv.gif" alt="pde" width="400px" class="align-center" /></p>
<!--more-->
<p>Full Russian text is available <a href="https://habr.com/ru/post/331500/">here</a></p>
<p><a href="https://github.com/msurtsukov/ae_vae_gan">Repository with jupyter notebook</a></p>Mikhail Surtsukovmsurtsukov@gmail.comIn order to better understand how autoencoders work, and also to subsequently generate something new from codes, it is worth understanding what codes are and how they can be interpreted.AutoEncoders in Keras: GAN2017-05-30T00:00:00+00:002017-05-30T00:00:00+00:00https://msurtsukov.github.io/AE5<p>With all the advantages of VAE variational autoencoders, which we dealt with in previous posts, they have one major drawback: due to the poor way of comparing original and restored objects, the objects they generated are similar to the objects from the training set, but they are easily distinguishable from them (for example blurred).</p>
<p>This disadvantage is much less pronounced in another approach, namely, generative competing networks - GANs.</p>
<p>(The real reason why VAEs produce blurred images is because of the way we define likelihood when comparing original and restored object. Namely, we suppose that pixel values are independent from each other (likelihood factorizes into product of likelihoods for each pixel). GANs don’t make this assumption (because we don’t define the likelihood at all), and thus are not restricted by it.)</p>
<p>Formally, GANs, of course, do not belong to autoencoders, however there are similarities between them and variational autoencoders, they will also be useful for the next part. So it will not be superfluous to meet them too.</p>
<h3 id="gan-in-brief">GAN in brief</h3>
<p>GANs were first proposed in article [1, Generative Adversarial Nets, Goodfellow et al, 2014] and are now being actively studied. Most state-of-the-art generative models one way or another use adversarial.</p>
<p>GAN scheme:</p>
<p><img src="/assets/ae/gan.png" alt="pde" width="400px" class="align-center" /></p>
<!--more-->
<p>Full Russian text is available <a href="https://habr.com/ru/post/332000/">here</a></p>
<p><a href="https://github.com/msurtsukov/ae_vae_gan">Repository with jupyter notebook</a></p>Mikhail Surtsukovmsurtsukov@gmail.comWith all the advantages of VAE variational autoencoders, which we dealt with in previous posts, they have one major drawback: due to the poor way of comparing original and restored objects, the objects they generated are similar to the objects from the training set, but they are easily distinguishable from them (for example blurred). This disadvantage is much less pronounced in another approach, namely, generative competing networks - GANs. (The real reason why VAEs produce blurred images is because of the way we define likelihood when comparing original and restored object. Namely, we suppose that pixel values are independent from each other (likelihood factorizes into product of likelihoods for each pixel). GANs don’t make this assumption (because we don’t define the likelihood at all), and thus are not restricted by it.) Formally, GANs, of course, do not belong to autoencoders, however there are similarities between them and variational autoencoders, they will also be useful for the next part. So it will not be superfluous to meet them too. GAN in brief GANs were first proposed in article [1, Generative Adversarial Nets, Goodfellow et al, 2014] and are now being actively studied. Most state-of-the-art generative models one way or another use adversarial. GAN scheme:AutoEncoders in Keras: Conditional VAE2017-05-26T00:00:00+00:002017-05-26T00:00:00+00:00https://msurtsukov.github.io/AE4<p>In the last part, we met variational autoencoders (VAE), implemented one on keras, and also understood how to generate images using it. The resulting model, however, had some drawbacks:</p>
<p>Not all the numbers turned out to be well encoded in the latent space: some of the numbers were either completely absent or were very blurry. In between the areas in which the variants of the same number were concentrated, there were generally some meaningless hieroglyphs.</p>
<p>It was difficult to generate a picture of a given digit. To do this, one had to look into what area of the latent space the images of a specific digit fell into, and to sample it from somewhere there, and even more so it was difficult to generate a digit in some given style.</p>
<p>In this part, we will see how only by slightly complicating the model to overcome both these problems, and at the same time we will be able to generate pictures of new numbers in the style of another digit - this is probably the most interesting feature of the future model.</p>
<p><img src="/assets/ae/cvae.png" alt="pde" width="400px" class="align-center" /></p>
<!--more-->
<p>Full Russian text is available <a href="https://habr.com/ru/post/331552/">here</a></p>
<p><a href="https://github.com/msurtsukov/ae_vae_gan">Repository with jupyter notebook</a></p>Mikhail Surtsukovmsurtsukov@gmail.comIn the last part, we met variational autoencoders (VAE), implemented one on keras, and also understood how to generate images using it. The resulting model, however, had some drawbacks: Not all the numbers turned out to be well encoded in the latent space: some of the numbers were either completely absent or were very blurry. In between the areas in which the variants of the same number were concentrated, there were generally some meaningless hieroglyphs. It was difficult to generate a picture of a given digit. To do this, one had to look into what area of the latent space the images of a specific digit fell into, and to sample it from somewhere there, and even more so it was difficult to generate a digit in some given style. In this part, we will see how only by slightly complicating the model to overcome both these problems, and at the same time we will be able to generate pictures of new numbers in the style of another digit - this is probably the most interesting feature of the future model.Partially Differential Equations in Tensorflow2017-02-13T00:00:00+00:002017-02-13T00:00:00+00:00https://msurtsukov.github.io/PDE<p><img src="/assets/pde/orig.gif" alt="pde" width="400px" class="align-center" /></p>
<p>Inspired by a course on parallel computing in my university and just after
got acquainted with Tensorflow, I wrote this article as the result of
a curiosity to apply framework for deep learning to the problem that has nothing
to do with neural networks, but is mathematically similar.</p>
<!--more-->
<script type="math/tex; mode=display">\frac{\partial u}{\partial t} = \sum \limits_{\alpha=1}^{2} \frac{\partial}{\partial x_\alpha} \left (k_\alpha \frac{\partial u}{\partial x_\alpha} \right ) -u, \quad x_\alpha \in [0,1] \quad (\alpha=1,2), \ t>0;</script>
<script type="math/tex; mode=display">k_\alpha = \begin{cases} 50, (x_1, x_2) \in \Delta ABC\\ 1, (x_1, x_2) \notin \Delta ABC \end{cases}</script>
<script type="math/tex; mode=display">(\alpha = 1,2), \ A(0.2,0.5), \ B(0.7,0.2), \ C(0.5,0.8);</script>
<script type="math/tex; mode=display">u(x_1, x_2, 0) = 0,\ u(0,x_2,t) = 1 - e^{-\omega t},\ u(1, x_2, t) = 0,</script>
<script type="math/tex; mode=display">u(x_1,0,t) = 1 - e^{-\omega t},\ u(0, x_2, t) = 0,\ \omega = 20.</script>
<p>Full Russian text is available <a href="https://habr.com/ru/post/321734/">here</a></p>
<p><a href="https://github.com/msurtsukov/pde">Repository with jupyter notebook</a></p>Mikhail Surtsukovmsurtsukov@gmail.comInspired by a course on parallel computing in my university and just after got acquainted with Tensorflow, I wrote this article as the result of a curiosity to apply framework for deep learning to the problem that has nothing to do with neural networks, but is mathematically similar.