for epoch in range(5):
# set to train mode
student_model.train()
epoch_loss = 0.0
# train for all batches of data
for data in trainloader:
inputs, labels = data
student_optimizer.zero_grad()
# get student outputs
student_logits = student_model(inputs)
# get teacher outputs and detach them
# to avoid backpropagation
teacher_logits = teacher_model(inputs).detach()
# compute KL Divergence loss
loss = KL_loss(student_logits, teacher_logits)
# run backpropagation step
loss.backward()
student_optimizer.step()