class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, stride=2)
)
self.fc = nn.Sequential(
nn.Linear(256 * 3 * 3, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 2)
)
def forward_once(self, x):
output = self.cnn(x)
output = output.view(output.size()[0], -1)
output = self.fc(output)
return output
def forward(self, inputA, inputB):
outputA = self.forward_once(inputA)
outputB = self.forward_once(inputB)
return outputA, outputB
A CNN sub-network
A Fully connected sub-network
Embeddings
for both images