网上相关实现非常多,但都不够工具化(主要不适合放在工程代码中),因此贴一段性能和代码结构上都比较满意的代码片。
# -*- coding: utf-8 -*-
from array import array
from copy import deepcopy
from random import random
from bisect import bisect
__all__ = ['WeightedRandom', ]
class WeightObj(object):
'''记录对象和其初始化权重'''
def __init__(self, weight, obj):
self.weight, self.obj = weight, obj
class WeightedRandom(object):
def __init__(self):
'''初始化总权重、权重列表以及随机对象
权重列表:为了节省内存空间使用array结构
'''
self.__total_weight = 0
self.__weight = array('d')
self.__objects = []
def addObject(self, weight, obj):
'''赋权,加对象'''
self.__total_weight += weight
self.__weight.append(self.__total_weight)
self.__objects.append(WeightObj(weight, obj))
def getObject(self):
'''随机获得对象'''
return self.__objects[bisect(self.__weight, random()*self.__total_weight)].obj
def exclusiveGet(self, get_count):
'''排他获取: 即获得过一次不再获得,
get_count: 最后获得几个结果,最多不超过加入的总数
'''
result = []
temp_total = self.__total_weight
temp_weight, temp_objs = self.__weight[:], deepcopy(self.__objects)
for _ in xrange(get_count):
get_index = bisect(temp_weight, random()*temp_total)
try:
current_obj = temp_objs.pop(get_index)
temp_weight.pop(get_index)
except IndexError:
'''获取数超过存储的数据数,直接返回结果'''
return result
result.append(current_obj.obj)
temp_total -= current_obj.weight
return result
if __name__ == '__main__':
from collections import defaultdict
factory = WeightedRandom()
for weight, str_obj in ((1, 'a'), (4, 'b'), (5, 'c')):
factory.addObject(weight, str_obj)
'''检查排他的方法是否正确'''
print(factory.exclusiveGet(100))
'''随机10000次,校验分布情况是否符合预期'''
count_dict = defaultdict(int)
for _ in xrange(10000):
get_obj = factory.getObject()
count_dict[get_obj] += 1
print(dict(count_dict))
478

被折叠的 条评论
为什么被折叠?



