StudentModel.py
for epoch in range(5):
    # set to train mode
    student_model.train()
    
    epoch_loss = 0.0
    
    # train for all batches of data
    for data in trainloader:
        inputs, labels = data
        student_optimizer.zero_grad()
        
        # get student outputs
        student_logits = student_model(inputs)
        
        # get teacher outputs and detach them
        # to avoid backpropagation
        teacher_logits = teacher_model(inputs).detach()  
        
        # compute KL Divergence loss       
        loss = KL_loss(student_logits, teacher_logits)
        
        # run backpropagation step
        loss.backward()
        student_optimizer.step()