Loading src/tgreft/nn/loss.py +3 −0 Original line number Diff line number Diff line Loading @@ -40,6 +40,9 @@ class CompositeLoss(nn.Module): np.apply_along_axis(param_to_rcurve, 1, true_params.detach().cpu().numpy()) ).float() # calculate the loss # cast to log first pred_rcurve = torch.log(pred_rcurve) true_rcurve = torch.log(true_rcurve) loss_rcurve = self.rcurve_loss(pred_rcurve, true_rcurve) # combine the losses loss = self.lambda_param * loss_param + self.lambda_curve * loss_rcurve Loading Loading
src/tgreft/nn/loss.py +3 −0 Original line number Diff line number Diff line Loading @@ -40,6 +40,9 @@ class CompositeLoss(nn.Module): np.apply_along_axis(param_to_rcurve, 1, true_params.detach().cpu().numpy()) ).float() # calculate the loss # cast to log first pred_rcurve = torch.log(pred_rcurve) true_rcurve = torch.log(true_rcurve) loss_rcurve = self.rcurve_loss(pred_rcurve, true_rcurve) # combine the losses loss = self.lambda_param * loss_param + self.lambda_curve * loss_rcurve Loading