class ContrastiveLoss(torch.nn.Module):
def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, outputA, outputB, y):
euclidean_distance = F.pairwise_distance(outputA, outputB, keepdim = True)
same_class_loss = (1-y) * (euclidean_distance**2)
diff_class_loss = (y) * (torch.clamp(self.margin - euclidean_distance, min=0.0)**2)
return torch.mean(same_class_loss + diff_class_loss)