pytorch制作和包装数据集
1 2
| import torch from matplotlib import pyplot as plt
|
1 2 3 4
| x = torch.linspace(-torch.pi, torch.pi, steps=1000) y = torch.sin(x) + torch.randn_like(x) * 0.3 plt.scatter(x.numpy(), y.numpy(), s=5) print("x_shape:{}, y_shape:{}".format(x.shape, y.shape))
|
x_shape:torch.Size([1000]), y_shape:torch.Size([1000])

继承Dataset类,包装成Dataset类型的数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self, x, y): super(MyDataset, self).__init__() self.x = x self.y = y
def __len__(self): return len(x) def __getitem__(self, index): xi = torch.unsqueeze(x[index], 0) yi = torch.unsqueeze(y[index], 0) return xi, yi
mydataset = MyDataset(x, y) next(iter(mydataset))
|
(tensor([-3.1416]), tensor([-0.0934]))
使用DataLoader包装Data
1 2 3 4 5 6
| from torch.utils.data import DataLoader train_data = DataLoader(dataset=mydataset, batch_size=5, shuffle=True) for X_train, y_train in train_data: print(X_train) print(y_train) break
|
tensor([[-3.1290],
[ 0.0723],
[ 2.8460],
[-1.5629],
[-0.9843]])
tensor([[-0.0016],
[-0.3829],
[ 0.1528],
[-0.6926],
[-0.7684]])
包装图片
1 2 3 4 5 6 7 8 9 10 11 12 13
| imgs_path = './data/'
import os
img_and_label_list = []
for file in os.listdir(imgs_path): file_path = imgs_path + '/' + file label = file[3:5] img_and_label_list.append((file_path, label))
img_and_label_list
|
[('./data//00028.png', '28'),
('./data//00029.png', '29'),
('./data//00030.png', '30'),
('./data//00031.png', '31'),
('./data//00032.png', '32')]
包装数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| from torch.utils.data import Dataset from PIL import Image import torch from torchvision import transforms from matplotlib import pyplot as plt
class ImageData(Dataset): def __init__(self, lists): super(ImageData, self).__init__() self.data_lists = lists
def __len__(self): return len(self.data_lists) def __getitem__(self, index): img_path, label = self.data_lists[index] img = Image.open(img_path) img = transforms.ToTensor()(img) label = torch.tensor(int(label)).unsqueeze(0) return img, label img_data = ImageData(img_and_label_list) img1, label1 = next(iter(img_data)) print(label1) print(img1.shape)
|
tensor([28])
torch.Size([1, 1024, 1272])
1
| plt.imshow(img1[0], cmap='gray')
|
<matplotlib.image.AxesImage at 0x2163abfd0d0>
