Numpy中的ravel_multi_index函数

本文详细解析了NumPy中的ravel_multi_index函数的功能与用法,通过实例代码展示了如何将多维索引转换为一维索引,适用于数据处理与索引映射等场景。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

最近遇到了ravel_multi_index这个函数,官方文档看不明白,Google了一番好不容易才从一堆示例代码里理解函数的意义,记录一下。


官方文档在这
这个函数主要功能为把给定的一个多维数组(函数的第一个参数)看作索引数组,索引什么呢?去索引一个形状为dims(函数的第二个参数),值为依次增大的自然数的数组中的值(可看做由list(range(N))的数组reshape(dims)而来),意义即为用一个唯一的一维数来定位(保存)原数组的二维(或多维(i,j,k,…))的数对的信息。

把文档里的示例代码贴一下来解释:

>>> arr = np.array([[3,6,6],[4,5,1]])
>>> np.ravel_multi_index(arr, (7,6))
array([22, 41, 37])
>>> np.ravel_multi_index(arr, (7,6), order='F')
array([31, 41, 13])
>>> np.ravel_multi_index(arr, (4,6), mode='clip')
array([22, 23, 19])
>>> np.ravel_multi_index(arr, (4,4), mode=('clip','wrap'))
array([12, 13, 13])
>>> np.ravel_multi_index((3,1,4,1), (6,7,8,9))
1621

示例中arr即为要转换的多维数组,把arr的内容当作索引,即[3,6,6]为横坐标,[4,5,1]为纵坐标,去索引形状为(7,6),内容为从0开始,从左往右,从上往下依次增大的自然数的数组中的值。

例如第一个要索引的数[3,4]即为(7,6)数组中第4行,第5列的的值,即为3*6+4=22,即为结果中的第一个数。依次类推。

了解函数功能后,其他参数具体可见官方文档说明。

import numpy as np import pandas as pd import matplotlib.pyplot as plt import networkx as nx from pgmpy.models import BayesianNetwork from pgmpy.factors.discrete import TabularCPD from pgmpy.estimators import MaximumLikelihoodEstimator from pgmpy.inference import VariableElimination # 网络结构定义 model_structure = [ ('Pollution', 'Cancer'), ('Smoker', 'Cancer'), ('Cancer', 'Xray'), ('Cancer', 'Dyspnoea') ] # 节点状态定义 states = { 'Pollution': ['low', 'high'], 'Smoker': ['yes', 'no'], 'Cancer': ['yes', 'no'], 'Xray': ['yes', 'no'], 'Dyspnoea': ['yes', 'no'] } # 专家条件概率表(修复行顺序) expert_cpds = { 'Pollution': TabularCPD( variable='Pollution', variable_card=2, values=[[0.9], [0.1]] ), 'Smoker': TabularCPD( variable='Smoker', variable_card=2, values=[[0.3], [0.7]] ), 'Cancer': TabularCPD( variable='Cancer', variable_card=2, # 行顺序修正:第一行对应 Cancer=yes values=[ [0.97, 0.95, 0.999, 0.98], # Cancer=yes [0.03, 0.05, 0.001, 0.02] # Cancer=no ], evidence=['Smoker', 'Pollution'], evidence_card=[2, 2] ), 'Xray': TabularCPD( variable='Xray', variable_card=2, # 行顺序修正:第一行对应 Xray=yes values=[ [0.1, 0.8], # Xray=yes | Cancer=no, Xray=yes | Cancer=yes [0.9, 0.2] # Xray=no ], evidence=['Cancer'], evidence_card=[2] ), 'Dyspnoea': TabularCPD( variable='Dyspnoea', variable_card=2, # 行顺序修正:第一行对应 Dyspnoea=yes values=[ [0.35, 0.7], # Dyspnoea=yes | Cancer=no, Dyspnoea=yes | Cancer=yes [0.65, 0.3] # Dyspnoea=no ], evidence=['Cancer'], evidence_card=[2] ) } # 数据生成函数(修复索引类型) def generate_random_data(model, n_samples=1000): np.random.seed(42) data = pd.DataFrame() G = nx.DiGraph(model.edges()) sorted_nodes = list(nx.topological_sort(G)) # 预生成状态到索引的映射 state_to_index = {node: {state: idx for idx, state in enumerate(states[node])} for node in states} for node in sorted_nodes: cpd = model.get_cpds(node) if len(cpd.variables) == 1: # 根节点 values = cpd.values.flatten() data[node] = np.random.choice(states[node], p=values, size=n_samples) else: # 非根节点 evidence = cpd.variables[1:] # 父节点列表 evidence_card = [len(states[ev]) for ev in evidence] parent_data = data[evidence] # 检查父节点数据完整性 if parent_data.isnull().any().any(): missing = parent_data.columns[parent_data.isnull().any()].tolist() raise ValueError(f"父节点 {missing} 数据未生成") # 批量计算父节点索引 evidence_indices = np.array([ [state_to_index[ev][val] for ev, val in zip(evidence, parent_data.iloc[i])] for i in range(n_samples) ], dtype=int) # 强制类型转换为整数 # 计算 CPD 列索引 try: cpd_indices = np.ravel_multi_index(evidence_indices.T, evidence_card) cpd_indices = cpd_indices.astype(int) # 强制转换为整数 except ValueError as e: raise ValueError(f"节点 {node} 的索引计算错误: {e}") # 检查索引范围 if cpd_indices.max() >= cpd.values.shape[1]: raise IndexError( f"CPD索引越界: 最大索引 {cpd_indices.max()}, " f"允许最大索引 {cpd.values.shape[1] - 1}" ) # 生成当前节点数据(确保索引为整数) data[node] = [ np.random.choice(states[node], p=cpd.values[:, int(idx)].flatten()) for idx in cpd_indices ] return data这是相关代码片段请帮我解决上述问题
最新发布
04-03
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值