如何使用训练好的PyTorch图像分类器模型

缘由

最近无聊就来尝试跑跑图片分类的模型,发现模型优化真的是一件挺重要的事情。但突然想到我们一直都是训练,怎么使用呢?就找到了方法!

思路

读取图片并将图片转化成和训练数据一样的格式

#-*-coding:utf-8-*-
import torch
from torchvision import transforms
from PIL import Image

img_path = "./trunk.jpeg"  


transform1 = transforms.Compose([
    transforms.CenterCrop((32,32)), # 只能对PIL图片进行裁剪
    transforms.ToTensor()
    ]
)

## PIL图片与Tensor互转
img_PIL = Image.open(img_path).convert('RGB')
# img_PIL.show() # 原始图片

img_PIL = img_PIL.resize((32,32),Image.NEAREST)
img_PIL.show()

img_PIL_Tensor = transform1(img_PIL)
print(type(img_PIL))
print(type(img_PIL_Tensor))

用模型获取结果

# 恢复模型并测试
net = Net()
net.load_state_dict(torch.load(PATH))

inputs = img_PIL_Tensor.unsqueeze(0)

new_outputs = net(inputs)


_, new_predicted = torch.max(new_outputs, 1)
print(classes[new_predicted])

总代码

#-*-coding:utf-8-*-
#-*-coding:utf-8-*-
import torch
from torchvision import transforms
from PIL import Image

img_path = "./trunk.jpeg"  


transform1 = transforms.Compose([
    transforms.CenterCrop((32,32)), # 只能对PIL图片进行裁剪
    transforms.ToTensor()
    ]
)

## PIL图片与Tensor互转
img_PIL = Image.open(img_path).convert('RGB')
# img_PIL.show() # 原始图片

img_PIL = img_PIL.resize((32,32),Image.NEAREST)
img_PIL.show()

img_PIL_Tensor = transform1(img_PIL)
print(type(img_PIL))
print(type(img_PIL_Tensor))

# 恢复模型并测试
net = Net()
net.load_state_dict(torch.load(PATH))

inputs = img_PIL_Tensor.unsqueeze(0)

new_outputs = net(inputs)


_, new_predicted = torch.max(new_outputs, 1)
print(classes[new_predicted])


其他


Comments

Leave a Reply

Your email address will not be published. Required fields are marked *