程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2024-11(1)

Pytorch小结

发布于2021-06-07 21:34     阅读(1007)     评论(0)     点赞(13)     收藏(2)


1.加载数据集

(1)直接使用torchversion提供的datasets,则采用以下方式进行加载:

train_set = torchvision.datasets.CIFAR10("./dataset", train=True, transform=torchvision.transforms.ToTensor(),
                                         download=True)

这一代码加载的是CIFAR10这个数据集,只有1G左右比较容易下载。
(2)自己把数据集下载下来后,写自己的数据集类继承torch.utils.data.Dataset,然后重载init和getitem用于获取自己的数据集,demo如下:

from PIL import Image
from torch.utils.data import Dataset
import os

class MyDataSet(Dataset):
    def __init__(self, root_dir, transform):
        self.root_dir = root_dir
        self.image_path = os.listdir(self.root_dir)
        self.transform = transform
    def __getitem__(self, idx):
        img = self.image_path[idx]
        img_path = os.path.join(self.root_dir, img)
        image = Image.open(img_path)
        image = self.transform(image)
        return image

如果需要将训练集和测试集分开,改动路径即可

2.可视化(Tensorboard的使用)

入门级别的使用基本就这几句话:

writer = SummaryWriter("./logs")
writer.add_scalar("train", loss.item(), train_all_times)
writer.add_image("train", image_tensor, 0)
writer.close()

另外还有对应的add_scalars和add_images等,而且可以简单的理解为第二个参数是y轴,第三个参数为x轴。

3.三种常用的数据格式的转化(PIL.Image、numpy.ndarray、Tensor的相互转换)

ima_path = "images/airplane.png"
img_PIL = Image.open(ima_path)
img_cv_numpy = cv2.imread(ima_path)

# 转为tensor
to_tensor = transforms.ToTensor()

img_numpy_tensor = to_tensor(img_cv_numpy)
img_PIL_tensor = to_tensor(img_PIL)

# 转为PIL
to_PIL = transforms.ToPILImage()

img_numpy_PIL = to_PIL(img_cv_numpy)
img_tensor_PIL = to_PIL(img_numpy_tensor)

# 转为numpy
img_tensor_numpy = img_PIL_tensor.numpy()
img_numpy = np.array(img_PIL)

4.torchvision.transforms的使用:用来对图像做变换

(1)例如上面的ToTensor()和ToPILImage()用来转化为tensor和pilimage
(2)找到transform查看它的结构,有这么多的类:
在这里插入图片描述
入门的有,像Compose用来同时声明多个操作,Resize用来重新变换高和宽等等,想用哪个直接按住ctrl点左键进入看源码咋解析的

5.torch.utils.data.DataLoader的使用:用来装入数据

对前面得到的dataset进行装载,其中有一个参数要注意,就是num_workers,这个参数再windows下需要置为0,否则会出问题

6.卷积神经网络的构建:

(1)构建模型的时候需要继承torch.nn.Moule,之后在类初始化中将各层书写完整,最好直接写在一个torch.nn.Sequential里面,最基本的有卷积层、池化层、全连接层等等,其中某些方法的参数参照这里面的公式进行设置:
官方参考
(2)loss函数在torch.nn中提供了很多,包括像交叉熵函数等等,如下:
在这里插入图片描述
可以感性的认识一下它,如果其中的class正确,则整个这一项x[class]会是最大的,那么-x[class]就是最小的,整个的loss就会小,这符合常理:预测正确的loss就是小的。
(3)优化器,在torch.optim中提供了不少优化器,每一个优化器都提供了不同的参数用来配置,但是都需要传入模型的参数,只有这样对其进行优化。例如SGD优化器,再另外传入一个学习速率就可以用了,另外torch.optim.lr_scheduler中提供了几种用于以epoch为单位对学习速率进行调整的函数方法从而更好的学习。
(4)具体的使用过程如下:

# 定义一个优化器,并使用StepLR对学习速率每5个epoch进行0.1倍的调整
# lr = 0.01     if epoch < 5
# lr = 0.001    if 5<= epoch < 10
# lr = 0.0001   if 10 <= epoch < 15
.......
optimizer = torch.optim.SGD(module.parameters(), 0.01)
scheduler = StepLR(optim, step_size=5, gamma=0.1)
.......
# 之后在训练过程中搭配loss函数使用
result_loss = loss(outputs, targets)# 利用loss函数获得loss
optim.zero_grad()#本次优化前清空上一次的grad
result_loss.backward()#反向传播求出本次的梯度
scheduler.step()#优化调整

7.保存模型及使用训练好的模型

(1)两种保存模型的方法,只说第二种(推荐使用):将模型训练好的参数保存在mymodule_cuda.pth文件中

torch.save(module.state_dict(), "mymodule_cuda.pth")

(2)读取保存的模型并进行测试:创建一个新的模型并将训练好的参数导入

model = Mymodule()
model.load_state_dict(torch.load("mymodule_cuda_google.pth"))

之后使用这个模型进行测试即可

8.其它:

(1)使用output.argmax(1)可以求出最后的十个分类结果哪个的概率最大。其中参数1指的就是行这一维度
(2)测试时在前面加上with torch.no_grad(),这样不会进行反向传播不会计算梯度
(3)torch.reshape():返回一个张量,内容和原来的张量相同,但是具有不同形状,传入的参数可以有一个-1,表示其具体值由其他维度信息和元素总个数推断出来.
(4)使用GPU计算:有两种方式只说第二种(推荐使用),首先定义一个device变量如下,意思是如果cuda可用就使用GPU计算,否则使用CPU。定义结束后在三个位置使用to(device):分别是模型、loss、输入的数据。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
module = Mymodule()
module.to(device)

loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)

imgs, targets = data
imgs = imgs.to(device)
targets = targets.to(device)

(5)测试时常用的参数:准确率。其中下面是用于计算64张图的准确预测的图片个数的,因为输入是64张图片且输出为分为10类的各自的概率故output的shape是([64, 10]),之后对其使用argmax(1)就会得到64个最终的结果,例如[1,5,8,4,7,2,3,6…],将这个预测的结果直接跟标准结果来一个==,会得到这种结果[True,True,True,True,True,False,True…],自然也是64个,之后再i来一个sum直接计算出这64个里面为True的个数也就是算出预测对的个数。

accuracy_times = (output.argmax(1) == targets).sum()

原文链接:https://blog.csdn.net/weixin_44142774/article/details/117594234



所属网站分类: 技术文章 > 博客

作者:搜嘎皮卡

链接:http://www.phpheidong.com/blog/article/89610/7d0a54e237d01807f43a/

来源:php黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

13 0
收藏该文
已收藏

评论内容:(最多支持255个字符)