[RL]贝叶斯网实现—天气预测

在这里插入图片描述
各依赖关系和概率如上图所示,建立贝叶斯网,使用直接事先抽样方法

各参数的含义:

  • C:Cloud
  • S:Sprinker
  • R:Rain
  • W:WetGrass

贝叶斯公式:
P ( C a u s e ∣ E f f e c t ) = P ( E f f e c t ∣ C a u s e ) P ( C a u s e ) P ( E f f e c t ) P(Cause|Effect) = \frac{P(Effect|Cause)P(Cause)}{P(Effect)} P(CauseEffect)=P(Effect)P(EffectCause)P(Cause)

贝叶斯网代码如下:

import dataclasses
import numpy as np

from typing import Dict, List, Optional, TypeVar, Tuple, Union

Representation 各个类的表示

#python函数修饰符@的作用是为现有函数增加额外的功能
@dataclasses.dataclass
class Probability:
    value: float
        
@dataclasses.dataclass
class ConditionalProb(Probability):
    condition: str
    value: Tuple[Probability, Probability]

@dataclasses.dataclass
class BoolNode:
    name: str
    cpt: Probability

Network Definition 定义网络

cloudy = BoolNode(name='Cloudy', cpt=Probability(value=0.5))

sprinkler = BoolNode(
    name='Sprinkler', 
    cpt=ConditionalProb(condition='Cloudy', value=(Probability(value=0.1), Probability(value=0.5)))
)

rain = BoolNode(
    name='Rain', 
    cpt=ConditionalProb(condition='Cloudy', value=(Probability(value=0.8), Probability(value=0.2)))
)

wetgrass = BoolNode(
    name='WetGrass', 
    cpt=ConditionalProb(condition='Sprinkler',
                        value=(
                            ConditionalProb(condition='Rain',
                                            value=(Probability(value=0.99), Probability(value=0.9))),
                            ConditionalProb(condition='Rain',
                                            value=(Probability(value=0.9), Probability(value=0.01)))
                            ))
)


nodes = [cloudy, sprinkler, rain, wetgrass]
node_dict: Dict[str, BoolNode] = {n.name: n for n in nodes}
get_cond = lambda x: list(set([x.condition] + [v for c in x.value for v in get_cond(c)])) if isinstance(x, ConditionalProb) else []
network: Dict[str, List[str]] = {n.name: get_cond(n.cpt) for n in nodes}
    
print(node_dict)
print(network)
print(get_cond(wetgrass.cpt))

Direct Prior Sampling 直接事先抽样

samples = []

for _ in range(10):
    s = {}
    for n in nodes:
        rnd = np.random.rand()
        node = node_dict[n.name]
        if len(network[n.name]) == 0:
            s[n.name] = True if rnd < node.cpt.value else False
        else:
            get_val = lambda x: (x.value[0] if s[x.condition] else x.value[1]) if isinstance(x.value, tuple) else x.value
            
            val = node.cpt
            while not isinstance(val, float):
                val = get_val(val)
                
            s[n.name] = True if rnd < val else False
    print(s)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

是土豆大叔啊!

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值