TSN算法的PyTorch代码解读(测试部分)
2018-01-14 19:13
1061 查看
这篇博客介绍TSN算法的PyTorch代码的测试部分,建议先看训练部分的代码解读:TSN算法的PyTorch代码解读(训练部分),test_moels.py是测试模型的入口。
前面模块导入和命令行参数配置方面和训练代码类似,不细讲。
接下来先是根据数据集来确定类别数。然后通过models.py脚本中的TSN类来导入网络结构。另外如果想查看得到的网络net的各层信息,可以通过net.state_dict()来查看。checkpoint = torch.load(args.weights)是导入预训练的模型,在PyTorch中,导入模型都是采用torch.load()接口实现,输入args.weights就是.pth文件,也就是预训练模型。
接下来关于args.test_crops的条件语句是用来对数据做不同的crop操作:简单crop操作和重复采样的crop操作。如果args.test_crops等于1,就先resize到指定尺寸(比如从400resize到256),然后再做center crop操作,最后得到的是net.input_size的尺寸(比如224),注意这里一张图片做完这些crop操作后输出还是一张图片。如果args.test_crops等于10,那么就调用该项目下的transforms.py脚本中的GroupOverSample类进行重复采样的crop操作,最终一张图像得到10张crop的结果,后面会详细介绍GroupOverSample这个类。接下来的数据读取部分和训练时候类似,需要注意的是:1、num_segments的参数默认是25,比训练时候要多的多。2、test_mode=True,所以在调用TSNDataSet类的
transforms.py脚本中的GroupOverSample类。首先
接下来是设置GPU模式、设置模型为验证模式、初始化数据等。
开始循环读取数据,每执行一次循环表示读取一个video的数据。在循环中主要是调用eval_video函数来测试。预测结果和真实标签的结果都保存在output列表中。
eval_video函数是测试的主体,当准备好了测试数据和模型后就通过这个函数进行预测。输入video_data是一个tuple:(i, data, label)。data.view(-1, length, data.size(2), data.size(3))是将原本输入为(1,3*args.test_crops*args.test_segments,224,224)变换到(args.test_crops*args.test_segments,3,224,224),相当于batch size为args.test_crops*args.test_segments。然后用torch.autograd.Variable接口封装成Variable类型数据并作为模型的输入。net(input_var)得到的结果是Variable,如果要读取Tensor内容,需读取data变量,cpu()表示存储到cpu,numpy()表示Tensor转为numpy array,copy()表示拷贝。rst.reshape((num_crop, args.test_segments, num_class))表示将输入维数(二维)变化到指定维数(三维),mean(axis=0)表示对num_crop维度取均值,也就是原来对某帧图像的10张crop或clip图像做预测,最后是取这10张预测结果的均值作为该帧图像的结果。最后再执行一个reshape操作。最后返回的是3个值,分别表示video的index,预测结果和video的真实标签。
接下来要计算video-level的预测结果,这里从np.mean(x[0], axis=0)可以看出对args.test_segments帧图像的结果采取的也是均值方法来计算video-level的预测结果,然后通过np.argmax将概率最大的那个类别作为该video的预测类别。video_labels则是真实类别。cf = confusion_matrix(video_labels, video_pred).astype(float)是调用了混淆矩阵生成结果(numpy array),举个例子,y_true=[2,0,2,2,0,1],y_pred=[0,0,2,2,0,2],那么confusion_matrix(y_true, y_pred)的结果就是array([[2,0,0],[0,0,1],[1,0,2]]),每行表示真实类别,每列表示预测类别。因此cls_cnt = cf.sum(axis=1)表示每个真实类别有多少个video,cls_hit = np.diag(cf)就是将cf的对角线数据取出,表示每个类别的video中各预测对了多少个,因此cls_acc = cls_hit / cls_cnt就是每个类别的video预测准确率。np.mean(cls_acc)就是各类别的平均准确率。最后的if args.save_scores is not None:语句只是用来将预测结果保存成文件。
前面模块导入和命令行参数配置方面和训练代码类似,不细讲。
import argparse import time import numpy as np import torch.nn.parallel import torch.optim from sklearn.metrics import confusion_matrix from dataset import TSNDataSet from models import TSN from transforms import * from ops import ConsensusModule # options parser = argparse.ArgumentParser( description="Standard video-level testing") parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics']) parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff']) parser.add_argument('test_list', type=str) parser.add_argument('weights', type=str) parser.add_argument('--arch', type=str, default="resnet101") parser.add_argument('--save_scores', type=str, default=None) parser.add_argument('--test_segments', type=int, default=25) parser.add_argument('--max_num', type=int, default=-1) parser.add_argument('--test_crops', type=int, default=10) parser.add_argument('--input_size', type=int, default=224) parser.add_argument('--crop_fusion_type', type=str, default='avg', choices=['avg', 'max', 'topk']) parser.add_argument('--k', type=int, default=3) parser.add_argument('--dropout', type=float, default=0.7) parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--gpus', nargs='+', type=int, default=None) parser.add_argument('--flow_prefix', type=str, default='') args = parser.parse_args()
接下来先是根据数据集来确定类别数。然后通过models.py脚本中的TSN类来导入网络结构。另外如果想查看得到的网络net的各层信息,可以通过net.state_dict()来查看。checkpoint = torch.load(args.weights)是导入预训练的模型,在PyTorch中,导入模型都是采用torch.load()接口实现,输入args.weights就是.pth文件,也就是预训练模型。
base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint.state_dict().items())}就是读取预训练模型的层和具体参数并存到base_dict这个字典中。net.load_state_dict(base_dict)就是通过调用torch.nn.Module类的load_state_dict方法,达到用预训练模型初始化net网络的过程。需要注意的是load_state_dict方法还有一个输入:strict,如果该参数为True,就表示网络结构的层信息要和预训练模型的层信息严格相等,反之亦然,该参数默认是True。那么什么时候会用到False呢?就是当你只想用预训练网络初始化你的网络的部分层参数或者说你的预训练网络的层信息和你要被初始化的网络的层信息不完全一致,那样就只会初始化层信息相同的层。
if args.dataset == 'ucf101': num_class = 101 elif args.dataset == 'hmdb51': num_class = 51 elif args.dataset == 'kinetics': num_class = 400 else: raise ValueError('Unknown dataset '+args.dataset) net = TSN(num_class, 1, args.modality, base_model=args.arch, consensus_type=args.crop_fusion_type, dropout=args.dropout) checkpoint = torch.load(args.weights) base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint.state_dict().items())} net.load_state_dict(base_dict)
接下来关于args.test_crops的条件语句是用来对数据做不同的crop操作:简单crop操作和重复采样的crop操作。如果args.test_crops等于1,就先resize到指定尺寸(比如从400resize到256),然后再做center crop操作,最后得到的是net.input_size的尺寸(比如224),注意这里一张图片做完这些crop操作后输出还是一张图片。如果args.test_crops等于10,那么就调用该项目下的transforms.py脚本中的GroupOverSample类进行重复采样的crop操作,最终一张图像得到10张crop的结果,后面会详细介绍GroupOverSample这个类。接下来的数据读取部分和训练时候类似,需要注意的是:1、num_segments的参数默认是25,比训练时候要多的多。2、test_mode=True,所以在调用TSNDataSet类的
__getitem__方法时和训练时候有些差别。
if args.test_crops == 1: cropping = torchvision.transforms.Compose([ GroupScale(net.scale_size), GroupCenterCrop(net.input_size), ]) elif args.test_crops == 10: cropping = torchvision.transforms.Compose([ GroupOverSample(net.input_size, net.scale_size) ]) else: raise ValueError("Only 1 and 10 crops are supported while we got {}".format(args.test_crops)) data_loader = torch.utils.data.DataLoader( TSNDataSet("", args.test_list, num_segments=args.test_segments, new_length=1 if args.modality == "RGB" else 5, modality=args.modality, image_tmpl="img_{:05d}.jpg" if args.modality in ['RGB', 'RGBDiff'] else args.flow_prefix+"{}_{:05d}.jpg", test_mode=True, transform=torchvision.transforms.Compose([ cropping, Stack(roll=args.arch == 'BNInception'), ToTorchFormatTensor(div=args.arch != 'BNInception'), GroupNormalize(net.input_mean, net.input_std), ])), batch_size=1, shuffle=False, num_workers=args.workers * 2, pin_memory=True)
transforms.py脚本中的GroupOverSample类。首先
__init__中的GroupScale类也是在transforms.py中定义的,其实是对输入的n张图像都做torchvision.transforms.Scale操作,也就是resize到指定尺寸。GroupMultiScaleCrop.fill_fix_offset返回的offsets是一个长度为5的列表,每个值都是一个tuple,其中前4个是四个点坐标,最后一个是中心点坐标,目的是以这5个点为左上角坐标时可以在原图的四个角和中心部分crop出指定尺寸的图,后面有例子介绍。crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))是按照crop_w*crop_h的大小去crop原图像,这里采用的是224*224。flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)是对crop得到的图像做左右翻转。最后把未翻转的和翻转后的列表合并,这样一张输入图像就可以得到10张输出了(5张crop,5张crop加翻转)。举个例子,假设image_w=340,image_h=256,crop_w=224,crop_h=224,那么offsets就是[(0,0),(116,0),(0,32),(116,32),(58,16)],因此第一个crop的结果就是原图上左上角坐标为(0,0),右下角坐标为(224,224)的图,这也就是原图的左上角部分图;第二个crop的结果就是原图上左上角坐标为(116,0),右下角坐标为(340,224)的图,这也就是原图的右上角部分图,其他依次类推分别是原图的左下角部分图和右下角部分图,最后一个是原图正中央crop出来的224*224图。这就是论文中说的corner crop,而且是4个corner和1个center。
class GroupOverSample(object): def __init__(self, crop_size, scale_size=None): self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) if scale_size is not None: self.scale_worker = GroupScale(scale_size) else: self.scale_worker = None def __call__(self, img_group): if self.scale_worker is not None: img_group = self.scale_worker(img_group) image_w, image_h = img_group[0].size crop_w, crop_h = self.crop_size offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) oversample_group = list() for o_w, o_h in offsets: normal_group = list() flip_group = list() for i, img in enumerate(img_group): crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) normal_group.append(crop) flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) if img.mode == 'L' and i % 2 == 0: flip_group.append(ImageOps.invert(flip_crop)) else: flip_group.append(flip_crop) oversample_group.extend(normal_group) oversample_group.extend(flip_group) return oversample_group
接下来是设置GPU模式、设置模型为验证模式、初始化数据等。
if args.gpus is not None: devices = [args.gpus[i] for i in range(args.workers)] else: devices = list(range(args.workers)) net = torch.nn.DataParallel(net.cuda(devices[0]), device_ids=devices) net.eval() data_gen = enumerate(data_loader) total_num = len(data_loader.dataset) output = []
开始循环读取数据,每执行一次循环表示读取一个video的数据。在循环中主要是调用eval_video函数来测试。预测结果和真实标签的结果都保存在output列表中。
proc_start_time = time.time() max_num = args.max_num if args.max_num > 0 else len(data_loader.dataset) for i, (data, label) in data_gen: if i >= max_num: break rst = eval_video((i, data, label)) output.append(rst[1:]) cnt_time = time.time() - proc_start_time print('video {} done, total {}/{}, average {} sec/video'.format(i, i+1, total_num, float(cnt_time) / (i+1)))
eval_video函数是测试的主体,当准备好了测试数据和模型后就通过这个函数进行预测。输入video_data是一个tuple:(i, data, label)。data.view(-1, length, data.size(2), data.size(3))是将原本输入为(1,3*args.test_crops*args.test_segments,224,224)变换到(args.test_crops*args.test_segments,3,224,224),相当于batch size为args.test_crops*args.test_segments。然后用torch.autograd.Variable接口封装成Variable类型数据并作为模型的输入。net(input_var)得到的结果是Variable,如果要读取Tensor内容,需读取data变量,cpu()表示存储到cpu,numpy()表示Tensor转为numpy array,copy()表示拷贝。rst.reshape((num_crop, args.test_segments, num_class))表示将输入维数(二维)变化到指定维数(三维),mean(axis=0)表示对num_crop维度取均值,也就是原来对某帧图像的10张crop或clip图像做预测,最后是取这10张预测结果的均值作为该帧图像的结果。最后再执行一个reshape操作。最后返回的是3个值,分别表示video的index,预测结果和video的真实标签。
def eval_video(video_data): i, data, label = video_data num_crop = args.test_crops if args.modality == 'RGB': length = 3 elif args.modality == 'Flow': length = 10 elif args.modality == 'RGBDiff': length = 18 else: raise ValueError("Unknown modality "+args.modality) input_var = torch.autograd.Variable(data.view(-1, length, data.size(2), data.size(3)), volatile=True) rst = net(input_var).data.cpu().numpy().copy() return i, rst.reshape((num_crop, args.test_segments, num_class)).mean(axis=0).reshape( (args.test_segments, 1, num_class) ), label[0]
接下来要计算video-level的预测结果,这里从np.mean(x[0], axis=0)可以看出对args.test_segments帧图像的结果采取的也是均值方法来计算video-level的预测结果,然后通过np.argmax将概率最大的那个类别作为该video的预测类别。video_labels则是真实类别。cf = confusion_matrix(video_labels, video_pred).astype(float)是调用了混淆矩阵生成结果(numpy array),举个例子,y_true=[2,0,2,2,0,1],y_pred=[0,0,2,2,0,2],那么confusion_matrix(y_true, y_pred)的结果就是array([[2,0,0],[0,0,1],[1,0,2]]),每行表示真实类别,每列表示预测类别。因此cls_cnt = cf.sum(axis=1)表示每个真实类别有多少个video,cls_hit = np.diag(cf)就是将cf的对角线数据取出,表示每个类别的video中各预测对了多少个,因此cls_acc = cls_hit / cls_cnt就是每个类别的video预测准确率。np.mean(cls_acc)就是各类别的平均准确率。最后的if args.save_scores is not None:语句只是用来将预测结果保存成文件。
video_pred = [np.argmax(np.mean(x[0], axis=0)) for x in output] video_labels = [x[1] for x in output] cf = confusion_matrix(video_labels, video_pred).astype(float) cls_cnt = cf.sum(axis=1) cls_hit = np.diag(cf) cls_acc = cls_hit / cls_cnt print(cls_acc) print('Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100)) if args.save_scores is not None: # reorder before saving name_list = [x.strip().split()[0] for x in open(args.test_list)] order_dict = {e:i for i, e in enumerate(sorted(name_list))} reorder_output = [None] * len(output) reorder_label = [None] * len(output) for i in range(len(output)): idx = order_dict[name_list[i]] reorder_output[idx] = output[i] reorder_label[idx] = video_labels[i] np.savez(args.save_scores, scores=reorder_output, labels=reorder_label)
相关文章推荐
- TSN算法的PyTorch代码解读(训练部分)
- lua:部分常用操作的效率对比及代码优化建议(附测试代码)
- KinectFusion 中的 ICP 算法 GPU 代码解读
- 飞机大战部分算法,思维解析,附有源代码,免费下载.
- C++类模板 实现队列的链式存储结构算法 《数据结构》(北京科海) 部分代码摘抄,自己编写运行
- 2014新跟踪算法KCF笔记 --续(代码部分)
- 时空上下文视觉跟踪(STC)算法的解读与代码复现
- 学习笔记——《机器学习实战》KNN算法实现 约会网站测试,手写数字识别,代码,注释,错误修改
- 通过Java测试几种压缩算法的性能(附测试代码下载)
- 图像跟踪与识别-kcf-c++代码解读,算法主流程
- 从EXCEL导入cxGrid旧版本部分测试代码存档
- 【算法】各种排序算法测试代码
- 时空上下文视觉跟踪(STC)算法的解读与代码复现
- 时空上下文视觉跟踪(STC)算法的解读与代码复现
- 时空上下文视觉跟踪(STC)算法的解读与代码复现
- 《机器学习:算法原理与编程实践》的读书笔记:SMO部分最难,大部分代码基于Scikit-Learn,决策树其实用处不大
- 室内清扫机器人环境搭建与算法测试简易代码
- DFS BFS 示例 以及BFS部分算法代码
- 时空上下文视觉跟踪(STC)算法的解读与代码复现
- tiptop erp 一份出厂测试报告程序部分的代码