在python中强大的decorator可以发挥很大的作用,而lua中的first-class function想来一定可以实现类似的东西,在lua wiki上有一篇文章专门讲解了lua中的decorator实现(http://lua-users.org/wiki/DecoratorsAndDocstrings),本文在此基础上做了一些简单的扩展。
首先在python中,decorator更像是一种语法糖,比如:
@dec2
@dec1
def foo():
pass
实际上等价于 foo = dec2(dec1(foo)),另一种更复杂的写法:
@dec2(a, b)
@dec1
def foo():
pass
也等价于 foo = dec2(a, b)(dec1(foo))
根据这样的语义在lua中实现decorator并不难,但是在lua中并不支持类似 @dec这样的语法,我们需要定义新的运算符完成语义
比如采用“..”,所以可能看起来像这样:
local foo =
dec2(a, b) ..
dec1 ..
function ()
return 11
end
除了这两种用法,为了实现一个简单易用的库,本文还实现了只针对输入以及只针对输出的几个接口,先看一个使用demo:
local multiply_input = prewrap(function (self, ...)
local paras = {...}
for k, v in pairs(paras) do
paras[k] = paras[k]*self.__paras[1]
end
return paras
end)
这里实现的是一个只针对输入参数进行处理的decorator
local typecheck = fullwrap(
function (self, ...)
local paras = {...}
for k, v in pairs(self.__paras) do
if v == "->" then break end
if type(paras[k]) ~= v then
error(string.format("input type error, expect:%s got:%s",
v, type(paras[k])), 3)
end
end
return paras
end,
function (self, res)
local idx
for k, v in pairs(self.__paras) do
if v == "->" then
idx = k
break
end
end
for k, v in pairs(res) do
if type(v) ~= self.__paras[idx+k] then
error(string.format("result type error, expect:%s, got:%s",
self.__paras[idx+k], type(v)), 3)
end
end
return res
end
)
这里实现了一个即处理输入,同时也校验输出的decorator
local timer = raw_fullwrap(
function (self, ...)
self.begin = os.time()
return {...}
end,
function (self, res)
self.ends = os.time()
print(self.ends - self.begin)
return res
end
)
最后这个是计时decorator
local foo =
typecheck("number", "number", "->", "number") ..
multiply_input(3) ..
timer ..
function (x, y)
local something_fool = function ()
for i = 1, 100000000 do
y = math.sin(2)
end
end; something_fool()
--return "hello"
return x^2+y
end
print(foo(100, 1))
最后这段程序的输出是:7
90000.909297427
最后是整个demo的实现代码:
local M = {}
local mt = { __concat =
function (a, f)
return function (...)
local input = {...}
if a.__pre then input = a:__pre(...) end
local res = {f(unpack(input))}
if a.__after then res = a:__after(res) end
return unpack(res)
end
end
}
local function _wrap(pre, after)
return function (...)
return setmetatable({__pre = pre, __after = after, __paras = {...} }, mt)
end
end
local function prewrap(pre)
return _wrap(pre)
end
M.prewrap = prewrap
local function fullwrap(pre, after)
return _wrap(pre, after)
end
M.fullwrap = fullwrap
local function afterwrap(after)
return _wrap(nil, after)
end
M.afterwrap = afterwrap
local function _rawwrap(pre, after)
return setmetatable({__pre=pre, __after=after}, mt)
end
local function raw_prewrap(pre)
return _rawwrap(pre)
end
M.raw_prewrap = raw_prewrap
local function raw_fullwrap(pre, after)
return _rawwrap(pre, after)
end
M.raw_fullwrap = raw_fullwrap
local function raw_afterwrap(after)
return _rawwrap(nil, after)
end
M.raw_afterwrap = raw_afterwrap