diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..225c38155869eaa5f928f05de5bb7681d66a20c5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.png +*.jpg diff --git a/aires_node/neural_cde_driver.py b/aires_node/neural_cde_driver.py index b9556375a71d032d62062f64c19409a1f9f86b07..389d9ace1b8ce024bc90a20eedb9cc3d02fa2ce1 100644 --- a/aires_node/neural_cde_driver.py +++ b/aires_node/neural_cde_driver.py @@ -41,7 +41,8 @@ class NeuralCDEDriver(object): self.model = model self.lr = kwargs.get("lr", 5e-4) self.optimizer = optim.AdamW( - self.model.parameters(), lr=self.lr, weight_decay=1e-5, amsgrad=True) + self.model.parameters(), lr=self.lr, + weight_decay=kwargs.get("weight_decay", 1e-5), amsgrad=True) self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, 'min', patience=20, cooldown=40, min_lr=1e-6, factor=kwargs.get("lr_factor", 0.99)) @@ -103,8 +104,12 @@ class NeuralCDEDriver(object): # log(1) is 0 prior_logvar_z0 = torch.zeros(qz0_logvar.shape) # estimate varience at z0 from network predictions - nvk = 30 - var_hat = torch.zeros((nvk, pred_x.shape[0], pred_x.shape[2])) + nvk = 96 + if hasattr(self.model, 'num_nodes'): + assert len(pred_x.shape) == 4 + var_hat = torch.zeros((nvk, pred_x.shape[0], pred_x.shape[2], pred_x.shape[3])) + else: + var_hat = torch.zeros((nvk, pred_x.shape[0], pred_x.shape[2])) for vk in range(nvk): eps = torch.randn(qz0_mean.size()).to(self.device) z0_rand = eps * torch.exp(.5 * qz0_logvar[:, :]) + qz0_mean[:, :] @@ -120,17 +125,26 @@ class NeuralCDEDriver(object): var_hat_nsamples[:, :, :] = var_hat # compute elbo loss pred_logvar = torch.log(var_hat_nsamples) + x0_loc = int(self.z0_nsamples / self.eval_thin_k) elbo_loss = custom_elbo_loss(pred_x, _targ_x[:, :, :], pred_logvar, qz0_mean, qz0_logvar, targ_mu_z0=prior_mu_z0, targ_logvar_z0=prior_logvar_z0, beta=self.beta_vae) - x0_loc = int(self.z0_nsamples / self.eval_thin_k) x0_loss = log_normal_pdf( - pred_x[:, x0_loc:x0_loc+1, :], _targ_x[:, x0_loc:x0_loc+1, :], pred_logvar[:, x0_loc:x0_loc+1, :]) - x0_loss = -torch.mean(x0_loss.sum(-1).sum(-1), dim=0) - return elbo_loss + self.x0_loss_scale*x0_loss + pred_x[:, x0_loc:x0_loc+1, :], + _targ_x[:, x0_loc:x0_loc+1, :], + pred_logvar[:, x0_loc:x0_loc+1, :]) + xp_loss = log_normal_pdf( + pred_x[:, 0:1, :], + _targ_x[:, 0:1, :], + pred_logvar[:, 0:1, :]) + x0_loss = -torch.mean(x0_loss.sum(tuple(range(1, len(_targ_x.shape)))), dim=0) + xp_loss = -torch.mean(xp_loss.sum(tuple(range(1, len(_targ_x.shape)))), dim=0) + # print(float(elbo_loss), float(x0_loss), float(self.x0_loss_scale*x0_loss)) + x0_weight = float(_targ_x.shape[1]) + return elbo_loss + self.x0_loss_scale*x0_weight*(0.5*x0_loss+0.5*xp_loss) def _ncde_loss(self, pred_x: torch.Tensor, targ_x: torch.Tensor, eval_idx=None): """ @@ -274,6 +288,8 @@ class NeuralCDEDriver(object): # Clean interpolant before saving or loading the model if hasattr(self.model, "zh_cde"): self.model.zh_cde.Xf_t = None + if isinstance(self.model, NeuralCDE): + self.model.func.Xf_t = None if itr % self.checkpoint_epoch == 0: # save model for rollback to best state diff --git a/aires_node/neural_cde_encoders.py b/aires_node/neural_cde_encoders.py index daa62acb58065c2751d560df14e5cedef5c78776..49c296ce14231a753dc86754ca63700680f48963 100644 --- a/aires_node/neural_cde_encoders.py +++ b/aires_node/neural_cde_encoders.py @@ -99,14 +99,16 @@ class RecognitionGRU(torch.nn.Module): ) self.fc = torch.nn.Linear(nhidden, latent_dim * v_mult) - def forward(self, x, hidden_state): + def forward(self, x, hidden_state=None): + # if hidden_state is None: + output, hidden_state = self.gru(x) # NOTE: this does not need to be manually unrolled, x can have # time as an axis. The GRU forward() call returns the next # predicted output in the sequence. # hidden_state = self.init_hidden() # x should have shape (n_batch, n_times, n_features) # hidden_state should have (num_layers, n_batch, nhidden) - output, hidden_state = self.gru(x, hidden_state) + # output, hidden_state = self.gru(x, hidden_state) # output has size n_batch, n_times, n_ output = self.fc(output[:, -1, :]) return output @@ -143,6 +145,9 @@ class RecognitionFCN(torch.nn.Module): v_mult = 2 self.final_tanh = final_tanh self.latent_dim = latent_dim + self.z0_nsamples = z0_nsamples + self.encoder_input_channels = encoder_input_channels + self.input_size = z0_nsamples * encoder_input_channels encoder_layer_list = [] if encoder_depth < 0: # zero-depth layer where initial latent state is simply the inital @@ -188,7 +193,7 @@ def initial_state_encoder_factory( return InitialLatentStateEncoderRNN( input_channels, output_channels, z0_nsamples, latent_dim, is_vae, num_nodes, **kwargs) - elif z0_fn_control == 0: + elif z0_fn_control >= 0: return InitialLatentStateEncoderFCN( input_channels, output_channels, z0_nsamples, latent_dim, is_vae, num_nodes, **kwargs) @@ -212,17 +217,18 @@ class InitialLatentStateEncoderRNN(torch.nn.Module): self.latent_dim = latent_dim self.is_vae_model = is_vae self.num_nodes = num_nodes - self.encoder_type = encoder_kwargs.get("encoder_type", "gru") - encoder_hidden_width = encoder_kwargs.get("encoder_hidden_width", 32) + self.obs_dim = encoder_input_channels * max(num_nodes, 1) + self.encoder_type = kwargs.get("encoder_type", "gru") + encoder_hidden_width = kwargs.get("encoder_hidden_width", 32) assert z0_nsamples >= 3 if self.encoder_type == "gru": self.initial = RecognitionGRU( - latent_dim=self.latent_dim, obs_dim=encoder_input_channels, - nhidden=encoder_hidden_width, nbatch=self.nbatch) + latent_dim=self.latent_dim * max(num_nodes, 1), obs_dim=encoder_input_channels * max(num_nodes, 1), + nhidden=encoder_hidden_width, nbatch=1, estimate_var=is_vae) else: self.initial = RecognitionRNN( - latent_dim=self.latent_dim, obs_dim=encoder_input_channels, - nhidden=encoder_hidden_width, nbatch=self.nbatch) + latent_dim=self.latent_dim, obs_dim=encoder_input_channels * max(num_nodes, 1), + nhidden=encoder_hidden_width, nbatch=1, estimate_var=is_vae) def __call__(self, y0, X_f, times, **kwargs): """ @@ -235,6 +241,7 @@ class InitialLatentStateEncoderRNN(torch.nn.Module): times: Times at which to evaluate the model. """ device = y0.device + batch_dim = y0.shape[0] qz0_mean, qz0_logvar = torch.Tensor([]), torch.Tensor([]) assert len(y0.shape) >= 3 local_times = times - times[0] @@ -242,21 +249,34 @@ class InitialLatentStateEncoderRNN(torch.nn.Module): # X0 = X_f.evaluate(local_times[0:self.z0_nsamples+1]) # samp_trajs = torch.cat((X0[:,:-1,1:], y0), -1) X0 = Xf_evaluate_reshape(X_f, local_times[0:self.z0_nsamples], n_nodes=self.num_nodes)[..., 1:] - samp_trajs = torch.cat((X0, y0), -1)[:] + # samp_trajs = torch.cat((X0, y0), -1)[:] - batch_dim = y0.shape[0] - h = self.initial.init_hidden(batch_dim) + # For a GRU, requires samp_trajs to have shape [n_batch, n_times, n_features] + samp_trajs = torch.cat((X0, y0[:, 0:self.z0_nsamples]), -1)[:, 0:self.z0_nsamples, :] + if self.num_nodes > 0: + assert samp_trajs.shape[2] == self.num_nodes + # flatten the features on each node into single large vector + samp_trajs = samp_trajs.flatten(-2, -1) # Roll time backwards, infer initial latent state from known samples if self.encoder_type == "gru": samp_trajs_rev = torch.flip(samp_trajs, dims=(1,)) - out = self.initial.forward(samp_trajs_rev, h) + out = self.initial.forward(samp_trajs_rev) else: + # initialize hidden state for unrolled RNN + h = self.initial.init_hidden(batch_dim) for t in reversed(range(samp_trajs.shape[1])): # batch, time, nfeatures obs = samp_trajs[:, t, :] out, h = self.initial.forward(obs, h) + # In graph cde, out needs shape [nbatch, n_nodes, n_latent] + if self.num_nodes > 0: + # pytorch GRU does not have a node dimension, must expand output shape + # to [nbatch, n_nodes, n_latent] + out = out.reshape((batch_dim, self.num_nodes, -1)) + assert out.shape[1] == self.num_nodes + qz0_mean, qz0_logvar = out[..., :self.latent_dim], out[..., self.latent_dim:] if self.is_vae_model: # random sampling initial latent state while training @@ -311,10 +331,23 @@ class InitialLatentStateEncoderFCN(torch.nn.Module): assert len(y0.shape) >= 3 local_times = times - times[0] X0 = Xf_evaluate_reshape(X_f, local_times[0:self.z0_nsamples], n_nodes=self.num_nodes)[..., 1:] + # FIXME: takes only the first time, meaning only z0_nsamples==1 is supported at this time - aug_X0 = torch.cat((X0, y0), -1)[:, 0, :] + # aug_X0 = torch.cat((X0, y0[:, 0:self.z0_nsamples]), -1)[:, 0, :] + + aug_X0 = torch.cat((X0, y0[:, 0:self.z0_nsamples]), -1)[:, 0:self.z0_nsamples, :] + if self.num_nodes > 0: + assert aug_X0.shape[2] == self.num_nodes + aug_X0 = aug_X0.permute((0, 2, 1, 3)) + aug_X0 = aug_X0.flatten(2, -1) + else: + aug_X0 = aug_X0.flatten(-2, -1) out = self.initial(aug_X0) + # In graph cde, out needs shape [nbatch, n_nodes, n_latent] + if self.num_nodes > 0: + assert out.shape[1] == self.num_nodes + qz0_mean, qz0_logvar = out[..., :self.latent_dim], out[..., self.latent_dim:] if self.is_vae_model: # random sampling initial latent state while training diff --git a/aires_node/neural_cde_models.py b/aires_node/neural_cde_models.py index c3a3841ed281237143edce34937fd0eae4ac9c35..9e72ac00a377aaeda08479008d2b03f710cd4831 100644 --- a/aires_node/neural_cde_models.py +++ b/aires_node/neural_cde_models.py @@ -40,7 +40,7 @@ class CDEFuncTemporal(torch.nn.Module): def __call__(self, t, h): # control_gradient is of shape (..., input_channels) dX_dt = self.Xf_t.derivative(t) - vector_field_f = self.func_f(h) + vector_field_f = self.func_f(t, h) out = (vector_field_f @ dX_dt.unsqueeze(-1)).squeeze(-1) return out @@ -217,14 +217,11 @@ class NeuralCDE(torch.nn.Module): cdeint_kwargs: arguments passed to the cde integrator. Eg: 'atol' and 'rtol' control desired integration accuracy. z0_fn_control: latent variable initial value mode. - 0 - uses only inital system variables to estimate - inital latent variable values at t0. (default) - 1 - uses only control variable initial values to estimate - inital latent variable values at t0. - 2 - uses both system variables and control variable inital values - to estimate inital latent variable values at t0. + 0 - Uses fully connected network to estimate initial latent state. + -1 - Uses RNN model. May have issue at longer lookback windows. + -2 - Uses GRU model. Typically better than the RNN model. z0_nsamples: number of observations used to infer the inital - values of the latent vars + values of the latent vars. This is similar to lookback in other models. encoder_kwargs: arguments passed to the encoder decoder_kwargs: arguments passed to the decoder """ @@ -253,51 +250,24 @@ class NeuralCDE(torch.nn.Module): self.latent_dim = hidden_channels self.verbose = verbose - # Netowrk that represents the ODE in the latent space - self.func = CDEFunc( + # Netowrk that represents the dynamics in the latent space + self.func = CDEFuncTemporal(CDEFunc( input_channels, hidden_channels, hidden_width, - hidden_depth, **cde_kwargs) - - # Network that maps from input space, X to latent sapace, z - # self.initial = torch.nn.Linear(output_channels, hidden_channels) - encoder_layer_list = [] - encoder_depth = encoder_kwargs.get("encoder_depth", 0) - encoder_hidden_width = encoder_kwargs.get("encoder_hidden_width", 32) - if self.z0_fn_control in (1, -1): - encoder_input_channels = input_channels-1 - elif self.z0_fn_control in (2,): - encoder_input_channels = input_channels-1 + output_channels - elif self.z0_fn_control == -2: - # remove time as feature channel in encoder network - # encoder_input_channels = input_channels + output_channels - encoder_input_channels = input_channels-1 + output_channels - else: - encoder_input_channels = output_channels - if encoder_depth < 0: - encoder_layer_list.append(Pad1D(z0_nsamples * encoder_input_channels, self.latent_dim)) - else: - encoder_layer_list.append(torch.nn.Linear(z0_nsamples * encoder_input_channels, hidden_channels)) - for i in range(int(encoder_depth)): - encoder_layer_list.append(torch.nn.Tanh()) - encoder_layer_list.append(torch.nn.Linear(hidden_channels, self.latent_dim)) + hidden_depth, **cde_kwargs)) - # Init RNN encoder model - self.encoder_type = encoder_kwargs.get("encoder_type", "gru") - if self.z0_fn_control < 0: - assert z0_nsamples >= 3 - if self.encoder_type == "gru": - self.initial = RecognitionGRU( - latent_dim=self.latent_dim, obs_dim=encoder_input_channels, - nhidden=encoder_hidden_width, nbatch=self.nbatch) - else: - self.initial = RecognitionRNN( - latent_dim=self.latent_dim, obs_dim=encoder_input_channels, - nhidden=encoder_hidden_width, nbatch=self.nbatch) - else: - self.initial = torch.nn.Sequential(*encoder_layer_list) + # Initial latent value estimation network + self.initial = initial_state_encoder_factory( + input_channels, + output_channels, + z0_fn_control, + z0_nsamples, + self.latent_dim, + is_vae=self.is_vae_model, + num_nodes=0, + **encoder_kwargs + ) # Network that maps from latent sapace, z to data space (output) - # self.readout = torch.nn.Linear(hidden_channels, output_channels) decoder_layer_list = [] decoder_layer_list.append(torch.nn.Linear(self.latent_dim, output_channels)) for i in range(int(decoder_kwargs.get("decoder_depth", 0))): @@ -305,144 +275,52 @@ class NeuralCDE(torch.nn.Module): decoder_layer_list.append(torch.nn.Linear(output_channels, output_channels)) self.readout = torch.nn.Sequential(*decoder_layer_list) - pytorch_total_params = sum(p.numel() for p in self.parameters() if p.requires_grad) - print("=== Number of trainable parameters in NeuralCDE: %d" % int(pytorch_total_params)) - - # init network weights and biases - if cde_kwargs.get("init_std", True): - for m in self.func.cde_nn.modules(): - if isinstance(m, torch.nn.Linear): - torch.nn.init.normal_( - m.weight, mean=0., std=cde_kwargs.get("w_std", 0.01)) - torch.nn.init.normal_( - m.bias, mean=0., std=cde_kwargs.get("b_std", 0.001)) - - def forward(self, z0, X_f, times, eval_idx=None): + def forward(self, y0, X_f, times, eval_idx=None): """ Forward pass through the NCDE. Integrates the NCDE forward in time. - Computes and returns results are input times. + Computes and returns results at desried times. Args: z0: prior observations dims - (n_batch, n_time, n_features) X_f: :module:`torchcde.LinearInterpolation` callable interpolation function which implements an evaluate method. - times: Times at which to evaluate the model. + times: Desired times at which to evaluate the model. eval_idx: Only evaluate at marked times. """ - # Initial hidden state should be a function of the first - # and/or prior observation(s). qz0_mean, qz0_logvar = torch.Tensor([]), torch.Tensor([]) - assert len(z0.shape) == 3 - assert z0.shape[1] == self.z0_nsamples + assert len(y0.shape) == 3 + assert y0.shape[1] == self.z0_nsamples local_times = times - times[0] assert isinstance(local_times, torch.Tensor) - if self.z0_fn_control == 0: - z0_r = z0.reshape((-1, z0.shape[1]*z0.shape[2])) - z0_0 = self.initial(z0_r) - elif self.z0_fn_control == 1: - # In the torchcde examples, z0 is a funciton of X; - # however, the inital state of the latent vars, z - # may also be a function of the dependant system observables. - X0 = X_f.evaluate(local_times[0:self.z0_nsamples])[:, 0, 1:] - z0_0 = self.initial(X0) - elif self.z0_fn_control == 2: - # X0 = X_f.evaluate(local_times[0:self.z0_nsamples])[:, 0, 1:] - # z_r shape = (batch dim, n_time*n_features) - # z0_r = z0.reshape((-1, z0.shape[1]*z0.shape[2])) - # z0_a = torch.cat((X0, z0_r), 1) - # z0_0 = self.initial(z0_a) - # Use both control and sys var initial condition - X0 = X_f.evaluate(local_times[0:self.z0_nsamples])[:, :, 1:] - aug_X0 = torch.cat((X0, z0), 2)[:, 0, :] - z0_0 = self.initial(aug_X0) - elif self.z0_fn_control < 0: - # Use RNN encoder to build estimate for hidden intial conds, z0 - # Requres z0 to be of length 3 or greater! - # in this case z0 is sequence of observed samples prior to - # forward forecast. - # backward in time to infer q(z_0) - if self.z0_fn_control == -2: - # concat along feature dim - # the 0th feature in X_f.evaluate is reserved for time-as-a-feature - X0 = X_f.evaluate(local_times[0:self.z0_nsamples+1]) - # dt_trajs = X0[:,:-1,:] - X0[:,1:,0] - samp_trajs = torch.cat((X0[:,:-1,1:], z0), 2) - # samp_trajs = torch.cat((X0[:,:-1,1:], z0, dt_trajs), 2) - else: - samp_trajs = z0 - - # Initialize RNN encoder hidden state - batch_dim = z0.shape[0] - h = self.initial.init_hidden(batch_dim) - - # Roll time backwards, infer initial latent state from known samples - if self.encoder_type == "gru": - samp_trajs_rev = torch.flip(samp_trajs, dims=(1,)) - out = self.initial.forward(samp_trajs_rev, h) - else: - for t in reversed(range(samp_trajs.shape[1])): - # batch, time, nfeatures - obs = samp_trajs[:, t, :] - out, h = self.initial.forward(obs, h) - - qz0_mean, qz0_logvar = out[:, :self.latent_dim], out[:, self.latent_dim:] - device = z0.device - if self.training and self.is_vae_model: - # random sampling initial latent state while training - epsilon = torch.randn(qz0_mean.size()).to(device) - else: - # frozen mean initial latent state while testing/eval - # print("====== qz0_mean: %s " % (qz0_mean)) - # print("====== qz0_sd: %s" % (torch.sqrt(torch.exp(qz0_logvar))) ) - epsilon = torch.zeros(qz0_mean.size()).to(device) - # VAE sample initial conditions of latent var - z0_0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean - else: - raise NotImplementedError("Invalid z0 Estimation Method") + self.func.Xf_t = X_f - # z0_0 has shape (n_batch, self.latent_dim) - assert z0_0.shape[1] == self.latent_dim - - # for logging only - if self.verbose: - t_range = X_f.interval - print("X_f.interval: ", t_range, "times range: ", - local_times[0], local_times[-1]) - X_f_vals = X_f.evaluate(local_times) - X_f_times = X_f_vals[:,:,0] - X_f_vars = X_f_vals[:,:,1:] - print('Min control time Xf_t', torch.min(X_f_times), - 'Max control time Xf_t', torch.max(X_f_times)) - print('Min control vars Xf', torch.min(X_f_vars), - 'Max control vars Xf', torch.max(X_f_vars)) + # Predict the initial latent state, z0 + z0, qz0_mean, qz0_logvar = self.initial(y0, X_f, times) + assert z0.shape[1] == self.latent_dim # Only output at selected times. - # This allows adaptive stepping to be more efficient since the - # ODE integration does not need to stop at every support - # point to eval results. if eval_idx is not None: assert len(eval_idx) <= len(local_times) selected_times = local_times[eval_idx] else: selected_times = local_times - # Solve the CDE. - try: - z_T = torchcde.cdeint(X=X_f, - z0=z0_0, - func=self.func, - t=selected_times, - rtol=self.cdeint_kwargs.get("rtol", 1e-5), - atol=self.cdeint_kwargs.get("atol", 1e-6), - backend='torchdiffeq', - adjoint=self.cdeint_kwargs.get("adjoint", False)) - except: - # z_T has shape (n_batches, time, latent_dim) - z_T = torch.zeros( - (z0_0.shape[0], len(selected_times), self.latent_dim)) - print("WARNING: cdeint failed") + # Integrate the NCDE to obtain latent state, z(t), trajectories + if self.cdeint_kwargs.get("adjoint", False): + z_T = odeint_adjoint(func=self.func, y0=z0, t=selected_times, + rtol=self.cdeint_kwargs.get("rtol", 1e-5), + atol=self.cdeint_kwargs.get("atol", 1e-6), + adjoint_rtol=self.cdeint_kwargs.get("adjoint_rtol", 1e-3), + adjoint_atol=self.cdeint_kwargs.get("adjoint_atol", 1e-4), + method=self.cdeint_kwargs.get("method", 'dopri5')) + else: + z_T = odeint(func=self.func, y0=z0, t=selected_times, + rtol=self.cdeint_kwargs.get("rtol", 1e-5), + atol=self.cdeint_kwargs.get("atol", 1e-6), + method=self.cdeint_kwargs.get("method", 'dopri5')) # Decode from latent space to data space + z_T = torch.permute(z_T, (1, 0, 2)) pred_y = self.readout(z_T) return pred_y, qz0_mean, qz0_logvar @@ -450,12 +328,14 @@ class NeuralCDE(torch.nn.Module): """ Save model state to file """ + self.func.Xf_t = None torch.save(self.state_dict(), file_path) def load_model(self, file_path: str): """ Load model from file """ + self.func.Xf_t = None state_dict = torch.load(file_path) self.model.load_state_dict(state_dict) @@ -543,7 +423,7 @@ class NeuralSTG_CDE(torch.nn.Module): h0_fn_control, z0_nsamples, latent_dim, - is_vae=is_vae_model, + is_vae=False, num_nodes=self.num_nodes, **encoder_kwargs ) @@ -600,10 +480,12 @@ class NeuralSTG_CDE(torch.nn.Module): # estimate initial latent states assert self.num_nodes > 0 - h0, _, _ = self.h0_est( + h0, qh0_mean, qh0_logvar = self.h0_est( y0, X_f, times, n_nodes=self.num_nodes) z0, qz0_mean, qz0_logvar = self.z0_est( y0, X_f, times, n_nodes=self.num_nodes) + + # If VAE, z0, h0 are samples from normal dist init_zh0 = (z0, h0) # Only output at selected times. @@ -654,3 +536,232 @@ class NeuralSTG_CDE(torch.nn.Module): self.zh_cde.Xf_t = None state_dict = torch.load(file_path) self.model.load_state_dict(state_dict) + + def _acg_adj_w(self): + """ + Adaptive convolutional network adjacency matrix and weights. + For debug and visualization only. + """ + return self.zh_cde.func_g.normalized_agc_adjacency(), self.zh_cde.func_g.agc_weights() + + +class NeuralCDE_deprecated(torch.nn.Module): + def __init__(self, + input_channels, + hidden_channels, + output_channels, + hidden_width=100, + hidden_depth=1, + z0_fn_control=0, + z0_nsamples=1, + n_batches=1, + cdeint_kwargs={}, + cde_kwargs={}, + encoder_kwargs={}, + decoder_kwargs={}, + is_vae_model=False, + verbose=False, + ): + super(NeuralCDE_deprecated, self).__init__() + self.z0_fn_control = z0_fn_control + self.cdeint_kwargs = cdeint_kwargs + self.nbatch = n_batches + self.z0_nsamples = z0_nsamples + self.is_vae_model = is_vae_model + self.latent_dim = hidden_channels + self.verbose = verbose + + # Netowrk that represents the ODE in the latent space + self.func = CDEFunc( + input_channels, hidden_channels, hidden_width, + hidden_depth, **cde_kwargs) + + # Network that maps from input space, X to latent sapace, z + # self.initial = torch.nn.Linear(output_channels, hidden_channels) + encoder_layer_list = [] + encoder_depth = encoder_kwargs.get("encoder_depth", 0) + encoder_hidden_width = encoder_kwargs.get("encoder_hidden_width", 32) + if self.z0_fn_control in (1, -1): + encoder_input_channels = input_channels-1 + elif self.z0_fn_control in (2,): + encoder_input_channels = input_channels-1 + output_channels + elif self.z0_fn_control == -2: + # remove time as feature channel in encoder network + # encoder_input_channels = input_channels + output_channels + encoder_input_channels = input_channels-1 + output_channels + else: + encoder_input_channels = output_channels + if encoder_depth < 0: + encoder_layer_list.append(Pad1D(z0_nsamples * encoder_input_channels, self.latent_dim)) + else: + encoder_layer_list.append(torch.nn.Linear(z0_nsamples * encoder_input_channels, hidden_channels)) + for i in range(int(encoder_depth)): + encoder_layer_list.append(torch.nn.Tanh()) + encoder_layer_list.append(torch.nn.Linear(hidden_channels, self.latent_dim)) + + # Init RNN encoder model + self.encoder_type = encoder_kwargs.get("encoder_type", "gru") + if self.z0_fn_control < 0: + assert z0_nsamples >= 3 + if self.encoder_type == "gru": + self.initial = RecognitionGRU( + latent_dim=self.latent_dim, obs_dim=encoder_input_channels, + nhidden=encoder_hidden_width, nbatch=self.nbatch) + else: + self.initial = RecognitionRNN( + latent_dim=self.latent_dim, obs_dim=encoder_input_channels, + nhidden=encoder_hidden_width, nbatch=self.nbatch) + else: + self.initial = torch.nn.Sequential(*encoder_layer_list) + + # Network that maps from latent sapace, z to data space (output) + # self.readout = torch.nn.Linear(hidden_channels, output_channels) + decoder_layer_list = [] + decoder_layer_list.append(torch.nn.Linear(self.latent_dim, output_channels)) + for i in range(int(decoder_kwargs.get("decoder_depth", 0))): + decoder_layer_list.append(torch.nn.Tanh()) + decoder_layer_list.append(torch.nn.Linear(output_channels, output_channels)) + self.readout = torch.nn.Sequential(*decoder_layer_list) + + pytorch_total_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + print("=== Number of trainable parameters in NeuralCDE: %d" % int(pytorch_total_params)) + + # init network weights and biases + if cde_kwargs.get("init_std", True): + for m in self.func.cde_nn.modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.normal_( + m.weight, mean=0., std=cde_kwargs.get("w_std", 0.01)) + torch.nn.init.normal_( + m.bias, mean=0., std=cde_kwargs.get("b_std", 0.001)) + + def forward(self, z0, X_f, times, eval_idx=None): + # Initial hidden state should be a function of the first + # and/or prior observation(s). + qz0_mean, qz0_logvar = torch.Tensor([]), torch.Tensor([]) + assert len(z0.shape) == 3 + assert z0.shape[1] == self.z0_nsamples + local_times = times - times[0] + assert isinstance(local_times, torch.Tensor) + if self.z0_fn_control == 0: + z0_r = z0.reshape((-1, z0.shape[1]*z0.shape[2])) + z0_0 = self.initial(z0_r) + elif self.z0_fn_control == 1: + # In the torchcde examples, z0 is a funciton of X; + # however, the inital state of the latent vars, z + # may also be a function of the dependant system observables. + X0 = X_f.evaluate(local_times[0:self.z0_nsamples])[:, 0, 1:] + z0_0 = self.initial(X0) + elif self.z0_fn_control == 2: + # X0 = X_f.evaluate(local_times[0:self.z0_nsamples])[:, 0, 1:] + # z_r shape = (batch dim, n_time*n_features) + # z0_r = z0.reshape((-1, z0.shape[1]*z0.shape[2])) + # z0_a = torch.cat((X0, z0_r), 1) + # z0_0 = self.initial(z0_a) + # Use both control and sys var initial condition + X0 = X_f.evaluate(local_times[0:self.z0_nsamples])[:, :, 1:] + aug_X0 = torch.cat((X0, z0), 2)[:, 0, :] + z0_0 = self.initial(aug_X0) + elif self.z0_fn_control < 0: + # Use RNN encoder to build estimate for hidden intial conds, z0 + # Requres z0 to be of length 3 or greater! + # in this case z0 is sequence of observed samples prior to + # forward forecast. + # backward in time to infer q(z_0) + if self.z0_fn_control == -2: + # concat along feature dim + # the 0th feature in X_f.evaluate is reserved for time-as-a-feature + X0 = X_f.evaluate(local_times[0:self.z0_nsamples+1]) + # dt_trajs = X0[:,:-1,:] - X0[:,1:,0] + samp_trajs = torch.cat((X0[:,:-1,1:], z0), 2) + # samp_trajs = torch.cat((X0[:,:-1,1:], z0, dt_trajs), 2) + else: + samp_trajs = z0 + + # Initialize RNN encoder hidden state + batch_dim = z0.shape[0] + h = self.initial.init_hidden(batch_dim) + + # Roll time backwards, infer initial latent state from known samples + if self.encoder_type == "gru": + samp_trajs_rev = torch.flip(samp_trajs, dims=(1,)) + out = self.initial.forward(samp_trajs_rev, h) + else: + for t in reversed(range(samp_trajs.shape[1])): + # batch, time, nfeatures + obs = samp_trajs[:, t, :] + out, h = self.initial.forward(obs, h) + + qz0_mean, qz0_logvar = out[:, :self.latent_dim], out[:, self.latent_dim:] + device = z0.device + if self.training and self.is_vae_model: + # random sampling initial latent state while training + epsilon = torch.randn(qz0_mean.size()).to(device) + else: + # frozen mean initial latent state while testing/eval + # print("====== qz0_mean: %s " % (qz0_mean)) + # print("====== qz0_sd: %s" % (torch.sqrt(torch.exp(qz0_logvar))) ) + epsilon = torch.zeros(qz0_mean.size()).to(device) + # VAE sample initial conditions of latent var + z0_0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean + else: + raise NotImplementedError("Invalid z0 Estimation Method") + + # z0_0 has shape (n_batch, self.latent_dim) + assert z0_0.shape[1] == self.latent_dim + + # for logging only + if self.verbose: + t_range = X_f.interval + print("X_f.interval: ", t_range, "times range: ", + local_times[0], local_times[-1]) + X_f_vals = X_f.evaluate(local_times) + X_f_times = X_f_vals[:,:,0] + X_f_vars = X_f_vals[:,:,1:] + print('Min control time Xf_t', torch.min(X_f_times), + 'Max control time Xf_t', torch.max(X_f_times)) + print('Min control vars Xf', torch.min(X_f_vars), + 'Max control vars Xf', torch.max(X_f_vars)) + + # Only output at selected times. + # This allows adaptive stepping to be more efficient since the + # ODE integration does not need to stop at every support + # point to eval results. + if eval_idx is not None: + assert len(eval_idx) <= len(local_times) + selected_times = local_times[eval_idx] + else: + selected_times = local_times + + # Solve the CDE. + try: + z_T = torchcde.cdeint(X=X_f, + z0=z0_0, + func=self.func, + t=selected_times, + rtol=self.cdeint_kwargs.get("rtol", 1e-5), + atol=self.cdeint_kwargs.get("atol", 1e-6), + backend='torchdiffeq', + adjoint=self.cdeint_kwargs.get("adjoint", False)) + except: + # z_T has shape (n_batches, time, latent_dim) + z_T = torch.zeros( + (z0_0.shape[0], len(selected_times), self.latent_dim)) + print("WARNING: cdeint failed") + + # Decode from latent space to data space + pred_y = self.readout(z_T) + return pred_y, qz0_mean, qz0_logvar + + def save_model(self, file_path: str): + """ + Save model state to file + """ + torch.save(self.state_dict(), file_path) + + def load_model(self, file_path: str): + """ + Load model from file + """ + state_dict = torch.load(file_path) + self.model.load_state_dict(state_dict) diff --git a/aires_node/neural_cde_vecfields.py b/aires_node/neural_cde_vecfields.py index 2dc78428a9550f31d82967dabf443485b26a6298..ae39be94cfcd496c1e25fdb5d6e1523c5f3d82c4 100644 --- a/aires_node/neural_cde_vecfields.py +++ b/aires_node/neural_cde_vecfields.py @@ -121,6 +121,9 @@ class VectorFieldAGC_g(torch.nn.Module): Choi, J. et al. Graph Neural Controlled Differential Equations for Traffic Forecasting. + Wu, Z. et al. Graph WaveNet for Deep Spatial-Temporal Graph Modeling + https://arxiv.org/abs/1906.00121 + Args: input_channels: number of input features hidden_channels: number of input features @@ -171,6 +174,12 @@ class VectorFieldAGC_g(torch.nn.Module): z = z.tanh() return z + def normalized_agc_adjacency(self): + return F.softmax(F.relu(torch.mm(self.node_embeddings, self.node_embeddings.transpose(0, 1))), dim=1) + + def agc_weights(self): + return self.weights_pool + def agc(self, z): """ Adaptive Graph Convolution diff --git a/aires_node/utils_cde.py b/aires_node/utils_cde.py index 2493c76b7e84f1340fbb113d48bbb1440830b007..39951276227dd175da0f4c6dedbd642c932bbad8 100644 --- a/aires_node/utils_cde.py +++ b/aires_node/utils_cde.py @@ -172,7 +172,10 @@ def predict_ncde(nn_cde, test_t_s, test_z0_s, test_U_s, n_trajectories=30, devic z0_test = torch.unsqueeze(z0_test, 0) if z0_test.shape[0] < n_trajectories: # duplicate along batch dim - z0_test = z0_test.expand(n_trajectories, -1, -1) + try: + z0_test = z0_test.expand(n_trajectories, -1, -1) + except: + z0_test = z0_test.expand(n_trajectories, -1, -1, -1) # if len(test_U_s.shape) == 3: # # ensure batch dim matches diff --git a/aires_node/utils_common.py b/aires_node/utils_common.py index 4b7830a24e0bb5111669270d58da8305cc6110ad..1ef1cd91734cd0e87e2ba787d77eddb1a4666fdb 100644 --- a/aires_node/utils_common.py +++ b/aires_node/utils_common.py @@ -2,6 +2,8 @@ Common utilities shared between multiple modules """ import numpy as np +from sklearn.preprocessing import StandardScaler, MinMaxScaler +import matplotlib.pyplot as plt import torch import numba @@ -363,3 +365,84 @@ def n_trainable_params(model): print("=== Number of trainable parameters in NN: %d" % \ int(pytorch_total_params)) return pytorch_total_params + + +def plot_loss_curve(loss_history, **kwargs): + """ + Plot train/test loss curves + """ + loss_hist = np.asarray(loss_history) + yl = kwargs.get("ylabel", "Loss") + fig_f = kwargs.get("fig_f_train", "loss_train_ncde.png") + curriculum_epochs = kwargs.get("curriculum_epochs", []) + # Plot training loss curve + plt.figure() + plt.plot(loss_hist[:, 0], loss_hist[:, 2], label="Train Loss") + plt.legend() + plt.xlabel("Epoch") + plt.ylabel(yl) + plt.grid(ls='--', alpha=0.5) + if len(curriculum_epochs) > 0: + for ce in curriculum_epochs: + plt.axvline(x=ce, c='k', alpha=0.8) + plt.savefig(fig_f, dpi=180) + plt.close() + + # Plot test loss curve if data is available + if loss_hist.shape[1] > 3 and kwargs.get("plot_test_loss", True): + fig_f = kwargs.get("fig_f_test", "loss_test_ncde.png") + plt.figure() + plt.plot(loss_hist[:, 0], loss_hist[:, 3], label="Test Loss") + plt.legend() + plt.xlabel("Epoch") + plt.ylabel(yl) + plt.grid(ls='--', alpha=0.5) + if len(curriculum_epochs) > 0: + for ce in curriculum_epochs: + plt.axvline(x=ce, c='k', alpha=0.8) + plt.savefig(fig_f, dpi=180) + plt.close() + + +def multicase_transform(cases_x_list, std_scale=False): + """ + Perform minmax or standard scaling on data and return + the transformed data and the fitted transform to + allow inverse of this transform to be performed in post. + + Args: + cases_x_list: Input data has shape: [n_batch, n_times, n_features] + """ + if std_scale: + X_tr = StandardScaler() + else: + X_tr = MinMaxScaler() + tmp_x = np.concatenate(cases_x_list) + X_tr.fit(tmp_x) + out_x = [] + for case_x in cases_x_list: + out_x.append(X_tr.transform(case_x)) + return out_x, X_tr + + +def multicase_transform_graph(cases_x_list, std_scale=False): + """ + Similar to multicase_transform but for graph data. + + Args: + cases_x_list: Input data has shape: [n_batch, n_times, n_nodes, n_features] + """ + num_nodes = cases_x_list[0].shape[-2] + num_feat = cases_x_list[0].shape[-1] + if std_scale: + X_tr = StandardScaler() + else: + X_tr = MinMaxScaler() + tmp_x = np.concatenate(cases_x_list).reshape(-1, num_nodes*num_feat) + X_tr.fit(tmp_x) + out_x = [] + for case_x in cases_x_list: + tr_x = X_tr.transform(case_x.reshape(-1, num_nodes*num_feat)) + out_x.append(tr_x.reshape(-1, num_nodes, num_feat)) + return out_x, X_tr + diff --git a/aires_node/utils_loss.py b/aires_node/utils_loss.py index 6ec4ab691d40419ac2335b997a5480892fea0a1d..c3ca07c0b3bfccfa7620e1351b2379a324054290 100644 --- a/aires_node/utils_loss.py +++ b/aires_node/utils_loss.py @@ -94,9 +94,23 @@ def custom_elbo_loss(x_pred, x_true, pred_logvar, pred_mu_z0, pred_logvar_z0, ta Computes -L(p(x)) + \beta*KL(q(q_0)||p(z_0)) """ # sum along time and feature dim - log_px = log_normal_pdf(x_pred, x_true, pred_logvar).sum(-1).sum(-1) + # log_px = log_normal_pdf(x_pred, x_true, pred_logvar).sum(-1).sum(-1) + log_px = log_normal_pdf(x_pred, x_true, pred_logvar).sum(tuple(range(1, len(x_true.shape)))) # sum along z dim (n_latent vars) - prior_kl = kl_normal(pred_mu_z0, pred_logvar_z0, targ_mu_z0, targ_logvar_z0).sum(-1) + # prior_kl = kl_normal(pred_mu_z0, pred_logvar_z0, targ_mu_z0, targ_logvar_z0).sum(-1) + prior_kl = kl_normal(pred_mu_z0, pred_logvar_z0, targ_mu_z0, targ_logvar_z0).sum( + tuple(range(1, len(pred_mu_z0.shape)))) # avg along batch dim + # kl_term = beta * torch.mean(prior_kl, dim=0) + # gnll_term = torch.mean(-log_px, dim=0) + # elbo_loss = gnll_term + kl_term elbo_loss = torch.mean(-log_px + beta * prior_kl, dim=0) + #gnll_loss_fn = torch.nn.GaussianNLLLoss(reduction='none', full=True) + #gnll_loss_fns = torch.nn.GaussianNLLLoss(reduction='none', full=True) + # gnll_torch = gnll_loss_fn(x_pred, x_true, torch.exp(pred_logvar)) + #gnll_torch = gnll_loss_fn(x_pred, x_true, torch.exp(pred_logvar)).sum(tuple(range(1, len(x_true.shape)))) + #gnll_torch = torch.mean(gnll_torch, dim=0) + # import pdb; pdb.set_trace() + #elbo_alt_loss = gnll_torch + kl_term + #return elbo_alt_loss return elbo_loss diff --git a/examples/pred_prey_new_ncde.py b/examples/pred_prey_new_ncde.py index b2300685be4ded147d5ed901e5f3c6220166a7f2..a7faa6f9a1bebf68dde1c05cb9eb6c6fa8262a84 100644 --- a/examples/pred_prey_new_ncde.py +++ b/examples/pred_prey_new_ncde.py @@ -6,8 +6,8 @@ from sklearn.model_selection import train_test_split from aires_node.utils_common import BatchGenerator, n_trainable_params from aires_node.neural_cde_driver import NeuralCDEDriver from aires_node.neural_cde_models import NeuralCDE +from aires_node.utils_common import plot_loss_curve from pred_prey_sys import node_lv_ctrl_sys -from pred_prey_vae_ncde import plot_loss_curve import torch diff --git a/examples/pred_prey_vae_ncde.py b/examples/pred_prey_vae_ncde.py index 31f75d95d293ce77abc040d5df0293527b63ea1b..70d318fc34c8f51393e40bd6fea8faf19448ea76 100644 --- a/examples/pred_prey_vae_ncde.py +++ b/examples/pred_prey_vae_ncde.py @@ -6,7 +6,8 @@ import torch # internal imports from aires_node.utils_common import BatchGenerator, n_trainable_params from aires_node.neural_cde_driver import NeuralCDEDriver -from aires_node.neural_cde_models import NeuralCDE +from aires_node.neural_cde_models import NeuralCDE, NeuralCDE_deprecated +from aires_node.utils_common import plot_loss_curve from pred_prey_sys import node_lv_ctrl_sys @@ -37,9 +38,10 @@ def fit_vae_ncde(n_epochs=6000, **kwargs): cdeint_kwargs = {'rtol': 1e-3, 'atol': 1e-4} encoder_kwargs = {'encoder_depth': 1, 'encoder_hidden_width': 24, - 'encoder_type': "gru"} + 'encoder_type': "rnn"} decoder_kwargs = {'decoder_depth': 0} z0_nsamples = 12 + # model = NeuralCDE_deprecated(input_channels=3, hidden_channels=8, model = NeuralCDE(input_channels=3, hidden_channels=8, output_channels=2, hidden_width=60, hidden_depth=2, z0_fn_control=-2, @@ -73,7 +75,10 @@ def fit_vae_ncde(n_epochs=6000, **kwargs): driver.load_best_test_state() # Plot training curve - plot_loss_curve(driver.loss_history, curriculum_epochs=curriculum_epochs) + plot_loss_curve(driver.loss_history, curriculum_epochs=curriculum_epochs, + ylabel="ELBO Loss", + fig_f_train="loss_train_pred_prey_vae_ncde.png", + fig_f_test="loss_test_pred_prey_vae_ncde.png") # Plot results on test set pred_y_trajs, _z0_mu, _z0_logvar = driver.eval(t_s_test, X_s_test, u_s_test, nt=30) @@ -114,34 +119,5 @@ def plot_test_data(model, t_test, pred_y_trajs, y_test, u_test, **kwargs): plt.close() -def plot_loss_curve(loss_history, **kwargs): - fig_f = kwargs.get("fig_f", "loss_train_pred_prey_vae_ncde.png") - curriculum_epochs = kwargs.get("curriculum_epochs", []) - plt.figure() - loss_hist = np.asarray(loss_history) - plt.plot(loss_hist[:, 0], loss_hist[:, 2], label="Train Loss") - plt.legend() - plt.xlabel("Epoch") - plt.ylabel("ELBO Loss") - plt.grid(ls='--', alpha=0.5) - if len(curriculum_epochs) > 0: - for ce in curriculum_epochs: - plt.axvline(x=ce, c='k', alpha=0.8) - plt.savefig(fig_f, dpi=180) - plt.close() - fig_f = kwargs.get("fig_f_test", "loss_test_pred_prey_vae_ncde.png") - plt.figure() - loss_hist = np.asarray(loss_history) - plt.plot(loss_hist[:, 0], loss_hist[:, 3], label="Test Loss") - plt.legend() - plt.xlabel("Epoch") - plt.ylabel("ELBO Loss") - plt.grid(ls='--', alpha=0.5) - if len(curriculum_epochs) > 0: - for ce in curriculum_epochs: - plt.axvline(x=ce, c='k', alpha=0.8) - plt.savefig(fig_f, dpi=180) - plt.close() - if __name__ == "__main__": fit_vae_ncde(n_epochs=2000)