torch 中各种图像格式转换的实现方法
- PIL:使用python自带图像处理库读取出来的图片格式
- numpy:使用python-opencv库读取出来的图片格式
- tensor:pytorch中训练时所采取的向量格式(当然也可以说图片)
PIL与Tensor相互转换
importtorch
fromPILimportImage
importmatplotlib.pyplotasplt
#loader使用torchvision中自带的transforms函数
loader=transforms.Compose([
transforms.ToTensor()])
unloader=transforms.ToPILImage()
#输入图片地址
#返回tensor变量
defimage_loader(image_name):
image=Image.open(image_name).convert('RGB')
image=loader(image).unsqueeze(0)
returnimage.to(device,torch.float)
#输入PIL格式图片
#返回tensor变量
defPIL_to_tensor(image):
image=loader(image).unsqueeze(0)
returnimage.to(device,torch.float)
#输入tensor变量
#输出PIL格式图片
deftensor_to_PIL(tensor):
image=tensor.cpu().clone()
image=image.squeeze(0)
image=unloader(image)
returnimage
#直接展示tensor格式图片
defimshow(tensor,title=None):
image=tensor.cpu().clone()#weclonethetensortonotdochangesonit
image=image.squeeze(0)#removethefakebatchdimension
image=unloader(image)
plt.imshow(image)
iftitleisnotNone:
plt.title(title)
plt.pause(0.001)#pauseabitsothatplotsareupdated
#直接保存tensor格式图片
defsave_image(tensor,**para):
dir='results'
image=tensor.cpu().clone()#weclonethetensortonotdochangesonit
image=image.squeeze(0)#removethefakebatchdimension
image=unloader(image)
ifnotosp.exists(dir):
os.makedirs(dir)
image.save('results_{}/s{}-c{}-l{}-e{}-sl{:4f}-cl{:4f}.jpg'
.format(num,para['style_weight'],para['content_weight'],para['lr'],para['epoch'],
para['style_loss'],para['content_loss']))
numpy与tensor相互转换
importcv2
importtorch
importmatplotlib.pyplotasplt
deftoTensor(img):
asserttype(img)==np.ndarray,'theimgtypeis{},butndarryexpected'.format(type(img))
img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img=torch.from_numpy(img.transpose((2,0,1)))
returnimg.float().div(255).unsqueeze(0)#255也可以改为256
deftensor_to_np(tensor):
img=tensor.mul(255).byte()
img=img.cpu().numpy().squeeze(0).transpose((1,2,0))
returnimg
defshow_from_cv(img,title=None):
img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
plt.figure()
plt.imshow(img)
iftitleisnotNone:
plt.title(title)
plt.pause(0.001)
defshow_from_tensor(tensor,title=None):
img=tensor.clone()
img=tensor_to_np(img)
plt.figure()
plt.imshow(img)
iftitleisnotNone:
plt.title(title)
plt.pause(0.001)
N张图片一起转换.
#将NxHxWXC的numpy格式图片转化为相应的tensor格式 deftoTensor(img): img=torch.from_numpy(img.transpose((0,3,1,2))) returnimg.float().div(255).unsqueeze(0)
参考:https://www.nhooo.com/article/177291.htm
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。