Source code for zea.backend.torch.losses

"""Container for custom loss functions."""

import torch
from torch import nn


[docs] class SMSLE(nn.Module): """Loss function for calculating the Signed-Mean-Squared-Logarithmic-Error. This loss function calculates the the mean squared error on log-scaled data, and then takes the sign of the difference between the predicted and ground truth values into account. See https://doi.org/10.1109/TMI.2020.3008537 for more information. """ def __init__(self, dynamic_range=60): super().__init__() self.dynamic_range = dynamic_range
[docs] def forward(self, y_true, y_pred): """ Args: y_true (tensor): Ground truth values. y_pred (tensor): The predicted values. returns: loss (tensor): SMSLE loss value. """ y_pred_max = torch.max(torch.abs(y_pred)) y_true_max = torch.max(torch.abs(y_true)) first_log_pos = torch.clamp( 20 * torch.log(torch.clamp(y_pred / y_pred_max, min=torch.finfo(torch.float32).eps) + 0.0) / torch.log(torch.tensor(10.0)), -self.dynamic_range, 0, ) secon_log_pos = torch.clamp( 20 * torch.log(torch.clamp(y_true / y_true_max, min=torch.finfo(torch.float32).eps) + 0.0) / torch.log(torch.tensor(10.0)), -self.dynamic_range, 0, ) first_log_neg = torch.clamp( 20 * torch.log(torch.clamp(-y_pred / y_pred_max, min=torch.finfo(torch.float32).eps) + 0.0) / torch.log(torch.tensor(10.0)), -self.dynamic_range, 0, ) secon_log_neg = torch.clamp( 20 * torch.log(torch.clamp(-y_true / y_true_max, min=torch.finfo(torch.float32).eps) + 0.0) / torch.log(torch.tensor(10.0)), -self.dynamic_range, 0, ) loss = 0.5 * torch.mean(torch.square(first_log_pos - secon_log_pos)) + 0.5 * torch.mean( torch.square(first_log_neg - secon_log_neg) ) return loss