1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
for epoch in range(epochs):

     for batch_idx, (data, target) in enumerate(train_loader):

        data = data.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        optimizer.zero_grad()

        output = model(data)

        loss = criterion(output, target)

        loss.backward()

        optimizer.step()

Specify non_blocking