train.py
# Define data transformations: convert to tensor and normalize
transform = transforms.Compose([transforms.ToTensor(),
                      transforms.Normalize((0.1307,), (0.3081,))]
                      )

# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data',
                               train=True, 
                               download=True, 
                               transform=transform)

test_dataset = datasets.MNIST(root='./data',
                              train=False,
                              download=True,
                              transform=transform)