train.py
def train(model,
          device,
          train_loader,
          optimizer,
          criterion,
          epoch):
          
    model.train()

    loss_value = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        output = model(data)

        loss = criterion(output, target)

        loss.backward()

        loss_value += loss.item()

        optimizer.step()

    print(f'Train Epoch: {epoch} \t Loss: {loss_value:.6f}')