diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 008d789402..3fb7cb9e19 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -160,6 +160,10 @@ def metric_fn(self, pred, label): if self.metric in ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) + elif self.metric == "mse": + mask = ~torch.isnan(label) + weight = torch.ones_like(label) + return -self.mse(pred[mask], label[mask], weight[mask]) raise ValueError("unknown metric `%s`" % self.metric)