Pytorch
官方文档 中文 https://www.pytorchtutorial.com/docs/
官方文档 https://pytorch.org/docs/stable/index.html
1 torchvision.models
有关图像处理的模型
torchvision.models: https://pytorch.org/vision/0.9/models.html


2 Example: Classification VGG
1. Download ImageNet

too large
2. Download the model

1 2 3 4 5 6
| import torchvision.datasets
vgg16_pretrained_false = torchvision.models.vgg16(pretrained=False) vgg16_pretrained_true = torchvision.models.vgg16(pretrained=True) print(vgg16_pretrained_true)
|

1
| model: out_features=1000
|
3. Add Layers
1 2 3 4 5 6 7 8 9
| """ 增加层 """ vgg16_pretrained_true.classifier.add_module( name="add_linear", module=nn.Linear( in_features=1000, out_features=10, ), ) print(vgg16_pretrained_true)
|

4. Modify Layers
1 2 3 4 5 6
| """ 修改层 """ vgg16_pretrained_false.classifier[6] = nn.Linear( in_features=1000, out_features=10, ) print(vgg16_pretrained_false)
|

3 Save & Load Models
1. Save
Method 1
1 2 3 4 5 6 7 8 9
| import torch import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
''' 1. save method 1: save -> structure + parameters ''' torch.save(vgg16, "./models/vgg16_method1.pth")
|
Method 2 (Recommend)
1 2 3 4 5 6 7 8 9
| import torch import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
''' 2. save method 2: save as dict -> parameters (better) ''' torch.save(vgg16.state_dict(), "./models/vgg16_method2.pth")
|
2. Load
Method 1
structure + parameters
1 2 3 4 5 6
| import torch import torchvision
model1 = torch.load("./models/vgg16_method1.pth") print(model1)
|
Method 2 (Recommend)
parameters -> models
1 2 3 4 5 6 7 8
| import torch import torchvision
vgg16 = torchvision.models.vgg16() model2_param_dict = torch.load("./models/vgg16_method2.pth") vgg16.load_state_dict(model2_param_dict) print(vgg16)
|
3. Some Errors
- save the model with method 1
1 2 3 4 5 6 7 8 9 10 11 12
| ''' python_file_name: model_save.py ''' class MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3) def forward(self, x): x = self.conv1(x) return x
my_net = MyNet() torch.save(my_net, "my_net_method1.pth")
|
- load the model in another python file
1 2 3 4 5 6 7 8
| ''' error model = torch.load("my_net_method1.pth") '''
from model_save import *
model = torch.load("my_net_method1.pth")
|