# Knowledge distillation loss (KL divergence)
def KL_loss(student_logits, teacher_logits):
# convert teacher model outputs to probabilities
p_teacher = F.softmax(teacher_logits, dim=1)
# convert student model outputs to probabilities
p_student = F.log_softmax(student_logits, dim=1)
# compute KL divergence loss (PyTorch's method)
loss = F.kl_div(p_student, p_teacher, reduction='batchmean')
return loss