siamese_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class SiameseDataset(Dataset):

    def __getitem__(self, index):
    
        imgA, labelA = self.data[index]
            
        same_class_flag = random.randint(0, 1) # pair with same class?
        
        if same_class_flag: # yes, pair with same class
            labelB = -1
            while labelB != labelA:
                imgB, labelB = random.choice(self.data)
                
        else: # no, pair with different class
            labelB = labelA
            while labelB == labelA:
                imgB, labelB = random.choice(self.data)

        if self.transform:
            imgA = self.transform(imgA)
            imgB = self.transform(imgB)
            
        pair_label = torch.tensor([(labelA != labelB)], dtype=torch.float32)
            
        return imgA, imgB, pair_label

Image pairs