1
%reset -f

pytorch制作和包装数据集

1
2
import torch
from matplotlib import pyplot as plt
1
2
3
4
x = torch.linspace(-torch.pi, torch.pi, steps=1000)
y = torch.sin(x) + torch.randn_like(x) * 0.3
plt.scatter(x.numpy(), y.numpy(), s=5)
print("x_shape:{}, y_shape:{}".format(x.shape, y.shape))
x_shape:torch.Size([1000]), y_shape:torch.Size([1000])

png

继承Dataset类,包装成Dataset类型的数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, x, y):
super(MyDataset, self).__init__()
self.x = x
self.y = y

def __len__(self):
return len(x)

def __getitem__(self, index):
xi = torch.unsqueeze(x[index], 0) # 保证取出来的是一组data,而不是一组数
yi = torch.unsqueeze(y[index], 0)
return xi, yi

mydataset = MyDataset(x, y)
next(iter(mydataset))
(tensor([-3.1416]), tensor([-0.0934]))

使用DataLoader包装Data

1
2
3
4
5
6
from torch.utils.data import DataLoader
train_data = DataLoader(dataset=mydataset, batch_size=5, shuffle=True)
for X_train, y_train in train_data:
print(X_train)
print(y_train)
break
tensor([[-3.1290],
        [ 0.0723],
        [ 2.8460],
        [-1.5629],
        [-0.9843]])
tensor([[-0.0016],
        [-0.3829],
        [ 0.1528],
        [-0.6926],
        [-0.7684]])
1
%reset -f

包装图片

1
2
3
4
5
6
7
8
9
10
11
12
13
# 初始数据给的是:图片路径,图片类别(此处以名称为例)
imgs_path = './data/'

import os

img_and_label_list = []

for file in os.listdir(imgs_path):
file_path = imgs_path + '/' + file
label = file[3:5]
img_and_label_list.append((file_path, label))

img_and_label_list
[('./data//00028.png', '28'),
 ('./data//00029.png', '29'),
 ('./data//00030.png', '30'),
 ('./data//00031.png', '31'),
 ('./data//00032.png', '32')]

包装数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from torch.utils.data import Dataset
from PIL import Image
import torch
from torchvision import transforms
from matplotlib import pyplot as plt

class ImageData(Dataset):
def __init__(self, lists):
super(ImageData, self).__init__()
self.data_lists = lists

def __len__(self):
return len(self.data_lists)

def __getitem__(self, index):
img_path, label = self.data_lists[index]
img = Image.open(img_path)
img = transforms.ToTensor()(img)
label = torch.tensor(int(label)).unsqueeze(0)
return img, label

img_data = ImageData(img_and_label_list)
img1, label1 = next(iter(img_data))
print(label1)
print(img1.shape)

tensor([28])
torch.Size([1, 1024, 1272])
1
plt.imshow(img1[0], cmap='gray')
<matplotlib.image.AxesImage at 0x2163abfd0d0>

png