各依赖关系和概率如上图所示,建立贝叶斯网,使用直接事先抽样方法
各参数的含义:
- C:Cloud
- S:Sprinker
- R:Rain
- W:WetGrass
贝叶斯公式:
P
(
C
a
u
s
e
∣
E
f
f
e
c
t
)
=
P
(
E
f
f
e
c
t
∣
C
a
u
s
e
)
P
(
C
a
u
s
e
)
P
(
E
f
f
e
c
t
)
P(Cause|Effect) = \frac{P(Effect|Cause)P(Cause)}{P(Effect)}
P(Cause∣Effect)=P(Effect)P(Effect∣Cause)P(Cause)
贝叶斯网代码如下:
import dataclasses
import numpy as np
from typing import Dict, List, Optional, TypeVar, Tuple, Union
Representation 各个类的表示
#python函数修饰符@的作用是为现有函数增加额外的功能
@dataclasses.dataclass
class Probability:
value: float
@dataclasses.dataclass
class ConditionalProb(Probability):
condition: str
value: Tuple[Probability, Probability]
@dataclasses.dataclass
class BoolNode:
name: str
cpt: Probability
Network Definition 定义网络
cloudy = BoolNode(name='Cloudy', cpt=Probability(value=0.5))
sprinkler = BoolNode(
name='Sprinkler',
cpt=ConditionalProb(condition='Cloudy', value=(Probability(value=0.1), Probability(value=0.5)))
)
rain = BoolNode(
name='Rain',
cpt=ConditionalProb(condition='Cloudy', value=(Probability(value=0.8), Probability(value=0.2)))
)
wetgrass = BoolNode(
name='WetGrass',
cpt=ConditionalProb(condition='Sprinkler',
value=(
ConditionalProb(condition='Rain',
value=(Probability(value=0.99), Probability(value=0.9))),
ConditionalProb(condition='Rain',
value=(Probability(value=0.9), Probability(value=0.01)))
))
)
nodes = [cloudy, sprinkler, rain, wetgrass]
node_dict: Dict[str, BoolNode] = {n.name: n for n in nodes}
get_cond = lambda x: list(set([x.condition] + [v for c in x.value for v in get_cond(c)])) if isinstance(x, ConditionalProb) else []
network: Dict[str, List[str]] = {n.name: get_cond(n.cpt) for n in nodes}
print(node_dict)
print(network)
print(get_cond(wetgrass.cpt))
Direct Prior Sampling 直接事先抽样
samples = []
for _ in range(10):
s = {}
for n in nodes:
rnd = np.random.rand()
node = node_dict[n.name]
if len(network[n.name]) == 0:
s[n.name] = True if rnd < node.cpt.value else False
else:
get_val = lambda x: (x.value[0] if s[x.condition] else x.value[1]) if isinstance(x.value, tuple) else x.value
val = node.cpt
while not isinstance(val, float):
val = get_val(val)
s[n.name] = True if rnd < val else False
print(s)