train.py
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)

start = time()

num_epochs = 5

for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, criterion, epoch, non_blocking=True)
print(f"Total time = {time() - start}")
Train Epoch: 1    Loss: 250.351424
Train Epoch: 2    Loss: 105.314569
Train Epoch: 3    Loss: 72.800107
Train Epoch: 4    Loss: 56.309036
Train Epoch: 5    Loss: 45.047929
Total time = 9.155