下载Pytorch和数据处理

0 创建环境并下载Pytorch

官网:https://pytorch.org

1
conda install pytorch::pytorch torchvision torchaudio -c pytorch

1 加载数据

1. Dataset

提取数据并获取label

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from torch.utils.data import Dataset
from PIL import Image
import os


class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)

def __getitem__(self, index):
img_name = self.img_path[index]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label

def __len__(self):
return len(self.img_path)


root_dir = 'dataset/train'

# 蚂蚁数据集
ants_label_dir = 'ants_image'
ants_dataset = MyData(root_dir, ants_label_dir)

# 蜜蜂数据集
bees_label_dir = 'bees_image'
bees_dataset = MyData(root_dir, bees_label_dir)

# 整合 简单的拼接,按照谁在前整合后仍然在前的原则
train_dataset = ants_dataset + bees_dataset

常见数据集形式

2. Dataloader

为后面的网络提供不同的数据形式(打包)

2 TensorBoard

1. 尝试

1
2
3
4
5
6
7
8
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('logs')

for i in range(50):
writer.add_scalar('y=x', i, i) # 第一个是标题,第二个是y轴,第三个是x轴

writer.close()
1
2
3
4
tensorboard --logdir=文件夹名
tensorboard --logdir=logs
# 改端口
tensorboard --logdir=logs --port=6007

2. 训练集练习

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np

writer = SummaryWriter('logs')
image_path1 = 'dataset/train/ants_image/0013035.jpg'
image_PIL1 = Image.open(image_path1)
image_array1 = np.array(image_PIL1)

image_path2 = 'dataset/train/ants_image/5650366_e22b7e1065.jpg'
image_PIL2 = Image.open(image_path2)
image_array2 = np.array(image_PIL2)

writer.add_image("test", image_array1, 1, dataformats='HWC')
writer.add_image("test", image_array2, 2, dataformats='HWC')

writer.close()

3 Transforms

[!NOTE]

torchvision的模块之一:Transforms

1. ToTensor

  1. transforms.ToTensor将”PIL Image”和”numpy.ndarray”转化为tensor类型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from torchvision import transforms
from PIL import Image
import cv2

img_path = 'dataset/train/ants_image/5650366_e22b7e1065.jpg'

# PIL Image 类型
img = Image.open(img_path)
print(img)
>> <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x375 at 0x13B2D7970>

# numpy.ndarray 类型
cv_img = cv2.imread(img_path)
print(cv_img)
>> [[[106 119 97]
[106 119 97]
[107 120 98]
...
[110 115 116]
[110 115 116]
[110 115 116]]]

tensor_trans = transforms.ToTensor()
tensor_img = tensor_trans(img)
print(tensor_img)
>> tensor([[[0.3804, 0.3804, 0.3843, ..., 0.3412, 0.3373, 0.3333],
[0.3765, 0.3804, 0.3843, ..., 0.3529, 0.3490, 0.3451],
[0.3804, 0.3804, 0.3843, ..., 0.3725, 0.3686, 0.3647],
...,
[0.5608, 0.5608, 0.5647, ..., 0.4392, 0.4392, 0.4392],
[0.5412, 0.5529, 0.5608, ..., 0.4353, 0.4353, 0.4353],
[0.5333, 0.5412, 0.5608, ..., 0.4314, 0.4314, 0.4314]]])

2. transforms使用

Note:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
"""
__call__函数的作用
"""

class Person:
def __call__(self, use):
print("__call__函数:" + use)

def func(self, use):
print("一般函数:" + use)

person = Person()
person("call能直接利用类名括号调用")
person.func("必须使用.func方式")

4 torchvision数据集

torchvision.dataset

查看官方文档:https://pytorch.org/

0.9版本:https://pytorch.org/vision/0.9/

1. 尝试使用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(), # 转为tensor类型
])

train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=dataset_transforms, download=True)
test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transforms, download=True)

'''
img, target = test_set[0]
print(img) # <PIL.Image.Image image mode=RGB size=32x32 at 0x14433FCA0>
print(target) # 3
print(test_set.classes[target]) # cat
img.show()
'''

writer = SummaryWriter(log_dir='./logs_cifar')
for i in range(10):
img, target = test_set[i]
writer.add_image('test_set', img, i)

writer.close()

2. 数据集的下载

进入数据集的源码(CIFAR10),查看url即为下载链接

5 Dataloader

官方文档:https://pytorch.org/docs/1.8.1/data.html?highlight=dataloader#torch.utils.data.DataLoader

1. batch_size

2. shuffle

shuffle打乱顺序

3. 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset_transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(), # 转为tensor类型
])

test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transforms, download=True)

test_loader = DataLoader(dataset=test_set, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
"""
1. batch_size=64 每次取4个数据进行打包:test_set[0-63]=dataset[0-63]打包
2. shuffle打乱顺序
"""

writer = SummaryWriter('dataloader')

for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
writer.add_images('Epoch: {}'.format(epoch), imgs, step)
step += 1

writer.close()