算法介绍:【数学】时间复杂度O(1)的离散采样算法—— Alias method/别名采样方法
先上代码
class alias_sampling(object):
def __init__(self,prob):
self.n = len(prob)
self.area_ratio = prob * self.n
self.__create_alias_table()
def __create_alias_table(self):
low_indices,hight_indices = [],[]
self.accept,self.alias = [0] * self.n,[0] * self.n
for i,ratio in enumerate(self.area_ratio):
if ratio < 1.0:
low_indices.append(i)
else:
hight_indices.append(i)
while low_indices and hight_indices:
low_id,hight_id = low_indices.pop(),hight_indices[-1]
self.accept[low_id] = self.area_ratio[low_id]
self.alias[low_id] = hight_id
self.area_ratio[hight_id] -= (1 - self.area_ratio[low_id])
if self.area_ratio[hight_id] < 1.0:
hight_indices.pop()
low_indices.append(hight_id)
for hight_id in hight_indices:
self.accept[hight_id] = 1.
for low_id in low_indices:
self.accept[low_id] = 1.
del self.area_ratio
def get_sample(self):
id = np.random.randint(0,self.n)
if self.accept[id] > np.random.uniform(0,1):
return id
return self.alias[id]
def get_samples(self,n):
for i in range(n):
yield self.get_sample()
def test(n=100,k=1000000):
def gen_prop_dist(n):
p = np.random.uniform(0,10,n)
return p / np.sum(p)
p = gen_prop_dist(n)
sampling = alias_sampling(p)
p_sample = np.zeros(n)
for sample in sampling.get_samples(k):
p_sample[sample] += 1
p_sample /= k
print(np.linalg.norm(p-p_sample,2))
if __name__ == '__main__':
for i in range(10):
test()
假设概率分布为 { p i } i = 0 n − 1 \{p_i\}_{i=0}^{n-1} { pi}i=0n