diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..225c38155869eaa5f928f05de5bb7681d66a20c5
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+*.png
+*.jpg
diff --git a/aires_node/neural_cde_driver.py b/aires_node/neural_cde_driver.py
index b9556375a71d032d62062f64c19409a1f9f86b07..389d9ace1b8ce024bc90a20eedb9cc3d02fa2ce1 100644
--- a/aires_node/neural_cde_driver.py
+++ b/aires_node/neural_cde_driver.py
@@ -41,7 +41,8 @@ class NeuralCDEDriver(object):
         self.model = model
         self.lr = kwargs.get("lr", 5e-4)
         self.optimizer = optim.AdamW(
-            self.model.parameters(), lr=self.lr, weight_decay=1e-5, amsgrad=True)
+            self.model.parameters(), lr=self.lr,
+            weight_decay=kwargs.get("weight_decay", 1e-5), amsgrad=True)
         self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
             self.optimizer, 'min', patience=20, cooldown=40,
             min_lr=1e-6, factor=kwargs.get("lr_factor", 0.99))
@@ -103,8 +104,12 @@ class NeuralCDEDriver(object):
         # log(1) is 0
         prior_logvar_z0 = torch.zeros(qz0_logvar.shape)
         # estimate varience at z0 from network predictions
-        nvk = 30
-        var_hat = torch.zeros((nvk, pred_x.shape[0], pred_x.shape[2]))
+        nvk = 96
+        if hasattr(self.model, 'num_nodes'):
+            assert len(pred_x.shape) == 4
+            var_hat = torch.zeros((nvk, pred_x.shape[0], pred_x.shape[2], pred_x.shape[3]))
+        else:
+            var_hat = torch.zeros((nvk, pred_x.shape[0], pred_x.shape[2]))
         for vk in range(nvk):
             eps = torch.randn(qz0_mean.size()).to(self.device)
             z0_rand = eps * torch.exp(.5 * qz0_logvar[:, :]) + qz0_mean[:, :]
@@ -120,17 +125,26 @@ class NeuralCDEDriver(object):
         var_hat_nsamples[:, :, :] = var_hat
         # compute elbo loss
         pred_logvar = torch.log(var_hat_nsamples)
+        x0_loc = int(self.z0_nsamples / self.eval_thin_k)
         elbo_loss = custom_elbo_loss(pred_x, _targ_x[:, :, :],
                                      pred_logvar,
                                      qz0_mean, qz0_logvar,
                                      targ_mu_z0=prior_mu_z0,
                                      targ_logvar_z0=prior_logvar_z0,
                                      beta=self.beta_vae)
-        x0_loc = int(self.z0_nsamples / self.eval_thin_k)
         x0_loss = log_normal_pdf(
-                pred_x[:, x0_loc:x0_loc+1, :], _targ_x[:, x0_loc:x0_loc+1, :], pred_logvar[:, x0_loc:x0_loc+1, :])
-        x0_loss = -torch.mean(x0_loss.sum(-1).sum(-1), dim=0)
-        return elbo_loss + self.x0_loss_scale*x0_loss
+                pred_x[:, x0_loc:x0_loc+1, :],
+                _targ_x[:, x0_loc:x0_loc+1, :],
+                pred_logvar[:, x0_loc:x0_loc+1, :])
+        xp_loss = log_normal_pdf(
+                pred_x[:, 0:1, :],
+                _targ_x[:, 0:1, :],
+                pred_logvar[:, 0:1, :])
+        x0_loss = -torch.mean(x0_loss.sum(tuple(range(1, len(_targ_x.shape)))), dim=0)
+        xp_loss = -torch.mean(xp_loss.sum(tuple(range(1, len(_targ_x.shape)))), dim=0)
+        # print(float(elbo_loss), float(x0_loss), float(self.x0_loss_scale*x0_loss))
+        x0_weight = float(_targ_x.shape[1])
+        return elbo_loss + self.x0_loss_scale*x0_weight*(0.5*x0_loss+0.5*xp_loss)
 
     def _ncde_loss(self, pred_x: torch.Tensor, targ_x: torch.Tensor, eval_idx=None):
         """
@@ -274,6 +288,8 @@ class NeuralCDEDriver(object):
             # Clean interpolant before saving or loading the model
             if hasattr(self.model, "zh_cde"):
                 self.model.zh_cde.Xf_t = None
+            if isinstance(self.model, NeuralCDE):
+                self.model.func.Xf_t = None
 
             if itr % self.checkpoint_epoch == 0:
                 # save model for rollback to best state
diff --git a/aires_node/neural_cde_encoders.py b/aires_node/neural_cde_encoders.py
index daa62acb58065c2751d560df14e5cedef5c78776..49c296ce14231a753dc86754ca63700680f48963 100644
--- a/aires_node/neural_cde_encoders.py
+++ b/aires_node/neural_cde_encoders.py
@@ -99,14 +99,16 @@ class RecognitionGRU(torch.nn.Module):
         )
         self.fc = torch.nn.Linear(nhidden, latent_dim * v_mult)
 
-    def forward(self, x, hidden_state):
+    def forward(self, x, hidden_state=None):
+        # if hidden_state is None:
+        output, hidden_state = self.gru(x)
         # NOTE: this does not need to be manually unrolled, x can have
         # time as an axis.  The GRU forward() call returns the next
         # predicted output in the sequence.
         # hidden_state = self.init_hidden()
         # x should have shape (n_batch, n_times, n_features)
         # hidden_state should have (num_layers, n_batch, nhidden)
-        output, hidden_state = self.gru(x, hidden_state)
+        # output, hidden_state = self.gru(x, hidden_state)
         # output has size n_batch, n_times, n_
         output = self.fc(output[:, -1, :])
         return output
@@ -143,6 +145,9 @@ class RecognitionFCN(torch.nn.Module):
             v_mult = 2
         self.final_tanh = final_tanh
         self.latent_dim = latent_dim
+        self.z0_nsamples = z0_nsamples
+        self.encoder_input_channels = encoder_input_channels
+        self.input_size = z0_nsamples * encoder_input_channels
         encoder_layer_list = []
         if encoder_depth < 0:
             # zero-depth layer where initial latent state is simply the inital
@@ -188,7 +193,7 @@ def initial_state_encoder_factory(
         return InitialLatentStateEncoderRNN(
                 input_channels, output_channels, z0_nsamples,
                 latent_dim, is_vae, num_nodes, **kwargs)
-    elif z0_fn_control == 0:
+    elif z0_fn_control >= 0:
         return InitialLatentStateEncoderFCN(
                 input_channels, output_channels, z0_nsamples,
                 latent_dim, is_vae, num_nodes, **kwargs)
@@ -212,17 +217,18 @@ class InitialLatentStateEncoderRNN(torch.nn.Module):
         self.latent_dim = latent_dim
         self.is_vae_model = is_vae
         self.num_nodes = num_nodes
-        self.encoder_type = encoder_kwargs.get("encoder_type", "gru")
-        encoder_hidden_width = encoder_kwargs.get("encoder_hidden_width", 32)
+        self.obs_dim = encoder_input_channels * max(num_nodes, 1)
+        self.encoder_type = kwargs.get("encoder_type", "gru")
+        encoder_hidden_width = kwargs.get("encoder_hidden_width", 32)
         assert z0_nsamples >= 3
         if self.encoder_type == "gru":
             self.initial = RecognitionGRU(
-                    latent_dim=self.latent_dim, obs_dim=encoder_input_channels,
-                    nhidden=encoder_hidden_width, nbatch=self.nbatch)
+                    latent_dim=self.latent_dim * max(num_nodes, 1), obs_dim=encoder_input_channels * max(num_nodes, 1),
+                    nhidden=encoder_hidden_width, nbatch=1, estimate_var=is_vae)
         else:
             self.initial = RecognitionRNN(
-                    latent_dim=self.latent_dim, obs_dim=encoder_input_channels,
-                    nhidden=encoder_hidden_width, nbatch=self.nbatch)
+                    latent_dim=self.latent_dim, obs_dim=encoder_input_channels * max(num_nodes, 1),
+                    nhidden=encoder_hidden_width, nbatch=1, estimate_var=is_vae)
 
     def __call__(self, y0, X_f, times, **kwargs):
         """
@@ -235,6 +241,7 @@ class InitialLatentStateEncoderRNN(torch.nn.Module):
             times: Times at which to evaluate the model.
         """
         device = y0.device
+        batch_dim = y0.shape[0]
         qz0_mean, qz0_logvar = torch.Tensor([]), torch.Tensor([])
         assert len(y0.shape) >= 3
         local_times = times - times[0]
@@ -242,21 +249,34 @@ class InitialLatentStateEncoderRNN(torch.nn.Module):
         # X0 = X_f.evaluate(local_times[0:self.z0_nsamples+1])
         # samp_trajs = torch.cat((X0[:,:-1,1:], y0), -1)
         X0 = Xf_evaluate_reshape(X_f, local_times[0:self.z0_nsamples], n_nodes=self.num_nodes)[..., 1:]
-        samp_trajs = torch.cat((X0, y0), -1)[:]
+        # samp_trajs = torch.cat((X0, y0), -1)[:]
 
-        batch_dim = y0.shape[0]
-        h = self.initial.init_hidden(batch_dim)
+        # For a GRU, requires samp_trajs to have shape [n_batch, n_times, n_features]
+        samp_trajs = torch.cat((X0, y0[:, 0:self.z0_nsamples]), -1)[:, 0:self.z0_nsamples, :]
+        if self.num_nodes > 0:
+            assert samp_trajs.shape[2] == self.num_nodes
+            # flatten the features on each node into single large vector
+            samp_trajs = samp_trajs.flatten(-2, -1)
 
         # Roll time backwards, infer initial latent state from known samples
         if self.encoder_type == "gru":
             samp_trajs_rev = torch.flip(samp_trajs, dims=(1,))
-            out = self.initial.forward(samp_trajs_rev, h)
+            out = self.initial.forward(samp_trajs_rev)
         else:
+            # initialize hidden state for unrolled RNN
+            h = self.initial.init_hidden(batch_dim)
             for t in reversed(range(samp_trajs.shape[1])):
                 # batch, time, nfeatures
                 obs = samp_trajs[:, t, :]
                 out, h = self.initial.forward(obs, h)
 
+        # In graph cde, out needs shape [nbatch, n_nodes, n_latent]
+        if self.num_nodes > 0:
+            # pytorch GRU does not have a node dimension, must expand output shape
+            # to [nbatch, n_nodes, n_latent]
+            out = out.reshape((batch_dim, self.num_nodes, -1))
+            assert out.shape[1] == self.num_nodes
+
         qz0_mean, qz0_logvar = out[..., :self.latent_dim], out[..., self.latent_dim:]
         if self.is_vae_model:
             # random sampling initial latent state while training
@@ -311,10 +331,23 @@ class InitialLatentStateEncoderFCN(torch.nn.Module):
         assert len(y0.shape) >= 3
         local_times = times - times[0]
         X0 = Xf_evaluate_reshape(X_f, local_times[0:self.z0_nsamples], n_nodes=self.num_nodes)[..., 1:]
+
         # FIXME: takes only the first time, meaning only z0_nsamples==1 is supported at this time
-        aug_X0 = torch.cat((X0, y0), -1)[:, 0, :]
+        # aug_X0 = torch.cat((X0, y0[:, 0:self.z0_nsamples]), -1)[:, 0, :]
+
+        aug_X0 = torch.cat((X0, y0[:, 0:self.z0_nsamples]), -1)[:, 0:self.z0_nsamples, :]
+        if self.num_nodes > 0:
+            assert aug_X0.shape[2] == self.num_nodes
+            aug_X0 = aug_X0.permute((0, 2, 1, 3))
+            aug_X0 = aug_X0.flatten(2, -1)
+        else:
+            aug_X0 = aug_X0.flatten(-2, -1)
         out = self.initial(aug_X0)
 
+        # In graph cde, out needs shape [nbatch, n_nodes, n_latent]
+        if self.num_nodes > 0:
+            assert out.shape[1] == self.num_nodes
+
         qz0_mean, qz0_logvar = out[..., :self.latent_dim], out[..., self.latent_dim:]
         if self.is_vae_model:
             # random sampling initial latent state while training
diff --git a/aires_node/neural_cde_models.py b/aires_node/neural_cde_models.py
index c3a3841ed281237143edce34937fd0eae4ac9c35..9e72ac00a377aaeda08479008d2b03f710cd4831 100644
--- a/aires_node/neural_cde_models.py
+++ b/aires_node/neural_cde_models.py
@@ -40,7 +40,7 @@ class CDEFuncTemporal(torch.nn.Module):
     def __call__(self, t, h):
         # control_gradient is of shape (..., input_channels)
         dX_dt = self.Xf_t.derivative(t)
-        vector_field_f = self.func_f(h)
+        vector_field_f = self.func_f(t, h)
         out = (vector_field_f @ dX_dt.unsqueeze(-1)).squeeze(-1)
         return out
 
@@ -217,14 +217,11 @@ class NeuralCDE(torch.nn.Module):
         cdeint_kwargs: arguments passed to the cde integrator.
             Eg: 'atol' and 'rtol' control desired integration accuracy.
         z0_fn_control: latent variable initial value mode.
-            0 - uses only inital system variables to estimate
-                inital latent variable values at t0. (default)
-            1 - uses only control variable initial values to estimate
-                inital latent variable values at t0.
-            2 - uses both system variables and control variable inital values
-                to estimate inital latent variable values at t0.
+            0 - Uses fully connected network to estimate initial latent state.
+            -1 - Uses RNN model. May have issue at longer lookback windows.
+            -2 - Uses GRU model. Typically better than the RNN model.
         z0_nsamples: number of observations used to infer the inital
-            values of the latent vars
+            values of the latent vars.  This is similar to lookback in other models.
         encoder_kwargs: arguments passed to the encoder
         decoder_kwargs: arguments passed to the decoder
     """
@@ -253,51 +250,24 @@ class NeuralCDE(torch.nn.Module):
         self.latent_dim = hidden_channels
         self.verbose = verbose
 
-        # Netowrk that represents the ODE in the latent space
-        self.func = CDEFunc(
+        # Netowrk that represents the dynamics in the latent space
+        self.func = CDEFuncTemporal(CDEFunc(
                 input_channels, hidden_channels, hidden_width,
-                hidden_depth, **cde_kwargs)
-
-        # Network that maps from input space, X to latent sapace, z
-        # self.initial = torch.nn.Linear(output_channels, hidden_channels)
-        encoder_layer_list = []
-        encoder_depth = encoder_kwargs.get("encoder_depth", 0)
-        encoder_hidden_width = encoder_kwargs.get("encoder_hidden_width", 32)
-        if self.z0_fn_control in (1, -1):
-            encoder_input_channels = input_channels-1
-        elif self.z0_fn_control in (2,):
-            encoder_input_channels = input_channels-1 + output_channels
-        elif self.z0_fn_control == -2:
-            # remove time as feature channel in encoder network
-            # encoder_input_channels = input_channels + output_channels
-            encoder_input_channels = input_channels-1 + output_channels
-        else:
-            encoder_input_channels = output_channels
-        if encoder_depth < 0:
-            encoder_layer_list.append(Pad1D(z0_nsamples * encoder_input_channels, self.latent_dim))
-        else:
-            encoder_layer_list.append(torch.nn.Linear(z0_nsamples * encoder_input_channels, hidden_channels))
-            for i in range(int(encoder_depth)):
-                encoder_layer_list.append(torch.nn.Tanh())
-                encoder_layer_list.append(torch.nn.Linear(hidden_channels, self.latent_dim))
+                hidden_depth, **cde_kwargs))
 
-        # Init RNN encoder model
-        self.encoder_type = encoder_kwargs.get("encoder_type", "gru")
-        if self.z0_fn_control < 0:
-            assert z0_nsamples >= 3
-            if self.encoder_type == "gru":
-                self.initial = RecognitionGRU(
-                        latent_dim=self.latent_dim, obs_dim=encoder_input_channels,
-                        nhidden=encoder_hidden_width, nbatch=self.nbatch)
-            else:
-                self.initial = RecognitionRNN(
-                        latent_dim=self.latent_dim, obs_dim=encoder_input_channels,
-                        nhidden=encoder_hidden_width, nbatch=self.nbatch)
-        else:
-            self.initial = torch.nn.Sequential(*encoder_layer_list)
+        # Initial latent value estimation network
+        self.initial = initial_state_encoder_factory(
+                 input_channels,
+                 output_channels,
+                 z0_fn_control,
+                 z0_nsamples,
+                 self.latent_dim,
+                 is_vae=self.is_vae_model,
+                 num_nodes=0,
+                 **encoder_kwargs
+                )
 
         # Network that maps from latent sapace, z to data space (output)
-        # self.readout = torch.nn.Linear(hidden_channels, output_channels)
         decoder_layer_list = []
         decoder_layer_list.append(torch.nn.Linear(self.latent_dim, output_channels))
         for i in range(int(decoder_kwargs.get("decoder_depth", 0))):
@@ -305,144 +275,52 @@ class NeuralCDE(torch.nn.Module):
             decoder_layer_list.append(torch.nn.Linear(output_channels, output_channels))
         self.readout = torch.nn.Sequential(*decoder_layer_list)
 
-        pytorch_total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
-        print("=== Number of trainable parameters in NeuralCDE: %d" % int(pytorch_total_params))
-
-        # init network weights and biases
-        if cde_kwargs.get("init_std", True):
-            for m in self.func.cde_nn.modules():
-                if isinstance(m, torch.nn.Linear):
-                    torch.nn.init.normal_(
-                            m.weight, mean=0., std=cde_kwargs.get("w_std", 0.01))
-                    torch.nn.init.normal_(
-                            m.bias, mean=0., std=cde_kwargs.get("b_std", 0.001))
-
-    def forward(self, z0, X_f, times, eval_idx=None):
+    def forward(self, y0, X_f, times, eval_idx=None):
         """
         Forward pass through the NCDE.  Integrates the NCDE forward in time.
-        Computes and returns results are input times.
+        Computes and returns results at desried times.
 
         Args:
             z0: prior observations dims - (n_batch, n_time, n_features)
             X_f: :module:`torchcde.LinearInterpolation` callable
                 interpolation function which implements an evaluate method.
-            times: Times at which to evaluate the model.
+            times: Desired times at which to evaluate the model.
             eval_idx: Only evaluate at marked times.
         """
-        # Initial hidden state should be a function of the first
-        # and/or prior observation(s).
         qz0_mean, qz0_logvar = torch.Tensor([]), torch.Tensor([])
-        assert len(z0.shape) == 3
-        assert z0.shape[1] == self.z0_nsamples
+        assert len(y0.shape) == 3
+        assert y0.shape[1] == self.z0_nsamples
         local_times = times - times[0]
         assert isinstance(local_times, torch.Tensor)
-        if self.z0_fn_control == 0:
-            z0_r = z0.reshape((-1, z0.shape[1]*z0.shape[2]))
-            z0_0 = self.initial(z0_r)
-        elif self.z0_fn_control == 1:
-            # In the torchcde examples, z0 is a funciton of X;
-            # however, the inital state of the latent vars, z
-            # may also be a function of the dependant system observables.
-            X0 = X_f.evaluate(local_times[0:self.z0_nsamples])[:, 0, 1:]
-            z0_0 = self.initial(X0)
-        elif self.z0_fn_control == 2:
-            # X0 = X_f.evaluate(local_times[0:self.z0_nsamples])[:, 0, 1:]
-            # z_r shape = (batch dim, n_time*n_features)
-            # z0_r = z0.reshape((-1, z0.shape[1]*z0.shape[2]))
-            # z0_a = torch.cat((X0, z0_r), 1)
-            # z0_0 = self.initial(z0_a)
-            # Use both control and sys var initial condition
-            X0 = X_f.evaluate(local_times[0:self.z0_nsamples])[:, :, 1:]
-            aug_X0 = torch.cat((X0, z0), 2)[:, 0, :]
-            z0_0 = self.initial(aug_X0)
-        elif self.z0_fn_control < 0:
-            # Use RNN encoder to build estimate for hidden intial conds, z0
-            # Requres z0 to be of length 3 or greater!
-            # in this case z0 is sequence of observed samples prior to
-            # forward forecast.
-            # backward in time to infer q(z_0)
-            if self.z0_fn_control == -2:
-                # concat along feature dim
-                # the 0th feature in X_f.evaluate is reserved for time-as-a-feature
-                X0 = X_f.evaluate(local_times[0:self.z0_nsamples+1])
-                # dt_trajs = X0[:,:-1,:] - X0[:,1:,0]
-                samp_trajs = torch.cat((X0[:,:-1,1:], z0), 2)
-                # samp_trajs = torch.cat((X0[:,:-1,1:], z0, dt_trajs), 2)
-            else:
-                samp_trajs = z0
-
-            # Initialize RNN encoder hidden state
-            batch_dim = z0.shape[0]
-            h = self.initial.init_hidden(batch_dim)
-
-            # Roll time backwards, infer initial latent state from known samples
-            if self.encoder_type == "gru":
-                samp_trajs_rev = torch.flip(samp_trajs, dims=(1,))
-                out = self.initial.forward(samp_trajs_rev, h)
-            else:
-                for t in reversed(range(samp_trajs.shape[1])):
-                    # batch, time, nfeatures
-                    obs = samp_trajs[:, t, :]
-                    out, h = self.initial.forward(obs, h)
-
-            qz0_mean, qz0_logvar = out[:, :self.latent_dim], out[:, self.latent_dim:]
-            device = z0.device
-            if self.training and self.is_vae_model:
-                # random sampling initial latent state while training
-                epsilon = torch.randn(qz0_mean.size()).to(device)
-            else:
-                # frozen mean initial latent state while testing/eval
-                # print("====== qz0_mean: %s " % (qz0_mean))
-                # print("====== qz0_sd: %s" % (torch.sqrt(torch.exp(qz0_logvar))) )
-                epsilon = torch.zeros(qz0_mean.size()).to(device)
-            # VAE sample initial conditions of latent var
-            z0_0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean
-        else:
-            raise NotImplementedError("Invalid z0 Estimation Method")
+        self.func.Xf_t = X_f
 
-        # z0_0 has shape (n_batch, self.latent_dim)
-        assert z0_0.shape[1] == self.latent_dim
-
-        # for logging only
-        if self.verbose:
-            t_range = X_f.interval
-            print("X_f.interval: ", t_range,  "times range: ",
-                  local_times[0], local_times[-1])
-            X_f_vals = X_f.evaluate(local_times)
-            X_f_times = X_f_vals[:,:,0]
-            X_f_vars = X_f_vals[:,:,1:]
-            print('Min control time Xf_t', torch.min(X_f_times),
-                  'Max control time Xf_t', torch.max(X_f_times))
-            print('Min control vars Xf', torch.min(X_f_vars),
-                  'Max control vars Xf', torch.max(X_f_vars))
+        # Predict the initial latent state, z0
+        z0, qz0_mean, qz0_logvar = self.initial(y0, X_f, times)
+        assert z0.shape[1] == self.latent_dim
 
         # Only output at selected times.
-        # This allows adaptive stepping to be more efficient since the
-        # ODE integration does not need to stop at every support
-        # point to eval results.
         if eval_idx is not None:
             assert len(eval_idx) <= len(local_times)
             selected_times = local_times[eval_idx]
         else:
             selected_times = local_times
 
-        # Solve the CDE.
-        try:
-            z_T = torchcde.cdeint(X=X_f,
-                                  z0=z0_0,
-                                  func=self.func,
-                                  t=selected_times,
-                                  rtol=self.cdeint_kwargs.get("rtol", 1e-5),
-                                  atol=self.cdeint_kwargs.get("atol", 1e-6),
-                                  backend='torchdiffeq',
-                                  adjoint=self.cdeint_kwargs.get("adjoint", False))
-        except:
-            # z_T has shape (n_batches, time, latent_dim)
-            z_T = torch.zeros(
-                    (z0_0.shape[0], len(selected_times), self.latent_dim))
-            print("WARNING: cdeint failed")
+        # Integrate the NCDE to obtain latent state, z(t), trajectories
+        if self.cdeint_kwargs.get("adjoint", False):
+            z_T = odeint_adjoint(func=self.func, y0=z0, t=selected_times,
+                         rtol=self.cdeint_kwargs.get("rtol", 1e-5),
+                         atol=self.cdeint_kwargs.get("atol", 1e-6),
+                         adjoint_rtol=self.cdeint_kwargs.get("adjoint_rtol", 1e-3),
+                         adjoint_atol=self.cdeint_kwargs.get("adjoint_atol", 1e-4),
+                         method=self.cdeint_kwargs.get("method", 'dopri5'))
+        else:
+            z_T = odeint(func=self.func, y0=z0, t=selected_times,
+                         rtol=self.cdeint_kwargs.get("rtol", 1e-5),
+                         atol=self.cdeint_kwargs.get("atol", 1e-6),
+                         method=self.cdeint_kwargs.get("method", 'dopri5'))
 
         # Decode from latent space to data space
+        z_T = torch.permute(z_T, (1, 0, 2))
         pred_y = self.readout(z_T)
         return pred_y, qz0_mean, qz0_logvar
 
@@ -450,12 +328,14 @@ class NeuralCDE(torch.nn.Module):
         """
         Save model state to file
         """
+        self.func.Xf_t = None
         torch.save(self.state_dict(), file_path)
 
     def load_model(self, file_path: str):
         """
         Load model from file
         """
+        self.func.Xf_t = None
         state_dict = torch.load(file_path)
         self.model.load_state_dict(state_dict)
 
@@ -543,7 +423,7 @@ class NeuralSTG_CDE(torch.nn.Module):
                  h0_fn_control,
                  z0_nsamples,
                  latent_dim,
-                 is_vae=is_vae_model,
+                 is_vae=False,
                  num_nodes=self.num_nodes,
                  **encoder_kwargs
                 )
@@ -600,10 +480,12 @@ class NeuralSTG_CDE(torch.nn.Module):
 
         # estimate initial latent states
         assert self.num_nodes > 0
-        h0, _, _ = self.h0_est(
+        h0, qh0_mean, qh0_logvar = self.h0_est(
                 y0, X_f, times, n_nodes=self.num_nodes)
         z0, qz0_mean, qz0_logvar = self.z0_est(
                 y0, X_f, times, n_nodes=self.num_nodes)
+
+        # If VAE, z0, h0 are samples from normal dist
         init_zh0 = (z0, h0)
 
         # Only output at selected times.
@@ -654,3 +536,232 @@ class NeuralSTG_CDE(torch.nn.Module):
         self.zh_cde.Xf_t = None
         state_dict = torch.load(file_path)
         self.model.load_state_dict(state_dict)
+
+    def _acg_adj_w(self):
+        """
+        Adaptive convolutional network adjacency matrix and weights.
+        For debug and visualization only.
+        """
+        return self.zh_cde.func_g.normalized_agc_adjacency(), self.zh_cde.func_g.agc_weights()
+
+
+class NeuralCDE_deprecated(torch.nn.Module):
+    def __init__(self,
+                 input_channels,
+                 hidden_channels,
+                 output_channels,
+                 hidden_width=100,
+                 hidden_depth=1,
+                 z0_fn_control=0,
+                 z0_nsamples=1,
+                 n_batches=1,
+                 cdeint_kwargs={},
+                 cde_kwargs={},
+                 encoder_kwargs={},
+                 decoder_kwargs={},
+                 is_vae_model=False,
+                 verbose=False,
+                 ):
+        super(NeuralCDE_deprecated, self).__init__()
+        self.z0_fn_control = z0_fn_control
+        self.cdeint_kwargs = cdeint_kwargs
+        self.nbatch = n_batches
+        self.z0_nsamples = z0_nsamples
+        self.is_vae_model = is_vae_model
+        self.latent_dim = hidden_channels
+        self.verbose = verbose
+
+        # Netowrk that represents the ODE in the latent space
+        self.func = CDEFunc(
+                input_channels, hidden_channels, hidden_width,
+                hidden_depth, **cde_kwargs)
+
+        # Network that maps from input space, X to latent sapace, z
+        # self.initial = torch.nn.Linear(output_channels, hidden_channels)
+        encoder_layer_list = []
+        encoder_depth = encoder_kwargs.get("encoder_depth", 0)
+        encoder_hidden_width = encoder_kwargs.get("encoder_hidden_width", 32)
+        if self.z0_fn_control in (1, -1):
+            encoder_input_channels = input_channels-1
+        elif self.z0_fn_control in (2,):
+            encoder_input_channels = input_channels-1 + output_channels
+        elif self.z0_fn_control == -2:
+            # remove time as feature channel in encoder network
+            # encoder_input_channels = input_channels + output_channels
+            encoder_input_channels = input_channels-1 + output_channels
+        else:
+            encoder_input_channels = output_channels
+        if encoder_depth < 0:
+            encoder_layer_list.append(Pad1D(z0_nsamples * encoder_input_channels, self.latent_dim))
+        else:
+            encoder_layer_list.append(torch.nn.Linear(z0_nsamples * encoder_input_channels, hidden_channels))
+            for i in range(int(encoder_depth)):
+                encoder_layer_list.append(torch.nn.Tanh())
+                encoder_layer_list.append(torch.nn.Linear(hidden_channels, self.latent_dim))
+
+        # Init RNN encoder model
+        self.encoder_type = encoder_kwargs.get("encoder_type", "gru")
+        if self.z0_fn_control < 0:
+            assert z0_nsamples >= 3
+            if self.encoder_type == "gru":
+                self.initial = RecognitionGRU(
+                        latent_dim=self.latent_dim, obs_dim=encoder_input_channels,
+                        nhidden=encoder_hidden_width, nbatch=self.nbatch)
+            else:
+                self.initial = RecognitionRNN(
+                        latent_dim=self.latent_dim, obs_dim=encoder_input_channels,
+                        nhidden=encoder_hidden_width, nbatch=self.nbatch)
+        else:
+            self.initial = torch.nn.Sequential(*encoder_layer_list)
+
+        # Network that maps from latent sapace, z to data space (output)
+        # self.readout = torch.nn.Linear(hidden_channels, output_channels)
+        decoder_layer_list = []
+        decoder_layer_list.append(torch.nn.Linear(self.latent_dim, output_channels))
+        for i in range(int(decoder_kwargs.get("decoder_depth", 0))):
+            decoder_layer_list.append(torch.nn.Tanh())
+            decoder_layer_list.append(torch.nn.Linear(output_channels, output_channels))
+        self.readout = torch.nn.Sequential(*decoder_layer_list)
+
+        pytorch_total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
+        print("=== Number of trainable parameters in NeuralCDE: %d" % int(pytorch_total_params))
+
+        # init network weights and biases
+        if cde_kwargs.get("init_std", True):
+            for m in self.func.cde_nn.modules():
+                if isinstance(m, torch.nn.Linear):
+                    torch.nn.init.normal_(
+                            m.weight, mean=0., std=cde_kwargs.get("w_std", 0.01))
+                    torch.nn.init.normal_(
+                            m.bias, mean=0., std=cde_kwargs.get("b_std", 0.001))
+
+    def forward(self, z0, X_f, times, eval_idx=None):
+        # Initial hidden state should be a function of the first
+        # and/or prior observation(s).
+        qz0_mean, qz0_logvar = torch.Tensor([]), torch.Tensor([])
+        assert len(z0.shape) == 3
+        assert z0.shape[1] == self.z0_nsamples
+        local_times = times - times[0]
+        assert isinstance(local_times, torch.Tensor)
+        if self.z0_fn_control == 0:
+            z0_r = z0.reshape((-1, z0.shape[1]*z0.shape[2]))
+            z0_0 = self.initial(z0_r)
+        elif self.z0_fn_control == 1:
+            # In the torchcde examples, z0 is a funciton of X;
+            # however, the inital state of the latent vars, z
+            # may also be a function of the dependant system observables.
+            X0 = X_f.evaluate(local_times[0:self.z0_nsamples])[:, 0, 1:]
+            z0_0 = self.initial(X0)
+        elif self.z0_fn_control == 2:
+            # X0 = X_f.evaluate(local_times[0:self.z0_nsamples])[:, 0, 1:]
+            # z_r shape = (batch dim, n_time*n_features)
+            # z0_r = z0.reshape((-1, z0.shape[1]*z0.shape[2]))
+            # z0_a = torch.cat((X0, z0_r), 1)
+            # z0_0 = self.initial(z0_a)
+            # Use both control and sys var initial condition
+            X0 = X_f.evaluate(local_times[0:self.z0_nsamples])[:, :, 1:]
+            aug_X0 = torch.cat((X0, z0), 2)[:, 0, :]
+            z0_0 = self.initial(aug_X0)
+        elif self.z0_fn_control < 0:
+            # Use RNN encoder to build estimate for hidden intial conds, z0
+            # Requres z0 to be of length 3 or greater!
+            # in this case z0 is sequence of observed samples prior to
+            # forward forecast.
+            # backward in time to infer q(z_0)
+            if self.z0_fn_control == -2:
+                # concat along feature dim
+                # the 0th feature in X_f.evaluate is reserved for time-as-a-feature
+                X0 = X_f.evaluate(local_times[0:self.z0_nsamples+1])
+                # dt_trajs = X0[:,:-1,:] - X0[:,1:,0]
+                samp_trajs = torch.cat((X0[:,:-1,1:], z0), 2)
+                # samp_trajs = torch.cat((X0[:,:-1,1:], z0, dt_trajs), 2)
+            else:
+                samp_trajs = z0
+
+            # Initialize RNN encoder hidden state
+            batch_dim = z0.shape[0]
+            h = self.initial.init_hidden(batch_dim)
+
+            # Roll time backwards, infer initial latent state from known samples
+            if self.encoder_type == "gru":
+                samp_trajs_rev = torch.flip(samp_trajs, dims=(1,))
+                out = self.initial.forward(samp_trajs_rev, h)
+            else:
+                for t in reversed(range(samp_trajs.shape[1])):
+                    # batch, time, nfeatures
+                    obs = samp_trajs[:, t, :]
+                    out, h = self.initial.forward(obs, h)
+
+            qz0_mean, qz0_logvar = out[:, :self.latent_dim], out[:, self.latent_dim:]
+            device = z0.device
+            if self.training and self.is_vae_model:
+                # random sampling initial latent state while training
+                epsilon = torch.randn(qz0_mean.size()).to(device)
+            else:
+                # frozen mean initial latent state while testing/eval
+                # print("====== qz0_mean: %s " % (qz0_mean))
+                # print("====== qz0_sd: %s" % (torch.sqrt(torch.exp(qz0_logvar))) )
+                epsilon = torch.zeros(qz0_mean.size()).to(device)
+            # VAE sample initial conditions of latent var
+            z0_0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean
+        else:
+            raise NotImplementedError("Invalid z0 Estimation Method")
+
+        # z0_0 has shape (n_batch, self.latent_dim)
+        assert z0_0.shape[1] == self.latent_dim
+
+        # for logging only
+        if self.verbose:
+            t_range = X_f.interval
+            print("X_f.interval: ", t_range,  "times range: ",
+                  local_times[0], local_times[-1])
+            X_f_vals = X_f.evaluate(local_times)
+            X_f_times = X_f_vals[:,:,0]
+            X_f_vars = X_f_vals[:,:,1:]
+            print('Min control time Xf_t', torch.min(X_f_times),
+                  'Max control time Xf_t', torch.max(X_f_times))
+            print('Min control vars Xf', torch.min(X_f_vars),
+                  'Max control vars Xf', torch.max(X_f_vars))
+
+        # Only output at selected times.
+        # This allows adaptive stepping to be more efficient since the
+        # ODE integration does not need to stop at every support
+        # point to eval results.
+        if eval_idx is not None:
+            assert len(eval_idx) <= len(local_times)
+            selected_times = local_times[eval_idx]
+        else:
+            selected_times = local_times
+
+        # Solve the CDE.
+        try:
+            z_T = torchcde.cdeint(X=X_f,
+                                  z0=z0_0,
+                                  func=self.func,
+                                  t=selected_times,
+                                  rtol=self.cdeint_kwargs.get("rtol", 1e-5),
+                                  atol=self.cdeint_kwargs.get("atol", 1e-6),
+                                  backend='torchdiffeq',
+                                  adjoint=self.cdeint_kwargs.get("adjoint", False))
+        except:
+            # z_T has shape (n_batches, time, latent_dim)
+            z_T = torch.zeros(
+                    (z0_0.shape[0], len(selected_times), self.latent_dim))
+            print("WARNING: cdeint failed")
+
+        # Decode from latent space to data space
+        pred_y = self.readout(z_T)
+        return pred_y, qz0_mean, qz0_logvar
+
+    def save_model(self, file_path: str):
+        """
+        Save model state to file
+        """
+        torch.save(self.state_dict(), file_path)
+
+    def load_model(self, file_path: str):
+        """
+        Load model from file
+        """
+        state_dict = torch.load(file_path)
+        self.model.load_state_dict(state_dict)
diff --git a/aires_node/neural_cde_vecfields.py b/aires_node/neural_cde_vecfields.py
index 2dc78428a9550f31d82967dabf443485b26a6298..ae39be94cfcd496c1e25fdb5d6e1523c5f3d82c4 100644
--- a/aires_node/neural_cde_vecfields.py
+++ b/aires_node/neural_cde_vecfields.py
@@ -121,6 +121,9 @@ class VectorFieldAGC_g(torch.nn.Module):
         Choi, J. et al. Graph Neural Controlled Differential Equations
         for Traffic Forecasting.
 
+        Wu, Z. et al. Graph WaveNet for Deep Spatial-Temporal Graph Modeling
+        https://arxiv.org/abs/1906.00121
+
     Args:
         input_channels: number of input features
         hidden_channels: number of input features
@@ -171,6 +174,12 @@ class VectorFieldAGC_g(torch.nn.Module):
         z = z.tanh()
         return z
 
+    def normalized_agc_adjacency(self):
+        return F.softmax(F.relu(torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1))), dim=1)
+
+    def agc_weights(self):
+        return self.weights_pool
+
     def agc(self, z):
         """
         Adaptive Graph Convolution
diff --git a/aires_node/utils_cde.py b/aires_node/utils_cde.py
index 2493c76b7e84f1340fbb113d48bbb1440830b007..39951276227dd175da0f4c6dedbd642c932bbad8 100644
--- a/aires_node/utils_cde.py
+++ b/aires_node/utils_cde.py
@@ -172,7 +172,10 @@ def predict_ncde(nn_cde, test_t_s, test_z0_s, test_U_s, n_trajectories=30, devic
         z0_test = torch.unsqueeze(z0_test, 0)
     if z0_test.shape[0] < n_trajectories:
         # duplicate along batch dim
-        z0_test = z0_test.expand(n_trajectories, -1, -1)
+        try:
+            z0_test = z0_test.expand(n_trajectories, -1, -1)
+        except:
+            z0_test = z0_test.expand(n_trajectories, -1, -1, -1)
 
 #     if len(test_U_s.shape) == 3:
 #         # ensure batch dim matches
diff --git a/aires_node/utils_common.py b/aires_node/utils_common.py
index 4b7830a24e0bb5111669270d58da8305cc6110ad..1ef1cd91734cd0e87e2ba787d77eddb1a4666fdb 100644
--- a/aires_node/utils_common.py
+++ b/aires_node/utils_common.py
@@ -2,6 +2,8 @@
 Common utilities shared between multiple modules
 """
 import numpy as np
+from sklearn.preprocessing import StandardScaler, MinMaxScaler
+import matplotlib.pyplot as plt
 import torch
 import numba
 
@@ -363,3 +365,84 @@ def n_trainable_params(model):
     print("=== Number of trainable parameters in NN: %d" % \
           int(pytorch_total_params))
     return pytorch_total_params
+
+
+def plot_loss_curve(loss_history, **kwargs):
+    """
+    Plot train/test loss curves
+    """
+    loss_hist = np.asarray(loss_history)
+    yl = kwargs.get("ylabel", "Loss")
+    fig_f = kwargs.get("fig_f_train", "loss_train_ncde.png")
+    curriculum_epochs = kwargs.get("curriculum_epochs", [])
+    # Plot training loss curve
+    plt.figure()
+    plt.plot(loss_hist[:, 0], loss_hist[:, 2], label="Train Loss")
+    plt.legend()
+    plt.xlabel("Epoch")
+    plt.ylabel(yl)
+    plt.grid(ls='--', alpha=0.5)
+    if len(curriculum_epochs) > 0:
+        for ce in curriculum_epochs:
+            plt.axvline(x=ce, c='k', alpha=0.8)
+    plt.savefig(fig_f, dpi=180)
+    plt.close()
+
+    # Plot test loss curve if data is available
+    if loss_hist.shape[1] > 3 and kwargs.get("plot_test_loss", True):
+        fig_f = kwargs.get("fig_f_test", "loss_test_ncde.png")
+        plt.figure()
+        plt.plot(loss_hist[:, 0], loss_hist[:, 3], label="Test Loss")
+        plt.legend()
+        plt.xlabel("Epoch")
+        plt.ylabel(yl)
+        plt.grid(ls='--', alpha=0.5)
+        if len(curriculum_epochs) > 0:
+            for ce in curriculum_epochs:
+                plt.axvline(x=ce, c='k', alpha=0.8)
+        plt.savefig(fig_f, dpi=180)
+        plt.close()
+
+
+def multicase_transform(cases_x_list, std_scale=False):
+    """
+    Perform minmax or standard scaling on data and return
+    the transformed data and the fitted transform to
+    allow inverse of this transform to be performed in post.
+
+    Args:
+        cases_x_list: Input data has shape: [n_batch, n_times, n_features]
+    """
+    if std_scale:
+        X_tr = StandardScaler()
+    else:
+        X_tr = MinMaxScaler()
+    tmp_x = np.concatenate(cases_x_list)
+    X_tr.fit(tmp_x)
+    out_x = []
+    for case_x in cases_x_list:
+        out_x.append(X_tr.transform(case_x))
+    return out_x, X_tr
+
+
+def multicase_transform_graph(cases_x_list, std_scale=False):
+    """
+    Similar to multicase_transform but for graph data.
+
+    Args:
+        cases_x_list: Input data has shape: [n_batch, n_times, n_nodes, n_features]
+    """
+    num_nodes = cases_x_list[0].shape[-2]
+    num_feat = cases_x_list[0].shape[-1]
+    if std_scale:
+        X_tr = StandardScaler()
+    else:
+        X_tr = MinMaxScaler()
+    tmp_x = np.concatenate(cases_x_list).reshape(-1, num_nodes*num_feat)
+    X_tr.fit(tmp_x)
+    out_x = []
+    for case_x in cases_x_list:
+        tr_x = X_tr.transform(case_x.reshape(-1, num_nodes*num_feat))
+        out_x.append(tr_x.reshape(-1, num_nodes, num_feat))
+    return out_x, X_tr
+
diff --git a/aires_node/utils_loss.py b/aires_node/utils_loss.py
index 6ec4ab691d40419ac2335b997a5480892fea0a1d..c3ca07c0b3bfccfa7620e1351b2379a324054290 100644
--- a/aires_node/utils_loss.py
+++ b/aires_node/utils_loss.py
@@ -94,9 +94,23 @@ def custom_elbo_loss(x_pred, x_true, pred_logvar, pred_mu_z0, pred_logvar_z0, ta
     Computes -L(p(x)) + \beta*KL(q(q_0)||p(z_0))
     """
     # sum along time and feature dim
-    log_px = log_normal_pdf(x_pred, x_true, pred_logvar).sum(-1).sum(-1)
+    # log_px = log_normal_pdf(x_pred, x_true, pred_logvar).sum(-1).sum(-1)
+    log_px = log_normal_pdf(x_pred, x_true, pred_logvar).sum(tuple(range(1, len(x_true.shape))))
     # sum along z dim (n_latent vars)
-    prior_kl = kl_normal(pred_mu_z0, pred_logvar_z0, targ_mu_z0, targ_logvar_z0).sum(-1)
+    # prior_kl = kl_normal(pred_mu_z0, pred_logvar_z0, targ_mu_z0, targ_logvar_z0).sum(-1)
+    prior_kl = kl_normal(pred_mu_z0, pred_logvar_z0, targ_mu_z0, targ_logvar_z0).sum(
+            tuple(range(1, len(pred_mu_z0.shape))))
     # avg along batch dim
+    # kl_term = beta * torch.mean(prior_kl, dim=0)
+    # gnll_term = torch.mean(-log_px, dim=0)
+    # elbo_loss =  gnll_term + kl_term
     elbo_loss = torch.mean(-log_px + beta * prior_kl, dim=0)
+    #gnll_loss_fn = torch.nn.GaussianNLLLoss(reduction='none', full=True)
+    #gnll_loss_fns = torch.nn.GaussianNLLLoss(reduction='none', full=True)
+    # gnll_torch = gnll_loss_fn(x_pred, x_true, torch.exp(pred_logvar))
+    #gnll_torch = gnll_loss_fn(x_pred, x_true, torch.exp(pred_logvar)).sum(tuple(range(1, len(x_true.shape))))
+    #gnll_torch = torch.mean(gnll_torch, dim=0)
+    # import pdb; pdb.set_trace()
+    #elbo_alt_loss = gnll_torch + kl_term
+    #return elbo_alt_loss
     return elbo_loss
diff --git a/examples/pred_prey_new_ncde.py b/examples/pred_prey_new_ncde.py
index b2300685be4ded147d5ed901e5f3c6220166a7f2..a7faa6f9a1bebf68dde1c05cb9eb6c6fa8262a84 100644
--- a/examples/pred_prey_new_ncde.py
+++ b/examples/pred_prey_new_ncde.py
@@ -6,8 +6,8 @@ from sklearn.model_selection import train_test_split
 from aires_node.utils_common import BatchGenerator, n_trainable_params
 from aires_node.neural_cde_driver import NeuralCDEDriver
 from aires_node.neural_cde_models import NeuralCDE
+from aires_node.utils_common import plot_loss_curve
 from pred_prey_sys import node_lv_ctrl_sys
-from pred_prey_vae_ncde import plot_loss_curve
 import torch
 
 
diff --git a/examples/pred_prey_vae_ncde.py b/examples/pred_prey_vae_ncde.py
index 31f75d95d293ce77abc040d5df0293527b63ea1b..70d318fc34c8f51393e40bd6fea8faf19448ea76 100644
--- a/examples/pred_prey_vae_ncde.py
+++ b/examples/pred_prey_vae_ncde.py
@@ -6,7 +6,8 @@ import torch
 # internal imports
 from aires_node.utils_common import BatchGenerator, n_trainable_params
 from aires_node.neural_cde_driver import NeuralCDEDriver
-from aires_node.neural_cde_models import NeuralCDE
+from aires_node.neural_cde_models import NeuralCDE, NeuralCDE_deprecated
+from aires_node.utils_common import plot_loss_curve
 from pred_prey_sys import node_lv_ctrl_sys
 
 
@@ -37,9 +38,10 @@ def fit_vae_ncde(n_epochs=6000, **kwargs):
     cdeint_kwargs = {'rtol': 1e-3, 'atol': 1e-4}
     encoder_kwargs = {'encoder_depth': 1,
                       'encoder_hidden_width': 24,
-                      'encoder_type': "gru"}
+                      'encoder_type': "rnn"}
     decoder_kwargs = {'decoder_depth': 0}
     z0_nsamples = 12
+    # model = NeuralCDE_deprecated(input_channels=3, hidden_channels=8,
     model = NeuralCDE(input_channels=3, hidden_channels=8,
                       output_channels=2, hidden_width=60,
                       hidden_depth=2, z0_fn_control=-2,
@@ -73,7 +75,10 @@ def fit_vae_ncde(n_epochs=6000, **kwargs):
     driver.load_best_test_state()
 
     # Plot training curve
-    plot_loss_curve(driver.loss_history, curriculum_epochs=curriculum_epochs)
+    plot_loss_curve(driver.loss_history, curriculum_epochs=curriculum_epochs,
+                    ylabel="ELBO Loss",
+                    fig_f_train="loss_train_pred_prey_vae_ncde.png",
+                    fig_f_test="loss_test_pred_prey_vae_ncde.png")
 
     # Plot results on test set
     pred_y_trajs, _z0_mu, _z0_logvar = driver.eval(t_s_test, X_s_test, u_s_test, nt=30)
@@ -114,34 +119,5 @@ def plot_test_data(model, t_test, pred_y_trajs, y_test, u_test, **kwargs):
     plt.close()
 
 
-def plot_loss_curve(loss_history, **kwargs):
-    fig_f = kwargs.get("fig_f", "loss_train_pred_prey_vae_ncde.png")
-    curriculum_epochs = kwargs.get("curriculum_epochs", [])
-    plt.figure()
-    loss_hist = np.asarray(loss_history)
-    plt.plot(loss_hist[:, 0], loss_hist[:, 2], label="Train Loss")
-    plt.legend()
-    plt.xlabel("Epoch")
-    plt.ylabel("ELBO Loss")
-    plt.grid(ls='--', alpha=0.5)
-    if len(curriculum_epochs) > 0:
-        for ce in curriculum_epochs:
-            plt.axvline(x=ce, c='k', alpha=0.8)
-    plt.savefig(fig_f, dpi=180)
-    plt.close()
-    fig_f = kwargs.get("fig_f_test", "loss_test_pred_prey_vae_ncde.png")
-    plt.figure()
-    loss_hist = np.asarray(loss_history)
-    plt.plot(loss_hist[:, 0], loss_hist[:, 3], label="Test Loss")
-    plt.legend()
-    plt.xlabel("Epoch")
-    plt.ylabel("ELBO Loss")
-    plt.grid(ls='--', alpha=0.5)
-    if len(curriculum_epochs) > 0:
-        for ce in curriculum_epochs:
-            plt.axvline(x=ce, c='k', alpha=0.8)
-    plt.savefig(fig_f, dpi=180)
-    plt.close()
-
 if __name__ == "__main__":
     fit_vae_ncde(n_epochs=2000)