您的位置:首页 > 编程语言

AI challenger 场景分类 PyTorch 迁移学习 Places365-CNNs 启动代码

2017-10-13 10:28 1086 查看
分享个简单的启动代码。

'''
CHANGES:
- imagenet cnns: resnet: http://pytorch.org/docs/master/torchvision/models.html - places 365 cnns: resnet 18, 50: https://github.com/CSAILVision/places365 - top3 accuracy: https://github.com/pytorch/examples/blob/master/imagenet/main.py - 训练-验证流程: http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html# 
TODO:
- (训练-验证)-(测试)总流程,代码模块化: https://zhuanlan.zhihu.com/p/29024978 - places: densenet 161
- 测试其他imagenet cnn

- 数据增强,各种套路逐一实现===========================================

- mxnet resnet 152? https://github.com/YanWang2014/iNaturalist - tf inception-resnet v2? http://blog.csdn.net/wayne2019/article/details/78210172 '''
#pkill -9 python
#nvidia-smi
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import time
import json

'''
load pretrained model
'''
from functools import partial
import pickle
pickle.load = partial(pickle.load, encoding="latin1")
pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
#model = torch.load(model_file, map_location=lambda storage, loc: storage, pickle_module=pickle)

# th architecture to use
arch = 'resnet18_places365'  # AlexNet, ResNet18, ResNet50, DenseNet161

model_weight = 'whole_%s.pth.tar' % arch
use_gpu = 1
if use_gpu == 1:
model_conv = torch.load(model_weight, pickle_module=pickle)
else:
model_conv = torch.load(model_weight, map_location=lambda storage, loc: storage, pickle_module=pickle) # model trained in GPU could be deployed in CPU machine like this!

'''
load and transform data
'''
with open('../ai_challenger_scene_train_20170904/scene_train_annotations_20170904.json', 'r') as f: #label文件
label_raw_train = json.load(f)
with open('../ai_challenger_scene_validation_20170908/scene_validation_annotations_20170908.json', 'r') as f: #label文件
label_raw_val = json.load(f)

label_raw_train[0]['label_id']
len(label_raw_train)

class SceneDataset(Dataset):

def __init__(self, json_labels, root_dir, transform=None):
"""
Args:
json_labesl (list):read from official json file.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.label_raw = json_labels
self.root_dir = root_dir
self.transform = transform

def __len__(self):
return len(self.label_raw)

def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.label_raw[idx]['image_id'])
image = Image.open(img_name)
label = int(self.label_raw[idx]['label_id'])

if self.transform:
image = self.transform(image)

return image, label

data_transforms = {
'train': transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
transformed_dataset_train = SceneDataset(json_labels=label_raw_train,
root_dir='../ai_challenger_scene_train_20170904/scene_train_images_20170904',
transform=data_transforms['train']
)
transformed_dataset_val = SceneDataset(json_labels=label_raw_val,
root_dir='../ai_challenger_scene_validation_20170908/scene_validation_images_20170908',
transform=data_transforms['val']
)
batch_size = 64
dataloader = {'train':DataLoader(transformed_dataset_train, batch_size=batch_size,shuffle=True, num_workers=8),
'val':DataLoader(transformed_dataset_val, batch_size=batch_size,shuffle=True, num_workers=8)
}
dataset_sizes = {'train': len(label_raw_train), 'val':len(label_raw_val)}
#use_gpu = torch.cuda.is_available()
#use_gpu = False

######################################################################
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()

def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0

def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))

res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res

def train_model (model, criterion, optimizer, scheduler, num_epochs, total_steps):
since = time.time()

print('total_steps is %d' % total_steps)
mystep = 0

best_model_wts = model.state_dict()
best_acc = 0.0

for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)

if (epoch%10 == 0):
torch.save(best_model_wts, ('%s_model_wts_%d.pth')% (arch, epoch))

# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train(True)  # Set model to training mode
else:
model.train(False)  # Set model to evaluate mode

running_loss = 0.0
running_corrects = 0
top1 = AverageMeter()
top3 = AverageMeter()

# Iterate over data.
for data in dataloader[phase]:
# get the inputs
mystep = mystep + 1
if(mystep%100 ==0):
duration = time.time() - since
print('step %d vs %d in %.0f s' % (mystep, total_steps, duration))

inputs, labels = data

# wrap them in Variable
if use_gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels)

# zero the parameter gradients
optimizer.zero_grad()

# forward
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)

# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()

# statistics
running_loss += loss.data[0]
running_corrects += torch.sum(preds == labels.data)
#                print(type(labels)) # <class 'torch.autograd.variable.Variable'>
#                print(type(labels.data)) # <class 'torch.cuda.LongTensor'>
prec1, prec3 = accuracy(outputs.data, labels.data, topk=(1, 3))
top1.update(prec1[0], inputs.data.size(0))
top3.update(prec3[0], inputs.data.size(0))

epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]

#没测试batch_size不能被dataset_size整除时会不会有问题
print('{} Loss: {:.6f} Acc: {:.6f}'.format(
phase, epoch_loss, epoch_acc))
print(' * Prec@1 {top1.avg:.6f} Prec@3 {top3.avg:.6f}'.format(top1=top1, top3=top3))

# deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = model.state_dict()

print()

#if (epoch%10 == 0):
# torch.save(best_model_wts, ('models/best_model_wts_%d.pth')% epoch)

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:.6f}'.format(best_acc))

# load best model weights
model.load_state_dict(best_model_wts)
return model

'''trained layers'''
#model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 80)

if use_gpu:
model_conv = model_conv.cuda()

criterion = nn.CrossEntropyLoss()

'''optimizer'''
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.0001, momentum=0.9)

'''Decay LR by a factor of 0.1 every 100 epochs'''
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=100, gamma=0.1)

######################################################################
# Train and evaluate

num_epochs = 2
total_steps = 1.0 * num_epochs * (len(label_raw_train) + len(label_raw_val)) / batch_size
print(total_steps)
model_conv = train_model(model_conv, criterion, optimizer_conv,
exp_lr_scheduler, num_epochs, total_steps)
torch.save(model_conv.state_dict(), ('%s_best_model_wts_final.pth')%arch)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: