def train(model,
device,
train_loader,
optimizer,
criterion,
epoch):
model.train()
loss_value = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
loss_value += loss.item()
optimizer.step()
print(f'Train Epoch: {epoch} \t Loss: {loss_value:.6f}')