【神经网络】VGG16、VGG16_bn、VGG19_bn详解以及使用pytorch进行模型预训练

分类: 外勤365在线登录 时间: 2025-09-16 18:52:54 作者: admin 阅读: 9202
【神经网络】VGG16、VGG16_bn、VGG19_bn详解以及使用pytorch进行模型预训练

目录

一、论文

二、模型介绍

三、模型预训练

一、论文

先来看看VGG这篇论文《Very Deep Convolutional Networks for Large-Scale Image Recognition》论文下载地址

论文中几个模型主要以几下几种方案A、B、C、D、E。目前主要还是采用VGG16和VGG19也就是下图中的分别红框和绿框部分。

二、模型介绍

其实通过上面的表格就已经大致知道模型的框架组成部分了。其实VGG16与VGG19的区别就是前者在三、四、五部分少了一层卷积。这里先附基于pytorch的一些预训练模型预训练模型下载地址。

上图可以看出VGG分有无BatchNormalization。这里先介绍一下VGG16_bn的一些内部层结构。

VGG16_bn序号层结构层数权重0conv1-1164x3x31batchnorm 2relu1-1 3conv1-2264x3x34batchnorm 5relu1-2 6pool1 7conv2-13128x3x38batchnorm 9relu2-1 10conv2-24128x3x311batchnorm 12relu2-2 13pool2 14conv3-15256x3x315batchnorm 16relu3-1 17conv3-26256x3x318batchnorm 19relu3-2 20conv3-37256x3x321batchnorm 22relu3-3 23pool3 24conv4-18512x3x325batchnorm 26relu4-1 27conv4-29512x3x328batchnorm 29relu4-2 30conv4-310512x3x331batchnorm 32relu4-3 33pool4 34conv5-111512x3x335batchnorm 36relu5-1 37conv5-212512x3x338batchnorm 39relu5-2 40conv5-313512x3x341batchnorm 42relu5-3512x3x343pool5 44fc6(4096)14 45relu6 46fc7(4096)15 47relu7 48fc8(1000)16 49prob(softmax)

上表格是VGG16_bn的一些详细层结构,一共有16层(层是指卷积层和全连接层)VGG16则仅仅去掉红色部分的Batch_normalization部分。这里可以看到VGG16_bn的modules共有44个(这里不算全连接层),如果是VGG16则有31个(不算全连接层)。

下图是通过导入VGG16_bn模型在调试过程中的结果,可见与上面是一致。

三、模型预训练

3.1加载整个模型

基于pytorch模型预训练,首先都要导入加载模型。有两种方式,下面一一介绍。

1.采用在线下载,这种一般受网络原因比较慢,不建议。

2.自己先下好预训练模型,从本地加载,这里介绍一下加载预训练模型后后自己提供一张图片进行分类识别。

import torch

import numpy

import torch.nn as nn

import torch.nn.functional as F

from PIL import Image

from torchvision import transforms

import torchvision.models as models

vgg = models.vgg16_bn()

pre=torch.load('./vgg16_bn-6c64b313.pth')

vgg.load_state_dict(pre)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],#这是imagenet數據集的均值

std=[0.229, 0.224, 0.225])

tran=transforms.Compose([

transforms.Resize((224,224)),

transforms.ToTensor(),

transforms.Normalize(mean=[0.485, 0.456, 0.406],

std=[0.229, 0.224, 0.225])

])

im='./1.jpg'

im=Image.open(im)

im=tran(im)

im.unsqueeze_(dim=0)

print(im.shape)

# input()

out=vgg(im)

outnp=out.data[0]

ind=int(numpy.argmax(outnp))

print(ind)

from cls import d

print(d[ind])

print(out.shape)

# im.show()

3 主要有几个注意的地方。由于是加载VGG模型的,并提供自己一张图像进行预测,输入就必须符合VGG的格式。

VGG模型的图像读入方式采用PIL库所以就得使用PIL库进行读入图片输入图像的尺寸得必须和VGG保持一致224x224的三通道。(因为全连接层用的是VGG的)上面采用的normalization归一化方式 几个固定的参数是因为VGG数据的分布,其均值和方差VGG的最终分类的类别是1000类,最终out=vgg(img)是一个1000元素的张量。

4 查看加載的參數

pre = torch.load('./pretrain/vgg16_bn-6c64b313.pth')

for key, v in pre.items():

print(key, v.size())

加載得到的是VGG網絡參數,可以將其輸出查看,這裏只顯示其size

features.0.weight torch.Size([64, 3, 3, 3])

features.0.bias torch.Size([64])

features.1.weight torch.Size([64])

features.1.bias torch.Size([64])

features.1.running_mean torch.Size([64])

features.1.running_var torch.Size([64])

features.3.weight torch.Size([64, 64, 3, 3])

features.3.bias torch.Size([64])

features.4.weight torch.Size([64])

features.4.bias torch.Size([64])

features.4.running_mean torch.Size([64])

features.4.running_var torch.Size([64])

features.7.weight torch.Size([128, 64, 3, 3])

features.7.bias torch.Size([128])

features.8.weight torch.Size([128])

features.8.bias torch.Size([128])

features.8.running_mean torch.Size([128])

features.8.running_var torch.Size([128])

features.10.weight torch.Size([128, 128, 3, 3])

features.10.bias torch.Size([128])

features.11.weight torch.Size([128])

features.11.bias torch.Size([128])

features.11.running_mean torch.Size([128])

features.11.running_var torch.Size([128])

features.14.weight torch.Size([256, 128, 3, 3])

features.14.bias torch.Size([256])

features.15.weight torch.Size([256])

features.15.bias torch.Size([256])

features.15.running_mean torch.Size([256])

features.15.running_var torch.Size([256])

features.17.weight torch.Size([256, 256, 3, 3])

features.17.bias torch.Size([256])

features.18.weight torch.Size([256])

features.18.bias torch.Size([256])

features.18.running_mean torch.Size([256])

features.18.running_var torch.Size([256])

features.20.weight torch.Size([256, 256, 3, 3])

features.20.bias torch.Size([256])

features.21.weight torch.Size([256])

features.21.bias torch.Size([256])

features.21.running_mean torch.Size([256])

features.21.running_var torch.Size([256])

features.24.weight torch.Size([512, 256, 3, 3])

features.24.bias torch.Size([512])

features.25.weight torch.Size([512])

features.25.bias torch.Size([512])

features.25.running_mean torch.Size([512])

features.25.running_var torch.Size([512])

features.27.weight torch.Size([512, 512, 3, 3])

features.27.bias torch.Size([512])

features.28.weight torch.Size([512])

features.28.bias torch.Size([512])

features.28.running_mean torch.Size([512])

features.28.running_var torch.Size([512])

features.30.weight torch.Size([512, 512, 3, 3])

features.30.bias torch.Size([512])

features.31.weight torch.Size([512])

features.31.bias torch.Size([512])

features.31.running_mean torch.Size([512])

features.31.running_var torch.Size([512])

features.34.weight torch.Size([512, 512, 3, 3])

features.34.bias torch.Size([512])

features.35.weight torch.Size([512])

features.35.bias torch.Size([512])

features.35.running_mean torch.Size([512])

features.35.running_var torch.Size([512])

features.37.weight torch.Size([512, 512, 3, 3])

features.37.bias torch.Size([512])

features.38.weight torch.Size([512])

features.38.bias torch.Size([512])

features.38.running_mean torch.Size([512])

features.38.running_var torch.Size([512])

features.40.weight torch.Size([512, 512, 3, 3])

features.40.bias torch.Size([512])

features.41.weight torch.Size([512])

features.41.bias torch.Size([512])

features.41.running_mean torch.Size([512])

features.41.running_var torch.Size([512])

classifier.0.weight torch.Size([4096, 25088])

classifier.0.bias torch.Size([4096])

classifier.3.weight torch.Size([4096, 4096])

classifier.3.bias torch.Size([4096])

classifier.6.weight torch.Size([1000, 4096])

classifier.6.bias torch.Size([1000])

上面的feature最多是42個,不是44個,因爲relu和pool沒有顯示出來,其分別是features.42 feature.43.因爲加載的參數pre裏面包含的內容是參數,而relu操作和池化操作是不需要參數的,也就是模型保存時並沒有保存下來。

3.2加载部分模型

class VGG(nn.Module):

def __init__(self, weights=False):

super(VGG, self).__init__()

if weights is False:

model = models.vgg19_bn(pretrained=True)

model = models.vgg19_bn(pretrained=False)

pre = torch.load(weights)

model.load_state_dict(pre)

self.vgg19 = model.features

for param in self.vgg19.parameters():

param.requires_grad = False

初始化有个参数权重,当为false时,默认网上下载VGG模型,通常网上下载的比较慢不建议,所以直接本地下载好之后再load即可。这里选择了vgg的features部分,全连接部分没有选择,当然也可以索引或者切片选择任何层的 features。

相关文章

天下手游满级多少级介绍
上面一个般下面一个木念什么字?
声菲特技术资讯·第12期 | EQ的调试