您的位置:首页 > 其它

COCO数据集根据给定的种类id获得符合条件的图片id

2020-06-17 16:58 93 查看

代码参考cocoapi
核心代码如下,

catId
为进行筛选的类别id。首先,利用
self.catToImgs[catId]
获取符合类别的图片的id,然后,利用集合set的与交集操作&,完成对满足多个类别条件的图片id的筛选。

ids = set(imgIds)
for i, catId in enumerate(catIds):
if i == 0 and len(ids) == 0:
ids = set(self.catToImgs[catId])
else:
ids &= set(self.catToImgs[catId])

完整代码如下

import time as time
import json
from collections import defaultdict

def _isArrayLike(obj):
return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
class COCO:
def __init__(self, annotation_file=None):
"""
Constructor of Microsoft COCO helper class for reading and visualizing annotations.
:param annotation_file (str): location of annotation file
:param image_folder (str): location to the folder that hosts images.
:return:
"""
# load dataset
self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
if not annotation_file == None:
print('loading annotations into memory...')
tic = time.time()
dataset = json.load(open(annotation_file, 'r'))
assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset))
print('Done (t={:0.2f}s)'.format(time.time()- tic))
self.dataset = dataset
self.createIndex()

def createIndex(self):
# create index
print('creating index...')
anns, cats, imgs = {}, {}, {}
imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
if 'annotations' in self.dataset:
for ann in self.dataset['annotations']:
imgToAnns[ann['image_id']].append(ann)
anns[ann['id']] = ann

if 'images' in self.dataset:
for img in self.dataset['images']:
imgs[img['id']] = img

if 'categories' in self.dataset:
for cat in self.dataset['categories']:
cats[cat['id']] = cat

if 'annotations' in self.dataset and 'categories' in self.dataset:
for ann in self.dataset['annotations']:
catToImgs[ann['category_id']].append(ann['image_id'])

print('index created!')

# create class members
self.anns = anns
self.imgToAnns = imgToAnns
self.catToImgs = catToImgs
self.imgs = imgs
self.cats = cats

def getCatIds(self, catNms=[], supNms=[], catIds=[]):
"""
filtering parameters. default skips that filter.
:param catNms (str array)  : get cats for given cat names
:param supNms (str array)  : get cats for given supercategory names
:param catIds (int array)  : get cats for given cat ids
:return: ids (int array)   : integer array of cat ids
"""
catNms = catNms if _isArrayLike(catNms) else [catNms]
supNms = supNms if _isArrayLike(supNms) else [supNms]
catIds = catIds if _isArrayLike(catIds) else [catIds]

if len(catNms) == len(supNms) == len(catIds) == 0:
print('进入if,不进行筛选时默认获取全部的cats')
cats = self.dataset['categories']
else:
print('进入else,根据筛选条件对cats进行筛选')
cats = self.dataset['categories']
cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms]
cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds]
print(self.dataset.keys())
print(cats)
ids = [cat['id'] for cat in cats]
return ids
def loadCats(self, ids=[]):
"""
Load cats with the specified ids.
:param ids (int array)       : integer ids specifying cats
:return: cats (object array) : loaded cat objects
"""
if _isArrayLike(ids):
return [self.cats[id] for id in ids]
elif type(ids) == int:
return [self.cats[ids]]
def getImgIds(self, imgIds=[], catIds=[]):
'''
Get img ids that satisfy given filter conditions.
:param imgIds (int array) : get imgs for given ids
:param catIds (int array) : get imgs with all given cats
:return: ids (int array)  : integer array of img ids
'''
imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
catIds = catIds if _isArrayLike(catIds) else [catIds]

if len(imgIds) == len(catIds) == 0:
ids = self.imgs.keys()
else:
ids = set(imgIds)
print(' ')
print(ids)
print(catIds)
print(self.catToImgs[catIds[0]])
print(' ')
for i, catId in enumerate(catIds):
if i == 0 and len(ids) == 0:
ids = set(self.catToImgs[catId])
else:
ids &= set(self.catToImgs[catId])
print(ids)
print()

return list(ids)
dataDir = '../..'
dataType = 'val2017'
annDir = '{}/annotations'.format(dataDir)
annFile = '{}/instances_{}.json'.format(annDir, dataType)
coco = COCO(annFile)

catIds = coco.getCatIds(catNms=['person','dog','skateboard'])
print(catIds)
imgIds = coco.getImgIds(catIds=catIds )
print(imgIds)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: