Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions dygraph/tsm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# TSM 视频分类模型

本目录下为基于PaddlePaddle 动态图实现的 TSM视频分类模型,静态图实现请参考[TSM 视频分类模型](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/PaddleVideo/models/tsm)

---
## 内容

- [模型简介](#模型简介)
- [数据准备](#数据准备)
- [模型训练](#模型训练)
- [模型评估](#模型评估)


## 模型简介

Temporal Shift Module是由MIT和IBM Watson AI Lab的Ji Lin,Chuang Gan和Song Han等人提出的通过时间位移来提高网络视频理解能力的模块, 详细内容请参考论文[Temporal Shift Module for Efficient Video Understanding](https://arxiv.org/abs/1811.08383v1)

## 数据准备

TSM的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。数据下载及准备请参考[数据说明](data/dataset/README.md)

### 小数据集验证

为了便于快速迭代,我们采用了较小的数据集进行动态图训练验证,分别进行了两组实验验证:

1. 其中包括8k大小的训练数据和2k大小的测试数据。
2. 其中包括了十类大小的训练数据和测试数据。

## 模型训练

数据准备完毕后,可以通过如下方式启动训练:

bash run.sh train

## 模型评估

数据准备完毕后,可以通过如下方式启动训练:

bash run.sh eval

在从Kinetics400选取的十类的数据集下:

|Top-1|Top-5|
|:-:|:-:|
|76.56%|98.1%|

全量数据集精度
Top-1 0.70
请参考:[静态图](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/PaddleVideo)
85 changes: 85 additions & 0 deletions dygraph/tsm/config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import yaml
import logging
logger = logging.getLogger(__name__)

CONFIG_SECS = [
'train',
'valid',
'test',
'infer',
]


class AttrDict(dict):
def __getattr__(self, key):
return self[key]

def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value


def parse_config(cfg_file):
"""Load a config file into AttrDict"""
import yaml
with open(cfg_file, 'r') as fopen:
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.Loader))
create_attr_dict(yaml_config)
return yaml_config


def create_attr_dict(yaml_config):
from ast import literal_eval
for key, value in yaml_config.items():
if type(value) is dict:
yaml_config[key] = value = AttrDict(value)
if isinstance(value, str):
try:
value = literal_eval(value)
except BaseException:
pass
if isinstance(value, AttrDict):
create_attr_dict(yaml_config[key])
else:
yaml_config[key] = value
return


def merge_configs(cfg, sec, args_dict):
assert sec in CONFIG_SECS, "invalid config section {}".format(sec)
sec_dict = getattr(cfg, sec.upper())
for k, v in args_dict.items():
if v is None:
continue
try:
if hasattr(sec_dict, k):
setattr(sec_dict, k, v)
except:
pass
return cfg


def print_configs(cfg, mode):
logger.info("---------------- {:>5} Arguments ----------------".format(
mode))
for sec, sec_items in cfg.items():
logger.info("{}:".format(sec))
for k, v in sec_items.items():
logger.info(" {}:{}".format(k, v))
logger.info("-------------------------------------------------")
78 changes: 78 additions & 0 deletions dygraph/tsm/data/dataset/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# 数据使用说明

## Kinetics数据集

Kinetics数据集是DeepMind公开的大规模视频动作识别数据集,有Kinetics400与Kinetics600两个版本。这里使用Kinetics400数据集,具体的数据预处理过程如下。

### mp4视频下载
在Code\_Root目录下创建文件夹

cd $Code_Root/data/dataset && mkdir kinetics

cd kinetics && mkdir data_k400 && cd data_k400

mkdir train_mp4 && mkdir val_mp4

ActivityNet官方提供了Kinetics的下载工具,具体参考其[官方repo ](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics)即可下载Kinetics400的mp4视频集合。将kinetics400的训练与验证集合分别下载到data/dataset/kinetics/data\_k400/train\_mp4与data/dataset/kinetics/data\_k400/val\_mp4。

### mp4文件预处理

为提高数据读取速度,提前将mp4文件解帧并打pickle包,dataloader从视频的pkl文件中读取数据(该方法耗费更多存储空间)。pkl文件里打包的内容为(video-id, label, [frame1, frame2,...,frameN])。

在 data/dataset/kinetics/data\_k400目录下创建目录train\_pkl和val\_pkl

cd $Code_Root/data/dataset/kinetics/data_k400

mkdir train_pkl && mkdir val_pkl

进入$Code\_Root/data/dataset/kinetics目录,使用video2pkl.py脚本进行数据转化。首先需要下载[train](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics/data/kinetics-400_train.csv)和[validation](https://github.com/activitynet/ActivityNet/tree/master/Crawler/Kinetics/data/kinetics-400_val.csv)数据集的文件列表。

首先生成预处理需要的数据集标签文件

python generate_label.py kinetics-400_train.csv kinetics400_label.txt

然后执行如下程序:

python video2pkl.py kinetics-400_train.csv $Source_dir $Target_dir 8 #以8个进程为例

- 该脚本依赖`ffmpeg`库,请预先安装`ffmpeg`

对于train数据,

Source_dir = $Code_Root/data/dataset/kinetics/data_k400/train_mp4

Target_dir = $Code_Root/data/dataset/kinetics/data_k400/train_pkl

对于val数据,

Source_dir = $Code_Root/data/dataset/kinetics/data_k400/val_mp4

Target_dir = $Code_Root/data/dataset/kinetics/data_k400/val_pkl

这样即可将mp4文件解码并保存为pkl文件。

### 生成训练和验证集list
··
cd $Code_Root/data/dataset/kinetics

ls $Code_Root/data/dataset/kinetics/data_k400/train_pkl/* > train.list

ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > val.list

ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > test.list

ls $Code_Root/data/dataset/kinetics/data_k400/val_pkl/* > infer.list

即可生成相应的文件列表,train.list和val.list的每一行表示一个pkl文件的绝对路径,示例如下:

/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-097
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-114
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/train_pkl/data_batch_100-118
...

或者

/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-085
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-086
/ssd1/user/models/PaddleCV/PaddleVideo/data/dataset/kinetics/data_k400/val_pkl/data_batch_102-090
...
44 changes: 44 additions & 0 deletions dygraph/tsm/data/dataset/kinetics/generate_label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys

# kinetics-400_train.csv should be down loaded first and set as sys.argv[1]
# sys.argv[2] can be set as kinetics400_label.txt
# python generate_label.py kinetics-400_train.csv kinetics400_label.txt

num_classes = 400

fname = sys.argv[1]
outname = sys.argv[2]
fl = open(fname).readlines()
fl = fl[1:]
outf = open(outname, 'w')

label_list = []
for line in fl:
label = line.strip().split(',')[0].strip('"')
if label in label_list:
continue
else:
label_list.append(label)

assert len(label_list
) == num_classes, "there should be {} labels in list, but ".format(
num_classes, len(label_list))

label_list.sort()
for i in range(num_classes):
outf.write('{} {}'.format(label_list[i], i) + '\n')

outf.close()
87 changes: 87 additions & 0 deletions dygraph/tsm/data/dataset/kinetics/video2pkl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

import os
import sys
import glob
try:
import cPickle as pickle
except:
import pickle
from multiprocessing import Pool

# example command line: python generate_k400_pkl.py kinetics-400_train.csv 8
#
# kinetics-400_train.csv is the training set file of K400 official release
# each line contains laebl,youtube_id,time_start,time_end,split,is_cc

assert (len(sys.argv) == 5)

f = open(sys.argv[1])
source_dir = sys.argv[2]
target_dir = sys.argv[3]
num_threads = sys.argv[4]
all_video_entries = [x.strip().split(',') for x in f.readlines()]
all_video_entries = all_video_entries[1:]
f.close()

category_label_map = {}
f = open('kinetics400_label.txt')
for line in f:
ens = line.strip().split(' ')
category = " ".join(ens[0:-1])
label = int(ens[-1])
category_label_map[category] = label
f.close()


def generate_pkl(entry):
mode = entry[4]
category = entry[0].strip('"')
category_dir = category
video_path = os.path.join(
'./',
entry[1] + "_%06d" % int(entry[2]) + "_%06d" % int(entry[3]) + ".mp4")
video_path = os.path.join(source_dir, category_dir, video_path)
label = category_label_map[category]

vid = './' + video_path.split('/')[-1].split('.')[0]
if os.path.exists(video_path):
if not os.path.exists(vid):
os.makedirs(vid)
os.system('ffmpeg -i ' + video_path + ' -q 0 ' + vid + '/%06d.jpg')
else:
print("File not exists {}".format(video_path))
return

images = sorted(glob.glob(vid + '/*.jpg'))
ims = []
for img in images:
f = open(img, 'rb')
ims.append(f.read())
f.close()

output_pkl = vid + ".pkl"
output_pkl = os.path.join(target_dir, output_pkl)
f = open(output_pkl, 'wb')
pickle.dump((vid, label, ims), f, protocol=2)
f.close()

os.system('rm -rf %s' % vid)


pool = Pool(processes=int(sys.argv[4]))
pool.map(generate_pkl, all_video_entries)
pool.close()
pool.join()
Loading