PyTorch:xor分类问题
先贴代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Net(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.fc1=nn.Linear(2,2)
self.fc2=nn.Linear(2,1)
def forward(self,x):
input=F.relu(self.fc1(x))
return self.fc2(input)
device = torch.device('cuda')
net=Net().to(device)
optimizer=optim.SGD(net.parameters(),lr=0.01)
criterion=nn.L1Loss()
datas=torch.tensor([[1,1,0],
[1,0,1],
[0,1,1],
[0,0,0],
]).to(device).float()
for i in range(5000):
for data in datas:
optimizer.zero_grad()
out=net(data[:2])
loss=criterion(out,data[-1])
loss.backward()
optimizer.step()
# print(loss)
for data in datas:
print(net(data[:2]).round())
详解:
代码简述:上述代码使用了两层线性神经网络,并使用relu激活函数连接。使用梯度下降算法(SDG)和L1Loss损失函数优化网络。
什么是xor?
xor,即异或:同则零,异为一
为什么要使用xor来作为神经网络学习案例?
xor问题是传统感知机无法实现的,是”单线性不可分“问题,故适合作为经典的神经网络学习案例
库导入:
import torch
import torch.nn as nn #模型
import torch.nn.functional as F #激活函数
import torch.optim as optim #优化器
定义网络
class Net(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.fc1=nn.Linear(2,2)
self.fc2=nn.Linear(2,1)
def forward(self,x):
input=F.relu(self.fc1(x))
return self.fc2(input)
初始化函数用于定义网络,forword函数用于前向传播,这是网络模型的基础函数
在初始化方法中,我们定义了两层线性神经网络,第一层为二输入二输出,第二层为2输入1输出
上一层网络的输出数量必须等于下一层网络的输入数量
事实上,我们可以定义任意层网络,只要保证网络整体是二输入一输出即可,但对于xor这种简单问题,我们并不需要太多层网络,以减轻计算负担。
我们要根据实际问题的复杂程度来定义网络的拓扑结构
实例化网络,优化器,损失函数和数据集
device = torch.device('cuda')
net=Net().to(device)
optimizer=optim.SGD(net.parameters(),lr=0.01)
criterion=nn.L1Loss()
datas=torch.tensor([[1,1,0],
[1,0,1],
[0,1,1],
[0,0,0],
]).to(device).float()
这里第一行device可以选择cuda或cpu,前者需要n卡支持,用于加速训练,但事实上如此简单的任务cpu即可完成,无需交由显卡计算。以及将计算任务提交至显卡时也需要一定时间,此处可优化
在定义优化器时可以定义学习率,学习率与训练效率息息相关,学习率过高可能导致模型不收敛,学习率过低可能导致学习速度太慢,建议根据学习效果实时调整学习率
训练
for i in range(5000):
for data in datas:
optimizer.zero_grad()
out=net(data[:2])
loss=criterion(out,data[-1])
loss.backward()
optimizer.step()
# print(loss)
遍历数据集训练,这里循环了5000次
训练时:
首先将参数清零,否则会影响后续调整
前向传播
计算损失
根据损失反向传播
迭代优化器
检验成果
for data in datas:
print(net(data[:2]).round())