Python 标准库中最有用的装饰器

众所周知,Python 语言灵活、简洁,对程序员友好,但在性能上有点不太令人满意,这一点通过一个递归的求斐波那契额函数就可以说明:


def
fib(
n):


if
n
<=
1:


return
n


return
fib(
n
-
1)
+
fib(
n
-
2)


1.

2.

3.

4.

在我的 MBP 上计算 fib(40) 花费了 33 秒:

import
time




def
main():


start
=
time
.
time()


result
=
fib(
40)


end
=
time
.
time()


cost
=
end
-
start


print(
f
"{result = } {cost = :.4f}")




if
__name__
==
'__main__':


main()


1.

2.

3.

4.

5.

6.

7.

8.

9.

10.

11.

但是,假如使用标准库中的这个装饰器,那结果完全不一样 

from
functools
import
lru_cache




@lru_cache


def
fib(
n):


if
n
<=
1:


return
n


return
fib(
n
-
1)
+
fib(
n
-
2)


1.

2.

3.

4.

5.

6.

7.

这次的结果是 0 秒,你没看错,我保留了 4 位小数,后面的忽略了。

提升了多少倍?我已经计算不出来了。

为什么 lru_cache 装饰器这么牛逼,它到底做了什么事情?今天就来聊一聊这个最有用的装饰器。

如果看过计算机操作系统的话,你对 LRU 一定不会陌生,这就是著名的最近最久未使用缓存淘汰算法。

而 lru_cache 就是这个算法的具体实现。(这个算法可是面试经常考的哦,有的面试官要求现场手写代码)

现在,我们来看一个 lru_cache 的源代码,其中的英文注释,我已经为你翻译为中文:

def
lru_cache(
maxsize
=
128,
typed
=
False):


""
"LRU 缓存装饰器




如果
*
maxsize
*
是
None,
将不会淘汰缓存,缓存大小也不做限制




如果
*
typed
*
是
True,
不同类型的参数将独立做缓存,比如
f(
3.0)
and
f(
3)
将认为是不同的函数调用而缓存在两个缓存节点上。




函数的参数必须可以被
hash




查看缓存信息使用的是命名元组 (
hits,
misses,
maxsize,
currsize)


查看缓存信息:user_func
.
cache_info()
.
清理缓存信息:user_func
.
cache_clear()
.




LRU
算法:
http:
//en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)




""
"




#
lru_cache
的内部实现是线程安全的




if
isinstance(
maxsize,
int):


#
负数转换为
0


if
maxsize
<
0:


maxsize
=
0


elif
callable(
maxsize)
and
isinstance(
typed,
bool):


#如果被装饰的函数(user_function)直接通过
maxsize
参数传入


user_function,
maxsize
=
maxsize,
128


wrapper
=
_lru_cache_wrapper(
user_function,
maxsize,
typed,
_CacheInfo)


return
update_wrapper(
wrapper,
user_function)


elif
maxsize
is
not
None:


raise
TypeError(


'Expected first argument to be an integer, a callable, or None')




def
decorating_function(
user_function):


wrapper
=
_lru_cache_wrapper(
user_function,
maxsize,
typed,
_CacheInfo)


return
update_wrapper(
wrapper,
user_function)




return
decorating_function


1.

2.

3.

4.

5.

6.

7.

8.

9.

10.

11.

12.

13.

14.

15.

16.

17.

18.

19.

20.

21.

22.

23.

24.

25.

26.

27.

28.

29.

30.

31.

32.

33.

34.

35.

36.

这里面有两个参数,一个是 maxsize,表示缓存的大小,当传入负数时,自动设置为 0,如果不传入 maxsize,或者设置为 None,表示缓存没有大小限制,此时没有缓存淘汰。还有一个是 type,当 type 传入 True 时,不同的参数类型会当作不同的 key 存到缓存当中。

接下来,lru_cache 的核心在这个函数上 ​​_lru_cache_wrapper​​,建议有感情的阅读、背诵并默写。我们来看下它的源代码

def
_lru_cache_wrapper(
user_function,
maxsize,
typed,
_CacheInfo):


#
所有
lru
cache
实例共享的常量:


sentinel
=
object()
#
用来表示缓存未命中的唯一对象


make_key
=
_make_key
#
build
a
key
from
the
function
arguments


PREV,
NEXT,
KEY,
RESULT
=
0,
1,
2,
3
#
names
for
the
link
fields




cache
= {}


hits
=
misses
=
0


full
=
False


cache_get
=
cache
.
get
#
绑定函数来获取缓存中
key
的值


cache_len
=
cache
.
__len__
#
绑定函数获取缓存大小


lock
=
RLock()
#
因为链表上的更新是线程不安全的


root
= []
#
循环双向链表的根节点


root[:]
= [
root,
root,
None,
None]
#
初始化根节点的前后指针都指向它自己




if
maxsize
==
0:




def
wrapper(
*
args,
**
kwds):


#
没有缓存,仅更新统计信息


nonlocal
misses


misses
+=
1


result
=
user_function(
*
args,
**
kwds)


return
result




elif
maxsize
is
None:




def
wrapper(
*
args,
**
kwds):


#
仅仅排序,不考虑排序和缓存大小限制


nonlocal
hits,
misses


key
=
make_key(
args,
kwds,
typed)


result
=
cache_get(
key,
sentinel)


if
result
is
not
sentinel:


hits
+=
1


return
result


misses
+=
1


result
=
user_function(
*
args,
**
kwds)


cache[
key]
=
result


return
result




else:




def
wrapper(
*
args,
**
kwds):


#
大小有限制,并跟踪最近使用的缓存


nonlocal
root,
hits,
misses,
full


key
=
make_key(
args,
kwds,
typed)


with
lock:


link
=
cache_get(
key)


if
link
is
not
None:


#
缓存命中,将命中的缓存移动到循环双向链表的头部


link_prev,
link_next,
_key,
result
=
link


link_prev[
NEXT]
=
link_next


link_next[
PREV]
=
link_prev


last
=
root[
PREV]


last[
NEXT]
=
root[
PREV]
=
link


link[
PREV]
=
last


link[
NEXT]
=
root


hits
+=
1


return
result


misses
+=
1


result
=
user_function(
*
args,
**
kwds)


with
lock:


if
key
in
cache:


#
走到这里说明
key
已经放在了缓存,且锁已经释放了,链表已经更新了,这里什么也不需要做了,最后只需要返回计算的结果就可以了。


pass


elif
full:


#
如果缓存满了,
使用最老的根节点来存储新节点就可以了,链表上不需要删除(是不是很聪明)


oldroot
=
root


oldroot[
KEY]
=
key


oldroot[
RESULT]
=
result


root
=
oldroot[
NEXT]


oldkey
=
root[
KEY]


oldresult
=
root[
RESULT]


root[
KEY]
=
root[
RESULT]
=
None




#
最后,我们需要从缓存中清除这个
key,因为它已经无效了。


del
cache[
oldkey]


#
新值放入缓存


cache[
key]
=
oldroot


else:


#
如果没有满,将新的结果放入循环双向链表的头部


last
=
root[
PREV]


link
= [
last,
root,
key,
result]


last[
NEXT]
=
root[
PREV]
=
cache[
key]
=
link


#
使用
cache_len
绑定方法而不是
len()
函数,后者可能会被包装在
lru_cache
本身中


full
= (
cache_len()
>=
maxsize)


return
result




def
cache_info():


""
"报告缓存统计信息"
""


with
lock:


return
_CacheInfo(
hits,
misses,
maxsize,
cache_len())




def
cache_clear():


""
"清理缓存信息"
""


nonlocal
hits,
misses,
full


with
lock:


cache
.
clear()


root[:]
= [
root,
root,
None,
None]


hits
=
misses
=
0


full
=
False




wrapper
.
cache_info
=
cache_info


wrapper
.
cache_clear
=
cache_clear


return
wrapper


1.

2.

3.

4.

5.

6.

7.

8.

9.

10.

11.

12.

13.

14.

15.

16.

17.

18.

19.

20.

21.

22.

23.

24.

25.

26.

27.

28.

29.

30.

31.

32.

33.

34.

35.

36.

37.

38.

39.

40.

41.

42.

43.

44.

45.

46.

47.

48.

49.

50.

51.

52.

53.

54.

55.

56.

57.

58.

59.

60.

61.

62.

63.

64.

65.

66.

67.

68.

69.

70.

71.

72.

73.

74.

75.

76.

77.

78.

79.

80.

81.

82.

83.

84.

85.

86.

87.

88.

89.

90.

91.

92.

93.

94.

95.

96.

97.

98.

99.

100.

101.

102.

103.

104.

如果我写的注释你都看明白了,那也不用看我下面的废话了,如果还有点不太明白,我啰嗦几句,也许你就明白了。

第一、所谓缓存,用的仍然是内存,为了快速存取,用的就是一个 hash 表,也就是 Python 的字典,都是在内存里的操作。


cache
= {}


1.

第二、如果 maxsize == 0,就相当于没有使用缓存,每调用一次,未命中数就 + 1,代码逻辑是这样的:


def
wrapper(
*
args,
**
kwds):


nonlocal
misses


misses
+=
1
#
未命中数


result
=
user_function(
*
args,
**
kwds)


return
result


1.

2.

3.

4.

5.

第三、如果 maxsize == None,相当于缓存无限制,也就不需要考虑淘汰,这个实现非常简单,我们直接在函数中用一个字典就可以实现,比如说:

cache
= {}


def
fib(
n):




if
n
in
cache:


return
cache[
n]




if
n
<=
1:


return
n


result
=
fib(
n
-
1)
+
fib(
n
-
2)


cache[
n]
=
result


return
result


1.

2.

3.

4.

5.

6.

7.

8.

9.

10.

11.

运行时间:

 理解了这一点,在装饰器中,这段逻辑就不难看懂:

def
wrapper(
*
args,
**
kwds):


nonlocal
hits,
misses


key
=
make_key(
args,
kwds,
typed)


result
=
cache_get(
key,
sentinel)


if
result
is
not
sentinel:


hits
+=
1


return
result


misses
+=
1


result
=
user_function(
*
args,
**
kwds)


cache[
key]
=
result


return
result


1.

2.

3.

4.

5.

6.

7.

8.

9.

10.

11.

第四、真正的缓存淘汰算法。

为了实现缓存(键值对)的淘汰,我们需要对缓存按时间进行排序,这就需要用到链表,链表的头部是最新插入的,尾部是最老插入的,当缓存数量已经达到最大值时,我们删除最久未使用的链尾节点,为了不删除链尾,我们可以使用循环链表,当缓存满了,直接更新链尾节点赋值为新节点,并把它做为新的链头就可以了。

当缓存命中时,我们需要把这个节点移动到链表的头部,保证链表的头部是最近经常使用的,为了移动方便,我们需要双向链表。

双向循环链表在 Python 中实现,可以简单的这么写:

PREV,
NEXT,
KEY,
RESULT
=
0,
1,
2,
3
#
names
for
the
link
fields


root
= []
#
root
of
the
circular
doubly
linked
list


root[:]
= [
root,
root,
None,
None]
#
initialize
by
pointing
to
self


1.

2.

3.

可能有些朋友看不懂最后那行代码:​​root[:] = [root, root, None, None]​​,画个图你就理解了:

这些箭头指向的都是节点的内存地址,随着节点的增多,就是这个样子的:

对比这个图,再看源代码,就很容易看懂了。尤其是这块的代码逻辑,是面试常考的重点,如果你能手写出这样线程安全的 LRU 缓存淘汰算法,那无疑是非常优秀的。

其他 LRU 算法的实现

其他关于 LRU 算法的实现,我自己写了两个,可以看这里:

LRU 缓存淘汰算法-双链表+hash 表[1]

LRU 缓存淘汰算法-Python 有序字典[2]

最后的话

装饰器 lru_cache 的作用就是把函数的计算机结果保存下来,下次用的时候可以直接从 hash 表中取出,避免重复计算从而提升效率,简单点的,直接在函数中使用个字典就搞定了,复杂点的,请看 lru_cache 的代码实现。另一方面,递归函数慢的一个主要原因就是重复计算。

Python 标准库的源码,是学习编程最有营养的原料,当你有好奇心时,不妨去窥探一下源码,相信你有定会有新的收获。今天的分享就到这里,如果有收获的话,请点赞、在看、转发、关注,感谢你的支持。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值