1.数据集的准备
# Train or evaluate
if args.command == "train":
# Training dataset. Use the training set and 35K from the
# validation set, as as in the Mask RCNN paper.
dataset_train = CocoDataset()
dataset_train.load_coco(args.dataset, "train", year=args.year, auto_download=args.download)
dataset_train.load_coco(args.dataset, "valminusminival", year=args.year, auto_download=args.download)
dataset_train.prepare()
# Validation dataset
dataset_val = CocoDataset()
dataset_val.load_coco(args.dataset, "minival", year=args.year, auto_download=args.download)
dataset_val.prepare()
# *** This training schedule is an example. Update to your needs ***
# Training - Stage 1
print("Training network heads")
model.train_model(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE,
epochs=40,
layers='heads')
# Training - Stage 2
# Finetune layers from ResNet stage 4 and up
print("Fine tune Resnet stage 4 and up")
model.train_model(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE,
epochs=120,
layers='4+')
# Training - Stage 3
# Fine tune all layers
print("Fine tune all layers")
model.train_model(dataset_train, dataset_val,
learning_rate=config.LEARNING_RATE / 10,
epochs=160,
layers='all')
a.初始化一个类CocoDataset的的对象。
class CocoDataset(utils.Dataset):
此类本身继承自
class Dataset(object):
"""The base class for dataset classes.
To use it, create a new class that adds functions specific to the dataset
you want to use. For example:
class CatsAndDogsDataset(Dataset):
def load_cats_and_dogs(self):
...
def load_mask(self, image_id):
...
def image_reference(self, image_id):
...
See COCODataset and ShapesDataset as examples.
"""
那么CocoDataset在初始化时来自Dataset的初始化函数:
def __init__(self, class_map=None):
self._image_ids = []
self.image_info = []
# Background is always the first class
self.class_info = [{"source": "", "id": 0, "name": "BG"}]
self.source_class_ids = {}
2.load_coco()方法
dataset_train.load_coco(args.dataset, "train", year=args.year, auto_download=args.download)
dataset_train.load_coco(args.dataset, "valminusminival", year=args.year, auto_download=args.download)
def load_coco(self, dataset_dir, subset, year=DEFAULT_DATASET_YEAR, class_ids=None,
class_map=None, return_coco=False, auto_download=False):
"""Load a subset of the COCO dataset.
dataset_dir: The root directory of the COCO dataset.
subset: What to load (train, val, minival, valminusminival)
year: What dataset year to load (2014, 2017) as a string, not an integer
class_ids: If provided, only loads images that have the given classes.
class_map: TODO: Not implemented yet. Supports maping classes from
different datasets to the same class ID.
return_coco: If True, returns the COCO object.
auto_download: Automatically download and unzip MS-COCO images and annotations
"""
if auto_download is True:
self.auto_download(dataset_dir, subset, year)
coco = COCO("{}/annotations/instances_{}{}.json".format(dataset_dir, subset, year))
if subset == "minival" or subset == "valminusminival":
subset = "val"
image_dir = "{}/{}{}".format(dataset_dir, subset, year)
# Load all classes or a subset?
if not class_ids:
# All classes
class_ids = sorted(coco.getCatIds())
# All images or a subset?
if class_ids:
image_ids = []
for id in class_ids:
image_ids.extend(list(coco.getImgIds(catIds=[id])))
# Remove duplicates
image_ids = list(set(image_id