缘由
最近无聊就来尝试跑跑图片分类的模型,发现模型优化真的是一件挺重要的事情。但突然想到我们一直都是训练,怎么使用呢?就找到了方法!
思路
读取图片并将图片转化成和训练数据一样的格式
#-*-coding:utf-8-*-
import torch
from torchvision import transforms
from PIL import Image
img_path = "./trunk.jpeg"
transform1 = transforms.Compose([
transforms.CenterCrop((32,32)), # 只能对PIL图片进行裁剪
transforms.ToTensor()
]
)
## PIL图片与Tensor互转
img_PIL = Image.open(img_path).convert('RGB')
# img_PIL.show() # 原始图片
img_PIL = img_PIL.resize((32,32),Image.NEAREST)
img_PIL.show()
img_PIL_Tensor = transform1(img_PIL)
print(type(img_PIL))
print(type(img_PIL_Tensor))
用模型获取结果
# 恢复模型并测试
net = Net()
net.load_state_dict(torch.load(PATH))
inputs = img_PIL_Tensor.unsqueeze(0)
new_outputs = net(inputs)
_, new_predicted = torch.max(new_outputs, 1)
print(classes[new_predicted])
总代码
#-*-coding:utf-8-*-
#-*-coding:utf-8-*-
import torch
from torchvision import transforms
from PIL import Image
img_path = "./trunk.jpeg"
transform1 = transforms.Compose([
transforms.CenterCrop((32,32)), # 只能对PIL图片进行裁剪
transforms.ToTensor()
]
)
## PIL图片与Tensor互转
img_PIL = Image.open(img_path).convert('RGB')
# img_PIL.show() # 原始图片
img_PIL = img_PIL.resize((32,32),Image.NEAREST)
img_PIL.show()
img_PIL_Tensor = transform1(img_PIL)
print(type(img_PIL))
print(type(img_PIL_Tensor))
# 恢复模型并测试
net = Net()
net.load_state_dict(torch.load(PATH))
inputs = img_PIL_Tensor.unsqueeze(0)
new_outputs = net(inputs)
_, new_predicted = torch.max(new_outputs, 1)
print(classes[new_predicted])