def train_encoder_at_dB(dB, report=False):
"""在指定SNR(dB)下训练编码器和解码器"""
SNR_dB = dB
SNR = 10 **(SNR_dB / 10)
sigma2 = 1 / SNR
# 优化器,同时优化编码器、解码器和映射器参数
opt = optim.Adam(
list(enc.parameters()) + list(dec.parameters()) + list(mapper.parameters()),
lr=lr
)
for j in range(nepochs):
# 编码器输出符号的对数概率
logits = enc(torch.tensor([1], dtype=torch.float))
# 采样符号索引
onehot = nn.functional.gumbel_softmax(logits.expand(n, -1), tau=10, hard=True)
indices = torch.argmax(onehot, dim=1)
# 调制过程
alphabet_t = mapper(nn.functional.one_hot(torch.arange(M), M).float()).squeeze()
probs = nn.functional.softmax(logits, -1)
# 计算归一化因子,确保平均功率为1
norm_factor = torch.rsqrt(torch.sum(torch.pow(torch.abs(alphabet_t), 2) * probs))
alphabet_norm = alphabet_t * norm_factor
symbols = torch.matmul(
onehot,
torch.transpose(input=alphabet_norm.reshape(1, -1), dim0=0, dim1=1)
)
# 通过AWGN信道传输
y = AWGN_channel(symbols, sigma2)
# 解码过程
ll = dec(y.reshape(-1, 1))
# 损失函数:最大化互信息 (MI = H(X) - CE)
# 这里最小化-(H(X)-CE)等价于最大化MI
loss = -(torch.sum(-probs * torch.log(probs)) - loss_fn(ll, indices.detach()))
opt.zero_grad()
loss.backward()
opt.step()
# 定期打印训练进度
if report and j % 500 == 0:
print(f'epoch {j}: Loss = {loss.detach().numpy() / np.log(2) :.4f}')
# 绘制星座图
with torch.no_grad():
alphabet_t = mapper(nn.functional.one_hot(torch.arange(M), M).float()).squeeze()
probs = nn.functional.softmax(logits, -1)
norm_factor = torch.rsqrt(torch.sum(torch.pow(torch.abs(alphabet_t), 2) * probs))
alphabet_norm = alphabet_t * norm_factor
plot_constellation(alphabet_norm, probs, SNR_dB)
return loss
运行之后显示:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_10992/2075127474.py in <module>
23
24 # 计算学习到的GeoPCS的互信息
---> 25 mi_gcs.append(-2 * (train_encoder_at_dB(snrdB, report=True)).detach().numpy().tolist() / np.log(2))
26
27 # 预定义的参考结果(Learnt PCS)
/tmp/ipykernel_10992/2871997838.py in train_encoder_at_dB(dB, report)
44 norm_factor = torch.rsqrt(torch.sum(torch.pow(torch.abs(alphabet_t), 2) * probs))
45 alphabet_norm = alphabet_t * norm_factor
---> 46 symbols = torch.matmul(
47 onehot,
48 torch.transpose(input=alphabet_norm.reshape(1, -1), dim0=0, dim1=1)
RuntimeError: expected scalar type Float but found ComplexFloat
最新发布