实现基本的分类任务

bluesky1年前 ⋅ 1114 阅读
import torch 
from torch import nn 
import torchvision 
import torchvision.transforms as transforms

# 定义数据预处理方式,将PIL Image或者numpy.ndarray转换成tensor格式
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean=(0.5, ), std=(0.5,))])

# 使用pytorch自带的torchvision加载cifar10数据库
train_data = torchvision.datasets.CIFAR10(root="../data/", 
                                         train=True, 
                                         transform=transform, 
                                         download=True)

# 构建DataLoader,用于每一个batch做数据对应
trainloader = torch.utils.data.DataLoader(train_data,
                                         batch_size=64,
                                         shuffle=True)

# 构建卷积神经网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

# 实例化一个CNN网络
net = CNN()
# 定义损失函数和反向传播的优化函数
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 一轮训练
for epoch in range(2):
  
  running_loss = 0.0
  for i, data in enumerate(trainloader, 0):
    inputs, labels = data
    # 梯度清零
    optimizer.zero_grad()
    # 计算输出
    outputs = net(inputs)
    # 计算损失
    loss = criterion(outputs, labels)
    # 反向传播
    loss.backward()
    # 更新权重
    optimizer.step()

    # 打印log
    running_loss += loss.item()
    if i % 2000 == 1999:
      print('[%d, %5d] loss: %.3f' %
            (epoch + 1, i + 1, running_loss / 2000))
      running_loss = 0.0

print('Finished Training')


全部评论: 0

    相关推荐