编译
重新编译之前,注释掉Makefile.config
中的下面这句话:
WITH_PYTHON_LAYER:=1
重新进行编译:
make clean
make -j
make pycaffe
设置路径
将生成的pycaffe
添加到你的路径$PYTHONPATH
中,以便能够import caffe
export PYTHONPATH=$CAFFE_ROOT/python:$PYTHONPATH
其中$CAFFE_ROOT
为你caffe的安装根目录,根据自己的实际情况进行修改。
当然你也可以在你的python layer
中导入caffe
时使用相对导入,这样也可以避免你在安装了多个caffe
时,导入caffe
包混乱的问题
写好你自己定义的Python Layer
之后,你可以将该*.py
文件放在$PYTHONPATH
中;也可以放在任意路径(例如:$Layer
)中,随后再将该路径添加到$PYTHONPATH
中即可。
编写你的Python Layer
参考Caffe官方样例:$CAFFE_ROOT/examples/pycaffe/layers/pyloss.py
import caffe
import numpy as np
class EuclideanLossLayer(caffe.Layer):
"""
Compute the Euclidean Loss in the same manner as the C++ EuclideanLossLayer
to demonstrate the class interface for developing layers in Python.
"""
def setup(self, bottom, top):
# check input pair
if len(bottom) != 2:
raise Exception("Need two inputs to compute distance.")
def reshape(self, bottom, top):
# check input dimensions match
if bottom[0].count != bottom[1].count:
raise Exception("Inputs must have the same dimension.")
# difference is shape of inputs
self.diff = np.zeros_like(bottom[0].data, dtype=np.float32)
# loss output is scalar
top[0].reshape(1)
def forward(self, bottom, top):
self.diff[...] = bottom[0].data - bottom[1].data
top[0].data[...] = np.sum(self.diff**2) / bottom[0].num / 2.
def backward(self, top, propagate_down, bottom):
for i in range(2):
if not propagate_down[i]:
continue
if i == 0:
sign = 1
else:
sign = -1
bottom[i].diff[...] = sign * self.diff / bottom[i].num
使用你的Python Layer
参考$CAFFE_ROOT/examples/pycaffe/linreg.prototxt
layer {
type: 'Python'
name: 'loss'
top: 'loss'
bottom: 'ipx'
bottom: 'ipy'
python_param {
# the module name -- usually the filename -- that needs to be in $PYTHONPATH
# 文件名
module: 'pyloss'
# the layer name -- the class name in the module
# 该模块中的类名
layer: 'EuclideanLossLayer'
}
# set loss weight so Caffe knows this is a loss layer.
# since PythonLayer inherits directly from Layer, this isn't automatically
# known to Caffe
loss_weight: 1
}
杂七杂八
python_param
还可以有param_str
,用来传递该层参数,是字符串。
若以命令行方式而不是python
接口使用Python Layer
时,往往语句print
没有效果。可以使用,如下方式:
import logging
logging.basicConfig(level=logging.INFO)
logging.info('Info you want to print')