您的位置:首页 > 其它

Keras —— 基于InceptionV3模型(不含全连接层)的迁移学习应用

2018-02-25 15:56 459 查看

一、ImageDataGenerator

def image_preprocess():
#  训练集的图片生成器,通过参数的设置进行数据扩增
train_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
rotation_range=30,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)
#   验证集的图片生成器,不进行数据扩增,只进行数据预处理
val_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
)
# 训练数据与测试数据
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(IM_WIDTH, IM_HEIGHT),
batch_size=batch_size, class_mode='categorical')

validation_generator = val_datagen.flow_from_directory(
val_dir,
target_size=(IM_WIDTH, IM_HEIGHT),
batch_size=batch_size, class_mode='categorical')
return train_generator, validation_generator


二、加载InceptionV3模型(不含全连接层)

使用带有预训练权重的InceptionV3模型,但不包括顶层分类器(顶层分类器即全连接层。)

base_model = InceptionV3(weights='imagenet', include_top=False)


三、添加新的顶层分类器

def add_new_last_layer(base_model, nb_classes):
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(FC_SIZE, activation='relu')(x)
predictions = Dense(nb_classes, activation='softmax')(x)
model = Model(input=base_model.input, output=predictions)
return model


四、训练顶层分类器

冻结base_model所有层,这样就可以正确获得bottleneck特征

def setup_to_transfer_learn(model, base_model):
for layer in base_model.layers:
layer.trainable = False
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

setup_to_transfer_learn(model, base_model)
history_tl = model.fit_generator(
train_generator,
epochs=nb_epoch,
steps_per_epoch=nb_train_samples // batch_size,
validation_data=validation_generator,
validation_steps=nb_val_samples // batch_size,
class_weight='auto')


五、对顶层分类器进行fine_tuning

冻结部分层,对顶层分类器进行Fine-tune

Fine-tune以一个预训练好的网络为基础,在新的数据集上重新训练一小部分权重。fine-tune应该在很低的学习率下进行,通常使用SGD优化

def setup_to_finetune(model):
for layer in model.layers[:NB_IV3_LAYERS_TO_FREEZE]:
layer.trainable = False
for layer in model.layers[NB_IV3_LAYERS_TO_FREEZE:]:
layer.trainable = True
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])

setup_to_finetune(model)  # 冻结model的部分层
history_ft = model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples // batch_size,
epochs=nb_epoch,
validation_data=validation_generator,
validation_steps=nb_val_samples // batch_size,
class_weight='auto')


源码地址:

https://github.com/Zheng-Wenkai/Keras_Demo
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  keras demo
相关文章推荐