您的位置:首页 > 其它

pytorch迁移学习,使用预训练模型

2019-07-06 18:14 441 查看
# Download and load the pretrained ResNet-18.
resnet = torchvision.models.resnet18(pretrained=True)

# 将参数设置为不可修改
for param in resnet.parameters():
param.requires_grad = False

# 替换网络的顶层
resnet.fc = nn.Linear(resnet.fc.in_features, 100)  # resnet.fc.in_features是输入维度,100为输出维度,此时网络的顶层是可以训练的,实现了网络的微调

# Forward pass.
images = torch.randn(64, 3, 224, 224)
outputs = resnet(images)
print (outputs.size())     # (64, 100)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐