目录
🌟 前言:当神经网络学会"推理"时会发生什么?
"银行如何预测连锁违约风险?疾控中心怎样模拟病毒传播路径?答案藏在概率图的因果迷宫中!"
本文将用PyTorch+Pyro实现贝叶斯推理引擎,手把手构建金融风险传染模型与疫情传播模拟器,揭秘概率图如何让AI获得逻辑推理能力。文末附赠概率图可视化工具包,让你像福尔摩斯一样破解变量间的隐藏关系!
一、结构化概率模型:世界的因果罗盘
1.1 概率图模型的三原色
import networkx as nx
import matplotlib.pyplot as plt
# 构建心脏病预测贝叶斯网络
G = nx.DiGraph()
edges = [("年龄", "动脉硬化"),
("吸烟", "动脉硬化"),
("动脉硬化", "心绞痛"),
("动脉硬化", "心肌梗塞"),
("运动量", "心绞痛")]
G.add_edges_from(edges)
pos = nx.spring_layout(G)
plt.figure(figsize=(10,6))
nx.draw(G, pos, with_labels=True, node_size=3000, node_color='#FFAAAA',
arrowsize=20, font_size=12)
plt.title("心脏疾病预测贝叶斯网络")
plt.show()
二、两大核心模型代码实战
2.1 贝叶斯网络:因果推理引擎
import torch
import pyro
import pyro.distributions as dist
def bayesian_network():
# 先验概率
age = pyro.sample("age", dist.Categorical(torch.tensor([0.3,0.7]))) # 0:年轻,1:老年
smoke = pyro.sample("smoke", dist.Bernoulli(0.2)) # 20%吸烟
# 条件概率表
art_sclerosis_prob = torch.tensor([
[[0.1, 0.9], # 年轻不吸烟
[0.3, 0.7]], # 年轻吸烟
[[0.4, 0.6], # 老年不吸烟
[0.8, 0.2]] # 老年吸烟
])
art_sclerosis = pyro.sample("art_sclerosis",
dist.Categorical(art_sclerosis_prob[age][smoke]))
# 后验推断
condition = {"age": torch.tensor(1), "art_sclerosis": torch.tensor(1)}
posterior = pyro.infer.Importance(bayesian_network, num_samples=1000).run()
smoke_prob = posterior.marginal("smoke").empirical["probs"][1]
print(f"已知老年且动脉硬化,吸烟概率: {smoke_prob*100:.1f}%")
# 运行推理
bayesian_network()
2.2 马尔可夫随机场:变量间的能量博弈
import pgmpy.models
from pgmpy.inference import BeliefPropagation
# 构建社交网络信任模型
model = pgmpy.models.MarkovRandomField()
model.add_nodes_from(["用户A", "用户B", "用户C", "用户D"])
model.add_edges_from([("用户A","用户B"), ("用户B","用户C"), ("C","D"), ("A","D")])
# 定义势函数
factors = []
for node in model.nodes:
factors.append(pgmpy.factors.discrete.DiscreteFactor(
variables=[node],
cardinality=[2], # 0:不可信,1:可信
values=[0.5, 0.5])) # 均匀先验
for edge in model.edges:
factors.append(pgmpy.factors.discrete.DiscreteFactor(
variables=list(edge),
cardinality=[2,2],
values=[[2,1], [1,2]])) # 相邻用户状态一致时能量低
model.add_factors(*factors)
# 信念传播推理
infer = BeliefPropagation(model)
infer.calibrate()
query = infer.query(variables=["用户D"], evidence={"用户A":0})
print(f"当用户A不可信时,用户D可信概率: {query.values[1]*100:.1f}%")
三、工业级应用案例解析
3.1 金融风险传染模拟
import pyro
import pyro.distributions as dist
def financial_network():
# 定义银行节点
banks = ["BankA", "BankB", "BankC"]
exposures = {
("BankA", "BankB"): 0.3,
("BankB", "BankC"): 0.5,
("BankA", "BankC"): 0.2
}
# 初始违约概率
with pyro.plate("banks", len(banks)):
default = pyro.sample("default", dist.Bernoulli(0.1)) # 10%基础违约率
# 风险传染过程
for i, bank in enumerate(banks):
neighbor_risk = sum(
exposures[(n, bank)] * default[j]
for j, n in enumerate(banks) if (n, bank) in exposures
)
pyro.sample(f"default_{bank}",
dist.Bernoulli(torch.sigmoid(2*neighbor_risk - 1)),
obs=default[i])
# 危机场景模拟
condition = {"default_BankA": torch.tensor(1.)} # BankA已违约
posterior = pyro.infer.Importance(financial_network, num_samples=1000).run(condition)
print("BankB违约概率:", posterior.marginal("default_BankB").empirical["probs"][1])
print("BankC违约概率:", posterior.marginal("default_BankC").empirical["probs"][1])
3.2 疫情传播预测
class SEIRModel(pyro.nn.PyroModule):
def __init__(self, population):
super().__init__()
self.population = population
self.beta = pyro.param("beta", torch.tensor(0.3)) # 传播率
self.gamma = pyro.param("gamma", torch.tensor(0.1)) # 康复率
def forward(self, days=30):
S = [self.population - 1]
E = [0]
I = [1]
R = [0]
for t in range(days):
new_infected = pyro.sample(f"inf_{t}",
dist.Poisson(self.beta * I[-1] * S[-1]/self.population))
new_exposed = pyro.sample(f"exp_{t}",
dist.Poisson(0.5 * new_infected))
new_recovered = pyro.sample(f"rec_{t}",
dist.Poisson(self.gamma * I[-1]))
S.append(S[-1] - new_infected)
E.append(E[-1] + new_exposed - new_infected)
I.append(I[-1] + new_infected - new_recovered)
R.append(R[-1] + new_recovered)
return {"S": S, "E": E, "I": I, "R": R}
# 参数推断
observations = {"I": [1,3,9,20,50,80,100,95,70,45]} # 实际感染数据
guide = pyro.infer.autoguide.AutoDelta(SEIRModel(1000))
pyro.clear_param_store()
infer = pyro.infer.SVI(SEIRModel(1000), guide,
pyro.optim.Adam({"lr": 0.01}),
loss=pyro.infer.Trace_ELBO())
for epoch in range(1000):
loss = infer.step(days=10)
if epoch % 100 == 0:
print(f"Epoch {epoch} Loss: {loss}")
print("推断传播率:", pyro.param("beta").item())
print("推断康复率:", pyro.param("gamma").item())
四、概率图模型工具包
4.1 动态概率传播可视化
import matplotlib.animation as animation
fig, ax = plt.subplots()
nx.draw(G, pos, with_labels=True, node_size=3000)
def update(frame):
ax.clear()
beliefs = infer.map_query(evidence=evidence_up_to(frame))
node_colors = [beliefs[node][1] for node in G.nodes]
nx.draw(G, pos, node_color=node_colors, cmap=plt.cm.Reds,
vmin=0, vmax=1, with_labels=True)
return ax,
ani = animation.FuncAnimation(fig, update, frames=10, interval=500)
plt.show()
🔥 模型选型指南
模型类型 | 方向性 | 推理方式 | 适用场景 | 典型库 |
---|---|---|---|---|
贝叶斯网络 | 有向 | 精确推断 | 因果推理 | pgmpy, pyro |
马尔可夫网络 | 无向 | 近似推断 | 关系建模 | OpenGM, libDAI |
因子图 | 混合 | 消息传递 | 复杂系统 | ForneyLab |
条件随机场 | 部分有向 | 梯度优化 | 序列标注 | CRFsuite |