自定义激活函数(自定义反向传播方式)

1
%reset -f

继承torch.autograd.Function类

1
2
3
from torch.autograd import Function
import torch
import torch.nn as nn

实现forward和backward

激活函数为Switch函数:
$$f(x) = x \cdot sigmoid(\beta x)$$
导数:
$$\cfrac{\partial f}{\partial x} = sigmoid(x) + \beta x \cdot sigmoid(x) [1 - sigmoid(x)]$$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Switch(Function):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@staticmethod
def forward(ctx, inputs):
ctx.save_for_backward(inputs)
return inputs * torch.sigmoid(inputs)

@staticmethod
def backward(ctx, grad_output):
inputs, = ctx.saved_tensors
tmp = torch.sigmoid(inputs)
return grad_output * (tmp + inputs*tmp*(1.0-tmp))
1
2
3
4
5
6
7
8
9
10
11
12
class SNN(nn.Module):
def __init__(self):
super(SNN, self).__init__()
self.fc1 = nn.Linear(1, 10)
self.switch = Switch.apply
self.fc2 = nn.Linear(10,1)

def forward(self, x):
x = self.fc1(x)
x = self.switch(x)
x = self.fc2(x)
return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
class NN(nn.Module):
def __init__(self):
super(NN, self).__init__()
self.fc1 = nn.Linear(1, 20)
self.ac = nn.Sigmoid()
self.fc2 = nn.Linear(20,1)


def forward(self, x):
x = self.fc1(x)
x = self.ac(x)
x = self.fc2(x)
x = self.ac(x)
return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from torch.utils.data import Dataset, DataLoader
class Data(Dataset):
def __init__(self, x, y):
super(Data, self).__init__()
self.x = x
self.y = y

def __len__(self):
return len(self.x)

def __getitem__(self, index):
xi = self.x[index]
yi = self.y[index]
xi = torch.unsqueeze(xi, 0)
yi = torch.unsqueeze(yi, 0)
return xi, yi
1
2
3
4
5
6
7
8
9
10
x_data = torch.linspace(-torch.pi, torch.pi, 1000)
y_data = torch.sin(x_data) + 0.2 * torch.randn_like(x_data)
datas = Data(x_data, y_data)
train_data = DataLoader(dataset=datas, batch_size=100, shuffle=True)

net = SNN()
net2 = NN()

loss = nn.MSELoss()
optim = torch.optim.SGD(net.parameters(), lr=0.01)
1
2
3
4
5
6
7
8
9
10
for epoch in range(200):
for X, y in train_data:
y_pred = net(X)
optim.zero_grad()
l = loss(y, y_pred)
l.mean().backward()
optim.step()
if epoch % 10 == 0:
print("Epoch:{}, mean_loss:{}".format(epoch, l.mean()))
print("Finished")
Epoch:0, mean_loss:0.36015474796295166
Epoch:10, mean_loss:0.20154818892478943
Epoch:20, mean_loss:0.16524557769298553
Epoch:30, mean_loss:0.20352555811405182
...
Epoch:180, mean_loss:0.05651158466935158
Epoch:190, mean_loss:0.057117704302072525
Finished
1
2
3
4
from matplotlib import pyplot as plt
plt.scatter(x_data.numpy(), y_data.numpy(), s=5)
y_preds = net.forward(x_data.reshape(-1,1)).reshape(x_data.shape)
plt.plot(x_data.numpy(), y_preds.detach().numpy(), c='r')
[<matplotlib.lines.Line2D at 0x20e1df58cd0>]

png