torch.argsort 代码测试

本文详细解析了PyTorch中的torch.argsort函数的工作原理,通过实例演示了如何使用该函数对张量沿指定维度进行升序排序,并返回排序后的索引。
部署运行你感兴趣的模型镜像

1.官方解释

torch.argsort

  • 解释
    Returns the indices that sort a tensor along a given dimension in ascending order by value.
    返回沿着给定维数按值升序对张量排序的索引。
  • 重点
    是按照值的顺序排列

2. 举例说明

a = torch.randn(4,4)
a = tensor([[ 0.0785,  1.5267, -0.8521,  0.4065],
            [ 0.1598,  0.0788, -0.0745, -1.2700],
            [ 1.2208,  1.0722, -0.7064,  1.2564],
            [ 0.0669, -0.2318, -0.8229, -0.9280]])
b = torch.argsort(a,dim=1)
b = tensor([[2, 0, 3, 1],
            [3, 2, 1, 0],
            [2, 1, 0, 3],
            [3, 2, 1, 0]])      

我们来分析下,b 为什么是这样的。
起初我们感官的认为,当a的第一行值为 [ 0.0785, 1.5267, -0.8521, 0.4065] 的时候,我们排序应该为如下:
在这里插入图片描述
按照我们的感觉应该得出 b 为 [1,3,0,2]才行,但是最后输出的结果居然是[2,0,3,1];居然跟我们设想的不一样,那为啥不对呢,主要原因是我们得出来的值是按照序号排列的,而官方文档说的是按照值来排序的。

  • 正确的操作:
    在这里插入图片描述
    说明:
    第一步是将 a 按照顺序值进行排序得到新的序列:
    [-0.8521, 0.0785, 0.4065, 1.5267]
    第二步是如果才能通过序列a来得到升序的序列,
    我们发现
    a[2]=-0.8521,a[0]=0.0785,a[3]=0.4065,a[1]=1.5267
    所以 b 返回的是序列值:[2,0,3,1];这样我们就可以通过这个序列[2,0,3,1]直接将数据按照升序进行排列了。

您可能感兴趣的与本文相关的镜像

PyTorch 2.7

PyTorch 2.7

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

import pandas as pd import numpy as np import torch from sklearn.preprocessing import StandardScaler from pyDOE import lhs def load_battery_data(file_path, input_dim=7, n_train=1000, n_test=100): """ 加载电池RUL数据,使用LHS方法抽取训练和测试样本(提取第2-8列作为特征) """ # 读取CSV文件 data = pd.read_csv(file_path) # 提取第2列到第8列作为特征(索引1到7),第9列作为RUL标签 features = data.iloc[:, 1:8].values # 第二列到第八列 target = data.iloc[:, 8].values # 第九列RUL # 标准化特征 scaler = StandardScaler() features_scaled = scaler.fit_transform(features) # 使用LHS方法生成样本索引 n_samples = len(features) sample_indices = lhs(n_samples, samples=n_train + n_test) sample_indices = np.round(sample_indices * (n_samples - 1)).astype(int) # 确保索引唯一 sample_indices = np.unique(sample_indices) if len(sample_indices) < n_train + n_test: print("警告: LHS生成了重复索引,正在重新生成...") return load_battery_data(file_path, input_dim, n_train, n_test) # 划分训练和测试索引 train_indices = sample_indices[:n_train] test_indices = sample_indices[n_train:n_train + n_test] # 提取训练和测试数据 X_train = torch.tensor(features_scaled[train_indices], dtype=torch.float32) y_train = torch.tensor(target[train_indices], dtype=torch.float32).squeeze() X_test = torch.tensor(features_scaled[test_indices], dtype=torch.float32) y_test = torch.tensor(target[test_indices], dtype=torch.float32).squeeze() return X_train, y_train, X_test, y_test逐行逐句解释代码,并用小白能懂的话解释问题本质
05-28
这是我构建的异构图的函数代码(注意特征维度是有区别的,课题最终任务是通过每个月更新的节点特征和边特征,预测每个月两个省份节点之间的物理联络线上的传输电量和线损电量): def build_heterogeneous_graph(month): # 初始化异构图 graph = HeteroData() # ================== 1. 读取节点数据 ================== # 省份节点 province_df = pd.read_excel("province_node_features.xlsx") province_month = province_df[province_df['月份'] == month].set_index('节点ID') # 通道节点 channel_df = pd.read_excel("channel_node_features.xlsx") channel_month = channel_df[channel_df['月份'] == month].set_index('节点ID') # ================== 2. 构建节点映射 ================== all_nodes = list(province_month.index) + list(channel_month.index) node_to_idx = {node: idx for idx, node in enumerate(all_nodes)} # ================== 3. 添加节点特征 ================== # 省份节点特征 graph['province'].x = torch.tensor( province_month[['跨区跨省总送出电量', '跨区送出电量', '跨省送出电量', '跨区跨省总受入电量', '跨区受入电量', '跨省受入电量']].values, dtype=torch.float32 ) # 通道节点特征 graph['channel'].x = torch.tensor( channel_month[['总送出电量', '总受入电量']].values, dtype=torch.float32 ) # ================== 4. 处理边数据 ================== # 物理联络线:包含省份-省份、省份-通道、通道-省份 physical_edges = pd.read_excel("physical_edge_features.xlsx") # 4.1 省份-省份物理边 prov_to_prov = physical_edges[ physical_edges['源节点'].isin(province_month.index) & physical_edges['目标节点'].isin(province_month.index) ] if not prov_to_prov.empty: edge_index = torch.tensor([ [node_to_idx[src] for src in prov_to_prov['源节点']], [node_to_idx[dst] for dst in prov_to_prov['目标节点']] ], dtype=torch.long) graph['province', 'physical_p2p', 'province'].edge_index = edge_index graph['province', 'physical_p2p', 'province'].edge_attr = torch.tensor( prov_to_prov[['电阻', '电抗', '线路单端电纳']].values, dtype=torch.float32 ) # 4.2 省份-通道物理边 prov_to_chan = physical_edges[ physical_edges['源节点'].isin(province_month.index) & physical_edges['目标节点'].isin(channel_month.index) ] if not prov_to_chan.empty: edge_index = torch.tensor([ [node_to_idx[src] for src in prov_to_chan['源节点']], [node_to_idx[dst] for dst in prov_to_chan['目标节点']] ], dtype=torch.long) graph['province', 'physical_p2c', 'channel'].edge_index = edge_index graph['province', 'physical_p2c', 'channel'].edge_attr = torch.tensor( prov_to_chan[['电阻', '电抗', '线路单端电纳']].values, dtype=torch.float32 ) # 4.3 通道-省份物理边 chan_to_prov = physical_edges[ physical_edges['源节点'].isin(channel_month.index) & physical_edges['目标节点'].isin(province_month.index) ] if not chan_to_prov.empty: edge_index = torch.tensor([ [node_to_idx[src] for src in chan_to_prov['源节点']], [node_to_idx[dst] for dst in chan_to_prov['目标节点']] ], dtype=torch.long) graph['channel', 'physical_c2p', 'province'].edge_index = edge_index graph['channel', 'physical_c2p', 'province'].edge_attr = torch.tensor( chan_to_prov[['电阻', '电抗', '线路单端电纳']].values, dtype=torch.float32 ) # 虚拟交易边 virtual_edges = pd.read_excel("virtual_edge_features.xlsx") virtual_edges_month = virtual_edges[virtual_edges['月份'] == month] # 4.4 省份-省份虚拟边 virt_p2p = virtual_edges_month[ virtual_edges_month['源节点'].isin(province_month.index) & virtual_edges_month['目标节点'].isin(province_month.index) ] if not virt_p2p.empty: edge_index = torch.tensor([ [node_to_idx[src] for src in virt_p2p['源节点']], [node_to_idx[dst] for dst in virt_p2p['目标节点']] ], dtype=torch.long) graph['province', 'virtual_p2p', 'province'].edge_index = edge_index graph['province', 'virtual_p2p', 'province'].edge_attr = torch.tensor( virt_p2p[['交易电量']].values, dtype=torch.float32 ) # 4.5 通道-省份虚拟边 virt_c2p = virtual_edges_month[ virtual_edges_month['源节点'].isin(channel_month.index) & virtual_edges_month['目标节点'].isin(province_month.index) ] if not virt_c2p.empty: edge_index = torch.tensor([ [node_to_idx[src] for src in virt_c2p['源节点']], [node_to_idx[dst] for dst in virt_c2p['目标节点']] ], dtype=torch.long) graph['channel', 'virtual_c2p', 'province'].edge_index = edge_index graph['channel', 'virtual_c2p', 'province'].edge_attr = torch.tensor( virt_c2p[['交易电量']].values, dtype=torch.float32 ) # 4.6 省份-通道虚拟边 virt_c2p = virtual_edges_month[ virtual_edges_month['源节点'].isin(province_month.index) & virtual_edges_month['目标节点'].isin(channel_month.index) ] if not virt_c2p.empty: edge_index = torch.tensor([ [node_to_idx[src] for src in virt_c2p['源节点']], [node_to_idx[dst] for dst in virt_c2p['目标节点']] ], dtype=torch.long) graph['province', 'virtual_c2p', 'channel'].edge_index = edge_index graph['province', 'virtual_c2p', 'channel'].edge_attr = torch.tensor( virt_c2p[['交易电量']].values, dtype=torch.float32 ) # ================== 5. 验证数据 ================== print(f"\n构建完成 {month}月异构图:") print(f"节点数量:") print(f"- 省份: {len(province_month)}") print(f"- 通道: {len(channel_month)}") print(f"\n边数量:") for edge_type in graph.edge_types: print(f"- {edge_type}: {graph[edge_type].edge_index.shape[1]}") return graph 后续代码怎么编写?
最新发布
08-20
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值