captum 分析的数据准备(一)

本文介绍了如何构建基线数据,通过在原始数据中填充0以确保等长,利用tokenizer获取填充字符、起始和结束标识。接着,对数据进行编码并计算长度,进一步使用encoder进行编码处理,最后调用construct_input_base函数完成预处理步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

构建基线数据,用pad=0填充构造与输入等长的数据

获取原数据的填充字符,起始,结束 使用tokenizer直接获取

PAD_ID = tokenizer.pad_token_id  # 补齐符号 pad 的 id
SEP_ID = tokenizer.sep_token_id  # 输入文本中句子的间隔符号 sep 的 id
# 置于文本初始位置的符号 cls 的 id,在 bert 预训练过程中其对应的输出用于预测句子是否相关
CLS_ID = tokenizer.cls_token_id

对原数据进行编码,用len获取长度

def construct_input_base(context, question):
    # question 指问题,context 指需要从中寻求问题答案的文本
    # 将 context 转化为 id
    context_ids = tokenizer.encode(context, add_special_tokens=False)
    
    # 将 question 转化为 id
    question_ids = tokenizer.encode(question, add_special_tokens=False)

    # 基线输入,与 input_ids 等长
    base_input_ids = [CLS_ID] + [PAD_ID] * len(question_ids) + [SEP_ID] + \
        [PAD_ID] * len(context_ids) + [SEP_ID]

    return torch.tensor([base_input_ids], device=DEVICE)
tokenizer.encode(question, add_special_tokens&#
import scipy.io import torch from torch_geometric.data import Data from torch_geometric.nn import GNNExplainer from torch_geometric.utils import to_networkx import matplotlib.pyplot as plt import networkx as nx from edge_index_to_adj_node_edge import edge_index_to_adj_node_edge # 加载 .mat 文件 mat_file = "D:/GNN/wy0303-62OLD/data/data0318/data0318.mat" # 替换为你的 .mat 文件路径 mat_data = scipy.io.loadmat(mat_file) # 提取数据 edge_index = mat_data['edge_index'] # 边的索引 n_s = mat_data['n_s'] # 源节点特征 n_d = mat_data['n_d'] # 目标节点特征 m_s = mat_data['m_s'] # 可能是边特征 m_d = mat_data['m_d'] # 可能是边特征 E = mat_data['E'] # 可能是些额外的图特征 sc = mat_data['sc'] # 可能是额外的图特征 adj_matrix = edge_index_to_adj_node_edge(edge_index) # 如果存在邻接矩阵 # 转换为 PyTorch 张量 edge_index_tensor = torch.tensor(edge_index, dtype=torch.long) node_features_tensor = torch.cat((torch.tensor(n_s, dtype=torch.float), torch.tensor(n_d, dtype=torch.float)), dim=0) # 处理节点特征(聚合特征维度,得到个二维张量) node_features_processed = node_features_tensor.mean(dim=(1, 2, 3)) # 你可以根据需要修改此方法 # 创建 PyTorch Geometric 数据对象 data = Data(x=node_features_processed, edge_index=edge_index_tensor) # 加载预训练的模型 model = torch.load('D:/GNN/wy0303-62OLD/best_model.pt') model.eval() # 设置为评估模式 # 选择个节点进行解释 node_idx = 0 # 选择图中的个节点进行解释 # 使用 GNNExplainer 进行解释 explainer = GNNExplainer(model, epochs=200) print(f"edge_index type: {type(data.edge_index)}") print(f"edge_index shape: {data.edge_index.shape}") # 在调用 explain_node 时传递所有额外的输入(包括 edge_index) # 这里我们明确传递 edge_index 给模型和 explainer # n_s, n_d, m_s, m_d, E, sc,edge_index,adj_matrix explanation = explainer.explain_node( node_idx, data.x, data.edge_index, m_d=m_d, m_s=m_s, E=E, sc=sc, adj_matrix=adj_matrix, ) # 可视化结果 G = to_networkx(data) # 转换为 NetworkX 图对象 pos = nx.spring_layout(G) # 获取图的布局 # 可视化图 plt.figure(figsize=(8, 8)) nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=500) nx.draw_networkx_edges(G, pos, edgelist=explanation.edge_mask, edge_color='r'
最新发布
04-03
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值