Pytorch

官方文档 中文 https://www.pytorchtutorial.com/docs/

官方文档 https://pytorch.org/docs/stable/index.html

1 torchvision.models

有关图像处理的模型

torchvision.models: https://pytorch.org/vision/0.9/models.html

2 Example: Classification VGG

1. Download ImageNet

too large

2. Download the model

1
2
3
4
5
6
import torchvision.datasets


vgg16_pretrained_false = torchvision.models.vgg16(pretrained=False)
vgg16_pretrained_true = torchvision.models.vgg16(pretrained=True)
print(vgg16_pretrained_true)

1
model: out_features=1000

3. Add Layers

1
2
3
4
5
6
7
8
9
""" 增加层 """
vgg16_pretrained_true.classifier.add_module(
name="add_linear", # 增加新的层名
module=nn.Linear(
in_features=1000,
out_features=10,
), # 增加新的层
)
print(vgg16_pretrained_true) # out_features=10

4. Modify Layers

1
2
3
4
5
6
""" 修改层 """
vgg16_pretrained_false.classifier[6] = nn.Linear(
in_features=1000,
out_features=10,
)
print(vgg16_pretrained_false) # out_features=10

3 Save & Load Models

1. Save

Method 1

  • save()
1
2
3
4
5
6
7
8
9
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

'''
1. save method 1: save -> structure + parameters
'''
torch.save(vgg16, "./models/vgg16_method1.pth")

Method 2 (Recommend)

  • save(model.state_dict())
1
2
3
4
5
6
7
8
9
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

'''
2. save method 2: save as dict -> parameters (better)
'''
torch.save(vgg16.state_dict(), "./models/vgg16_method2.pth")

2. Load

Method 1

  • load()

structure + parameters

1
2
3
4
5
6
import torch
import torchvision

# load method 1:
model1 = torch.load("./models/vgg16_method1.pth")
print(model1)

Method 2 (Recommend)

  • load_state_dict()

parameters -> models

1
2
3
4
5
6
7
8
import torch
import torchvision

# load method 2:
vgg16 = torchvision.models.vgg16()
model2_param_dict = torch.load("./models/vgg16_method2.pth") # parameters' dict
vgg16.load_state_dict(model2_param_dict)
print(vgg16)

3. Some Errors

  • save the model with method 1
1
2
3
4
5
6
7
8
9
10
11
12
''' python_file_name: model_save.py '''
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

def forward(self, x):
x = self.conv1(x)
return x

my_net = MyNet()
torch.save(my_net, "my_net_method1.pth")
  • load the model in another python file
1
2
3
4
5
6
7
8
'''
error
model = torch.load("my_net_method1.pth")
'''

from model_save import *

model = torch.load("my_net_method1.pth")