1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
| import os from PIL import Image import numpy as np import torch from torch.utils import data from torch.utils.data import DataLoader from torchvision import transforms
transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ])
class DogCat(data.Dataset): def __init__(self, root, transform=None): imgs = os.listdir(root) self.imgs = [os.path.join(root, img) for img in imgs] self.transform = transform def __getitem__(self, index): ''' 返回一条数据或样本 ''' img_path = self.imgs[index] label = 0 if 'dog' in img_path.split('/')[-1] else 1 data = Image.open(img_path) if self.transform: data = self.transform(data) return data, label
def __len__(self): ''' 返回样本的数量 ''' return len(self.imgs)
dataset = DogCat('DogCat/data/', transform=transform)
from torch.utils.data.sampler import WeightedRandomSampler
weight = [0.5 if label == 1 else 1 for data, label in dataset]
sampler = WeightedRandomSampler(weight, num_samples = 5, replacement=True )
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler, num_workers=0 )
for datas, labels in dataloader: print(labels.tolist())
|