Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions image_classification/caffe2paddle/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
## 使用说明

`caffe2paddle.py`提供了将Caffe训练的模型转换为PaddlePaddle可使用的模型的接口`ModelConverter`,其封装了图像领域常用的Convolution、BatchNorm等layer的转换函数,可完成VGG、ResNet等常用模型的转换。模型转换的基本过程是:基于Caffe的Python API加载模型并依次获取每一个layer的信息,将其中的参数根据layer类型与PaddlePaddle适配后序列化保存(对于Pooling等无需训练的layer不做处理),输出可以直接为PaddlePaddle的Python API加载使用的模型文件。

`ModelConverter`的定义及说明如下:

```python
class ModelConverter(object):
#设置Caffe网络配置文件、模型文件路径和要保存为的Paddle模型的文件名,并使用Caffe API加载模型
def __init__(self, caffe_model_file, caffe_pretrained_file, paddle_tar_name)

#输出保存Paddle模型
def to_tar(self, f)

#将参数值序列化输出为二进制
@staticmethod
def serialize(data, f)

#依次对各个layer进行转换,转换时参照name_map进行layer和参数命名
def convert(self, name_map={})

#对Caffe模型的Convolution层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名
@wrap_name_default("img_conv_layer")
def convert_Convolution_layer(self, params, name=None)

#对Caffe模型的InnerProduct层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名
@wrap_name_default("fc_layer")
def convert_InnerProduct_layer(self, params, name=None)

#对Caffe模型的BatchNorm层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名
@wrap_name_default("batch_norm_layer")
def convert_BatchNorm_layer(self, params, name=None)

#对Caffe模型的Scale层的参数进行转换,将使用name值对Paddle模型中对应layer的参数命名
def convert_Scale_layer(self, params, name=None)

#输入图片路径和均值文件路径,使用加载的Caffe模型进行预测
def caffe_predict(self, img, mean_file)

```

`ModelConverter`的使用方法如下:

```python
#指定Caffe网络配置文件、模型文件路径和要保存为的Paddle模型的文件名,并从指定文件加载模型
converter = ModelConverter("./ResNet-50-deploy.prototxt",
"./ResNet-50-model.caffemodel",
"Paddle_ResNet50.tar.gz")
#进行模型转换
converter.convert(name_map={})
#进行预测并输出预测概率以便对比验证模型转换结果
converter.caffe_predict(img='./caffe/examples/images/cat.jpg')
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

文档觉得需要简洁明了告诉用户怎么用,比如配置好参数之后,运行:

python caffe2paddle.py 

然后再把code相关解释,注意事项列出来。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


为验证并使用转换得到的模型,需基于PaddlePaddle API编写对应的网络结构配置文件,具体可参照PaddlePaddle使用文档,我们这里附上ResNet的配置以供使用。需要注意,上文给出的模型转换在调用`ModelConverter.convert`时传入了空的`name_map`,这将在遍历每一个layer进行参数保存时使用PaddlePaddle默认的layer和参数命名规则:以`wrap_name_default`中的值和调用计数构造layer name,并以此为前缀构造参数名(比如第一个InnerProduct层的bias参数将被命名为`___fc_layer_0__.wbias`);为此,在编写PaddlePaddle网络配置时要保证和Caffe端模型使用同样的拓扑顺序,尤其是对于ResNet这种有分支的网络结构,要保证两分支在PaddlePaddle和Caffe中先后顺序一致,这样才能够使得模型参数正确加载。如果不希望使用默认的layer name,可以使用一种更为精细的方法:建立Caffe和PaddlePaddle网络配置间layer name对应关系的`dict`并在调用`ModelConverter.convert`时作为`name_map`传入,这样在命名保存layer中的参数时将使用相应的layer name,另外这里只针对Caffe网络配置中Convolution、InnerProduct和BatchNorm类别的layer建立`name_map`即可(一方面,对于Pooling等无需训练的layer不需要保存,故这里没有提供转换接口;另一方面,对于Caffe中的Scale类别的layer,由于Caffe和PaddlePaddle在实现上的一些差别,PaddlePaddle中的batch_norm层同时包含BatchNorm和Scale层的复合,故这里对Scale进行了特殊处理)。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

分多个段落吧。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
import os
import functools
import inspect
Expand All @@ -9,6 +8,7 @@
import numpy as np
import caffe
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
from image import load_and_transform


def __default_not_set_callback__(kwargs, name):
Expand Down Expand Up @@ -90,23 +90,27 @@ def func(name=None):

class ModelConverter(object):
def __init__(self, caffe_model_file, caffe_pretrained_file,
paddle_tar_name):
paddle_output_path, paddle_tar_name):
self.net = caffe.Net(caffe_model_file, caffe_pretrained_file,
caffe.TEST)
self.output_path = paddle_output_path
self.tar_name = paddle_tar_name
self.params = dict()
self.pre_layer_name = ""
self.pre_layer_type = ""

def convert(self):
def convert(self, name_map={}):
layer_dict = self.net.layer_dict
for layer_name in layer_dict.keys():
layer = layer_dict[layer_name]
layer_params = layer.blobs
layer_type = layer.type
if len(layer_params) > 0:
self.pre_layer_name = getattr(
self, "convert_" + layer_type + "_layer")(layer_params)
self, "convert_" + layer_type + "_layer")(
layer_params,
name=None
if name_map == None else name_map.get(layer_name))
self.pre_layer_type = layer_type
with gzip.open(self.tar_name, 'w') as f:
self.to_tar(f)
Expand Down Expand Up @@ -136,7 +140,7 @@ def serialize(data, f):
f.write(struct.pack("IIQ", 0, 4, data.size))
f.write(data.tobytes())

@wrap_name_default("conv")
@wrap_name_default("img_conv_layer")
def convert_Convolution_layer(self, params, name=None):
for i in range(len(params)):
data = np.array(params[i].data)
Expand All @@ -149,6 +153,7 @@ def convert_Convolution_layer(self, params, name=None):
param_conf.name = file_name
param_conf.size = reduce(lambda a, b: a * b, data.shape)
self.params[file_name] = (param_conf, data.flatten())

return name

@wrap_name_default("fc_layer")
Expand All @@ -171,9 +176,10 @@ def convert_InnerProduct_layer(self, params, name=None):
self.params[file_name] = (param_conf, data.flatten())
return name

@wrap_name_default("batch_norm")
@wrap_name_default("batch_norm_layer")
def convert_BatchNorm_layer(self, params, name=None):
scale = np.array(params[-1].data)
scale = 1 / np.array(params[-1].data)[0] if np.array(
params[-1].data)[0] != 0 else 0
for i in range(2):
data = np.array(params[i].data) * scale
file_name = "_%s.w%s" % (name, str(i + 1))
Expand Down Expand Up @@ -210,19 +216,7 @@ def caffe_predict(self,
mean_file='./caffe/imagenet/ilsvrc_2012_mean.npy'):
net = self.net

mu = np.load(mean_file)
mu = mu.mean(1).mean(1)

transformer = caffe.io.Transformer({
'data': net.blobs['data'].data.shape
})
transformer.set_transpose('data', (2, 0, 1))
transformer.set_mean('data', mu)
transformer.set_raw_scale('data', 255)
transformer.set_channel_swap('data', (2, 1, 0))
im = caffe.io.load_image(img)

net.blobs['data'].data[...] = transformer.preprocess('data', im)
net.blobs['data'].data[...] = load_img(img, mean_file)
out = net.forward()

output_prob = net.blobs['prob'].data[0].flatten()
Expand All @@ -231,9 +225,19 @@ def caffe_predict(self,
print 'predicted class is:', output_prob.argmax()


def load_image(file, mean_file):
im = load_and_transform(file, 256, 224, is_train=False)
im = im[(2, 1, 0), :, :]
mu = np.load(mean_file)
mu = mu.mean(1).mean(1)
im = im - mu[:, None, None]
im = im / 255.0
return im


if __name__ == "__main__":
converter = ModelConverter("./VGG_ILSVRC_16_layers_deploy.prototxt",
"./VGG_ILSVRC_16_layers.caffemodel",
"test_vgg16.tar.gz")
converter.convert()
converter.caffe_predict(img='./caffe/examples/images/cat.jpg')
converter = ModelConverter("./resnet50/ResNet-50-deploy.prototxt",
"./resnet50/ResNet-50-model.caffemodel",
"paddle_resnet50.tar.gz")
converter.convert(name_map=dict())
converter.caffe_predict("./images/cat.jpg")
223 changes: 223 additions & 0 deletions image_classification/caffe2paddle/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件不用引入,安装paddle之后,import即可~

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,删除了image.py,caffe2paddle.py中加入了简单的预处理

try:
import cv2
except:
print(
"import cv2 error, please install opencv-python: pip install opencv-python"
)

__all__ = [
"load_image", "resize_short", "to_chw", "center_crop", "random_crop",
"left_right_flip", "simple_transform", "load_and_transform"
]
"""
This file contains some common interfaces for image preprocess.
Many users are confused about the image layout. We introduce
the image layout as follows.
- CHW Layout
- The abbreviations: C=channel, H=Height, W=Width
- The default layout of image opened by cv2 or PIL is HWC.
PaddlePaddle only supports the CHW layout. And CHW is simply
a transpose of HWC. It must transpose the input image.
- Color format: RGB or BGR
OpenCV use BGR color format. PIL use RGB color format. Both
formats can be used for training. Noted that, the format should
be keep consistent between the training and inference peroid.
"""


def load_image(file, is_color=True):
"""
Load an color or gray image from the file path.
Example usage:

.. code-block:: python
im = load_image('cat.jpg')
:param file: the input image path.
:type file: string
:param is_color: If set is_color True, it will load and
return a color image. Otherwise, it will
load and return a gray image.
"""
# cv2.IMAGE_COLOR for OpenCV3
# cv2.CV_LOAD_IMAGE_COLOR for older OpenCV Version
# cv2.IMAGE_GRAYSCALE for OpenCV3
# cv2.CV_LOAD_IMAGE_GRAYSCALE for older OpenCV Version
# Here, use constant 1 and 0
# 1: COLOR, 0: GRAYSCALE
flag = 1 if is_color else 0
im = cv2.imread(file, flag)
return im


def resize_short(im, size):
"""
Resize an image so that the length of shorter edge is size.
Example usage:

.. code-block:: python
im = load_image('cat.jpg')
im = resize_short(im, 256)

:param im: the input image with HWC layout.
:type im: ndarray
:param size: the shorter edge size of image after resizing.
:type size: int
"""
assert im.shape[-1] == 1 or im.shape[-1] == 3
h, w = im.shape[:2]
h_new, w_new = size, size
if h > w:
h_new = size * h / w
else:
w_new = size * w / h
im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC)
return im


def to_chw(im, order=(2, 0, 1)):
"""
Transpose the input image order. The image layout is HWC format
opened by cv2 or PIL. Transpose the input image to CHW layout
according the order (2,0,1).
Example usage:

.. code-block:: python
im = load_image('cat.jpg')
im = resize_short(im, 256)
im = to_chw(im)

:param im: the input image with HWC layout.
:type im: ndarray
:param order: the transposed order.
:type order: tuple|list
"""
assert len(im.shape) == len(order)
im = im.transpose(order)
return im


def center_crop(im, size, is_color=True):
"""
Crop the center of image with size.
Example usage:

.. code-block:: python
im = center_crop(im, 224)

:param im: the input image with HWC layout.
:type im: ndarray
:param size: the cropping size.
:type size: int
:param is_color: whether the image is color or not.
:type is_color: bool
"""
h, w = im.shape[:2]
h_start = (h - size) / 2
w_start = (w - size) / 2
h_end, w_end = h_start + size, w_start + size
if is_color:
im = im[h_start:h_end, w_start:w_end, :]
else:
im = im[h_start:h_end, w_start:w_end]
return im


def random_crop(im, size, is_color=True):
"""
Randomly crop input image with size.
Example usage:

.. code-block:: python
im = random_crop(im, 224)

:param im: the input image with HWC layout.
:type im: ndarray
:param size: the cropping size.
:type size: int
:param is_color: whether the image is color or not.
:type is_color: bool
"""
h, w = im.shape[:2]
h_start = np.random.randint(0, h - size + 1)
w_start = np.random.randint(0, w - size + 1)
h_end, w_end = h_start + size, w_start + size
if is_color:
im = im[h_start:h_end, w_start:w_end, :]
else:
im = im[h_start:h_end, w_start:w_end]
return im


def left_right_flip(im):
"""
Flip an image along the horizontal direction.
Return the flipped image.
Example usage:

.. code-block:: python
im = left_right_flip(im)

:paam im: input image with HWC layout
:type im: ndarray
"""
if len(im.shape) == 3:
return im[:, ::-1, :]
else:
return im[:, ::-1, :]


def simple_transform(im, resize_size, crop_size, is_train, is_color=True):
"""
Simply data argumentation for training. These operations include
resizing, croping and flipping.
Example usage:

.. code-block:: python
im = simple_transform(im, 256, 224, True)
:param im: The input image with HWC layout.
:type im: ndarray
:param resize_size: The shorter edge length of the resized image.
:type resize_size: int
:param crop_size: The cropping size.
:type crop_size: int
:param is_train: Whether it is training or not.
:type is_train: bool
"""
im = resize_short(im, resize_size)
if is_train:
im = random_crop(im, crop_size)
if np.random.randint(2) == 0:
im = left_right_flip(im)
else:
im = center_crop(im, crop_size)
im = to_chw(im)

return im


def load_and_transform(filename,
resize_size,
crop_size,
is_train,
is_color=True):
"""
Load image from the input file `filename` and transform image for
data argumentation. Please refer to the `simple_transform` interface
for the transform operations.
Example usage:

.. code-block:: python
im = load_and_transform('cat.jpg', 256, 224, True)
:param filename: The file name of input image.
:type filename: string
:param resize_size: The shorter edge length of the resized image.
:type resize_size: int
:param crop_size: The cropping size.
:type crop_size: int
:param is_train: Whether it is training or not.
:type is_train: bool
"""
im = load_image(filename)
im = simple_transform(im, resize_size, crop_size, is_train, is_color)
return im
Loading