# Knowledge distillation loss (KL divergence)
def KL_loss(student_logits, teacher_logits):
    
    # convert teacher model outputs to probabilities 
    p_teacher = F.softmax(teacher_logits, dim=1)
    
    # convert student model outputs to probabilities 
    p_student = F.log_softmax(student_logits, dim=1)
    
    # compute KL divergence loss (PyTorch's method)
    loss = F.kl_div(p_student, p_teacher, reduction='batchmean')
    
    return loss