import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
classData(Dataset): def__init__(self, x, y): super(Data, self).__init__() self.x = x self.y = y def__len__(self): returnlen(self.x) def__getitem__(self, index): xi = self.x[index] yi = self.y[index] xi = torch.unsqueeze(xi, 0) yi = torch.unsqueeze(yi, 0) return xi, yi
losses = [] from matplotlib import pyplot as plt plt.scatter(x_data.numpy(), y_data.numpy(), s=5)
for epoch inrange(200): for X, y in train_data: y_pred = net(X) optim.zero_grad() l = loss(y, y_pred) l.mean().backward() optim.step() losses.append(l.mean().detach().numpy()) if epoch % 10 == 0: print("Epoch:{}, mean_loss:{}".format(epoch, l.mean()))