torch.reshape输入参数为-1的含义

由于未提供博客具体内容,无法给出包含关键信息的摘要。

在这里插入图片描述

def train_encoder_at_dB(dB, report=False): """在指定SNR(dB)下训练编码器和解码器""" SNR_dB = dB SNR = 10 **(SNR_dB / 10) sigma2 = 1 / SNR # 优化器,同时优化编码器、解码器和映射器参数 opt = optim.Adam( list(enc.parameters()) + list(dec.parameters()) + list(mapper.parameters()), lr=lr ) for j in range(nepochs): # 编码器输出符号的对数概率 logits = enc(torch.tensor([1], dtype=torch.float)) # 采样符号索引 onehot = nn.functional.gumbel_softmax(logits.expand(n, -1), tau=10, hard=True) indices = torch.argmax(onehot, dim=1) # 调制过程 alphabet_t = mapper(nn.functional.one_hot(torch.arange(M), M).float()).squeeze() probs = nn.functional.softmax(logits, -1) # 计算归一化因子,确保平均功率为1 norm_factor = torch.rsqrt(torch.sum(torch.pow(torch.abs(alphabet_t), 2) * probs)) alphabet_norm = alphabet_t * norm_factor symbols = torch.matmul( onehot, torch.transpose(input=alphabet_norm.reshape(1, -1), dim0=0, dim1=1) ) # 通过AWGN信道传输 y = AWGN_channel(symbols, sigma2) # 解码过程 ll = dec(y.reshape(-1, 1)) # 损失函数:最大化互信息 (MI = H(X) - CE) # 这里最小化-(H(X)-CE)等价于最大化MI loss = -(torch.sum(-probs * torch.log(probs)) - loss_fn(ll, indices.detach())) opt.zero_grad() loss.backward() opt.step() # 定期打印训练进度 if report and j % 500 == 0: print(f'epoch {j}: Loss = {loss.detach().numpy() / np.log(2) :.4f}') # 绘制星座图 with torch.no_grad(): alphabet_t = mapper(nn.functional.one_hot(torch.arange(M), M).float()).squeeze() probs = nn.functional.softmax(logits, -1) norm_factor = torch.rsqrt(torch.sum(torch.pow(torch.abs(alphabet_t), 2) * probs)) alphabet_norm = alphabet_t * norm_factor plot_constellation(alphabet_norm, probs, SNR_dB) return loss 运行之后显示: --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) /tmp/ipykernel_10992/2075127474.py in <module> 23 24 # 计算学习到的GeoPCS的互信息 ---> 25 mi_gcs.append(-2 * (train_encoder_at_dB(snrdB, report=True)).detach().numpy().tolist() / np.log(2)) 26 27 # 预定义的参考结果(Learnt PCS) /tmp/ipykernel_10992/2871997838.py in train_encoder_at_dB(dB, report) 44 norm_factor = torch.rsqrt(torch.sum(torch.pow(torch.abs(alphabet_t), 2) * probs)) 45 alphabet_norm = alphabet_t * norm_factor ---> 46 symbols = torch.matmul( 47 onehot, 48 torch.transpose(input=alphabet_norm.reshape(1, -1), dim0=0, dim1=1) RuntimeError: expected scalar type Float but found ComplexFloat
最新发布
08-19
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值