Commit 291e7563 authored by Zhang, Chen's avatar Zhang, Chen
Browse files

cast rcurve to log before cal loss

parent 3a518b5b
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -40,6 +40,9 @@ class CompositeLoss(nn.Module):
            np.apply_along_axis(param_to_rcurve, 1, true_params.detach().cpu().numpy())
        ).float()
        # calculate the loss
        # cast to log first
        pred_rcurve = torch.log(pred_rcurve)
        true_rcurve = torch.log(true_rcurve)
        loss_rcurve = self.rcurve_loss(pred_rcurve, true_rcurve)
        # combine the losses
        loss = self.lambda_param * loss_param + self.lambda_curve * loss_rcurve