|
你可以看到我们总共有 16 个使用 3x3 卷积过滤器的卷积层,与最大的池化层来下采样,和由 4096 个单元组成的两个全连接的隐藏层,每个隐藏层之后跟随一个由 1000 个单元组成的致密层,每个单元代表 ImageNet 数据库中的一个分类。我们不需要最后三层,因为我们将使用我们自己的全连接致密层来预测疟疾。我们更关心前五个块,因此我们可以利用 VGG 模型作为一个有效的特征提取器。
我们将使用模型之一作为一个简单的特征提取器,通过冻结五个卷积块的方式来确保它们的位权在每个纪元后不会更新。对于最后一个模型,我们会对 VGG 模型进行微调,我们会解冻最后两个块(第 4 和第 5)因此当我们训练我们的模型时,它们的位权在每个时期(每批数据)被更新。
模型 2:预训练的模型作为一个特征提取器
为了构建这个模型,我们将利用 TensorFlow 载入 VGG-19 模型并冻结卷积块,因此我们能够将它们用作特征提取器。我们在末尾插入我们自己的致密层来执行分类任务。
vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet', input_shape=INPUT_SHAPE)vgg.trainable = False# Freeze the layersfor layer in vgg.layers: layer.trainable = False base_vgg = vggbase_out = base_vgg.outputpool_out = tf.keras.layers.Flatten()(base_out)hidden1 = tf.keras.layers.Dense(512, activation='relu')(pool_out)drop1 = tf.keras.layers.Dropout(rate=0.3)(hidden1)hidden2 = tf.keras.layers.Dense(512, activation='relu')(drop1)drop2 = tf.keras.layers.Dropout(rate=0.3)(hidden2)-
out = tf.keras.layers.Dense(1, activation='sigmoid')(drop2)-
model = tf.keras.Model(inputs=base_vgg.input, outputs=out)model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy'])model.summary()-
-
# OutputModel: "model_1"_________________________________________________________________Layer (type) Output Shape Param # =================================================================input_2 (InputLayer) [(None, 125, 125, 3)] 0 _________________________________________________________________block1_conv1 (Conv2D) (None, 125, 125, 64) 1792 _________________________________________________________________block1_conv2 (Conv2D) (None, 125, 125, 64) 36928 _________________________________________________________________......_________________________________________________________________block5_pool (MaxPooling2D) (None, 3, 3, 512) 0 _________________________________________________________________flatten_1 (Flatten) (None, 4608) 0 _________________________________________________________________dense_3 (Dense) (None, 512) 2359808 _________________________________________________________________dropout_2 (Dropout) (None, 512) 0 _________________________________________________________________dense_4 (Dense) (None, 512) 262656 _________________________________________________________________dropout_3 (Dropout) (None, 512) 0 _________________________________________________________________dense_5 (Dense) (None, 1) 513 =================================================================Total params: 22,647,361Trainable params: 2,622,977Non-trainable params: 20,024,384_________________________________________________________________
(编辑:衢州站长网)
【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!
|