alphabet = np.arange(-(8-1),8,2)
#alphabet = np.array([-10, -6, -3, -1, 1, 3, 6, 10]) <-- Non-uniform alphabet
alphabet = alphabet / np.sqrt(np.mean(alphabet**2))
class Receiver(nn.Module):
def init(self, M):
super().init()
self.lin1 = nn.Linear(1, M)
def forward(self, y): y = self.lin1(y) return y
class Encoder(nn.Module):
def init(self, M):
super().init()
self.lin1 = nn.Linear(1, M, bias=False)
nn.init.constant_(self.lin1.weight, 1 / M)
self.out = nn.Softmax(dim=0)
def forward(self, y): return self.out(self.lin1(y))
class Mapper(nn.Module):
def init(self, M):
super().init()
self.lin1 = nn.Linear(M, 1)
self.lin1.weight = nn.Parameter(torch.Tensor([[i for i in alphabet]])) # set weigths equal to alphabet
def forward(self, y): y = self.lin1(y) return y
def sampler(prob, n):
samples = torch.empty(0)
for idx, p in enumerate(prob):
occurrences = torch.round(n * p).type(torch.LongTensor)
samples = torch.cat((samples, torch.ones(occurrences, dtype=torch.int64) * torch.tensor(idx)))
indexes = torch.randperm(samples.shape[0])
return samples[indexes]
def gradient_correction_factor(app, idx, prob, M):
(nn,M)= app.shape
cf = torch.zeros(M)
for j in range(M):
tmp = app[:, j]
cf[j] = torch.sum(torch.log(tmp[idxj])) / (nn*prob[j]) # tmp[idxj] selects the ll of those xy pairs which belong to the current symbol j
return cf
def AWGN_channel(x, sigma2):
noise_t = np.sqrt(sigma2)*torch.randn(x.shape)
return x + noise_t
def tanh_channel(x, sigma2):
noise_t = np.sqrt(sigma2)*torch.randn(x.shape)
return torch.tanh(x) + noise_t
M = 8
n = 10_000
SNR_dB = 7
SNR = 10**(SNR_dB/10)
sigma2 = 1/SNR
nepochs = 4000
dec = Receiver(M)
enc = Encoder(M)
mapper = Mapper(M)
loss_fn = nn.CrossEntropyLoss()
alphabet = np.arange(-(M-1),M,2)
alphabet = np.array([-10, -6, -3, -1, 1, 3, 6, 10]) <-- Non-uniform alphabet
alphabet = alphabet / np.sqrt(np.mean(alphabet**2))
alphabet_t = torch.tensor(alphabet).float()
for i in range(0, M):
i_onehot = nn.functional.one_hot(torch.tensor(i), 8).float()
print(mapper(i_onehot))
enc.lin1._parameters[‘weight’].shape
lr = 0.1
opt = optim.Adam( list(enc.parameters()) + list(dec.parameters()) + list(mapper.parameters()), lr=lr)
opt = optim.Adam( list(enc.parameters()) + list(dec.parameters()), lr=lr)
torch.arange(1,M+1).squeeze()
mapper(nn.functional.one_hot(torch.arange(M),M).float()).squeeze().shape
for j in range(nepochs):
# logits = enc(torch.tensor([1], dtype=torch.float))
probs = enc(torch.tensor([1], dtype=torch.float))
# probs = nn.functional.softmax(logits, -1)
probs.retain_grad()
# Sample indexes indices = sampler(probs, n) indices = indices.type(torch.LongTensor) # Modulation alphabet_t = mapper(nn.functional.one_hot(torch.arange(M),M).float()).squeeze() norm_factor = torch.rsqrt(torch.sum(torch.pow(torch.abs(alphabet_t), 2) * probs)) alphabet_norm = alphabet_t * norm_factor onehot = nn.functional.one_hot(indices, M).float() symbols = torch.matmul(onehot, torch.transpose(input=alphabet_norm.reshape(1,-1), dim0=0, dim1=1)) # Channel # y = AWGN_channel(symbols, sigma2) y = tanh_channel(symbols, sigma2) # Demodulator ll = dec(y.reshape(-1,1).float()) app = nn.functional.softmax(ll, 1) #Q(X|Y) # Loss loss = -(torch.sum(-probs*torch.log(probs)) - loss_fn(ll, indices))# -(H(X) - CE(P,Q)), the gradient descent minimizes, therefore we minimize the opposite to maximize the MI in the end. opt.zero_grad() loss.backward(retain_graph=True) # correction factor cf = - (gradient_correction_factor(app, indices, probs, M) - torch.log(probs)).detach() # if j % 500 == 0: # print('missing factors: ', cf.detach().numpy()) # print('current grad: ', probs.grad.detach().numpy()) probs.grad = cf probs.backward(torch.tensor([1., 1., 1., 1., 1., 1., 1., 1.])) # probs.grad += cf.detach() # enc.lin1._parameters['weight'].grad += cf.reshape(-1, 1).detach() opt.step() # Printout and visualization if j % 500 == 0: print(f'epoch {j}: Loss = {loss.detach().numpy() / np.log(2) :.4f}')
enc.lin1._parameters[‘weight’]
probs
plt.rcParams[‘figure.figsize’] = [4, 4]
plt.hist(symbols.detach().numpy(), bins=100)
tikzplotlib.save(“/home/ddeandres/Projects/internship_pcs/documentation/figs/aref_gcs_{}dB.tex”.format(SNR_dB))
plt.show()
plt.savefig(‘/home/ddeandres/Projects/internship_pcs/documentation/figs/aref_gcs_{}dB.pgf’.format(SNR_dB))
Scatterplot
pp = (probs.reshape(-1,1)probs.reshape(1,-1)).reshape(-1,1).detach().numpy()
alph = alphabet_norm.detach().numpy()
a = []
for c in np.flip(alph):
for d in alph:
a.append(d+1jc)
plt.scatter(np.real(a), np.imag(a), pp*2000)
plt.show()
def AWGN_channel_np(x, sigma2):
noise = np.sqrt(sigma2) * np.random.randn(x.size)
return x + noise
def AWGNdemapper(y, const, varN):
apps = np.exp(-np.abs(np.transpose([y])-const)**2/(2*varN))
return apps / np.transpose([np.sum(apps, 1)])
def xesmd(apps, idx):
“”"
Estimates symbolwise equivocation from reference symbols indices and a posteriori probabilities.
“”"
eq = -np.log(np.take_along_axis(apps, idx[:, None], axis=1) / np.transpose([np.sum(apps, 1)]))
eq[eq==np.inf] = 1000
return np.mean(eq)
n = 100_000
SNR_dBs = np.arange(5,22)
M = 8
alphabet = np.arange(-7,8,2)
alphabet = alphabet / np.sqrt(np.mean(alphabet**2))
indices = np.random.choice(np.arange(M), n)
symbols2 = alphabet[indices]
mi_64 = []
for snrdB in SNR_dBs:
sigma2 = 1/(10**(snrdB/10))
sigma2 = sigma2
y = AWGN_channel_np(symbols2, sigma2)
apps = AWGNdemapper(y, alphabet, sigma2)
xe = xesmd(apps, indices)
mi_64.append(2*(3 - (xe) / np.log(2)))
print((-2*loss.detach()/np.log(2)).detach().numpy())
Plot
plt.rcParams[‘figure.figsize’] = [8, 6]
plt.plot(SNR_dBs, mi_64, label = ‘64QAM’)
plt.plot(SNR_dBs, np.log2(1+10**(SNR_dBs/10)), color=‘black’, label=‘Capacity’)
plt.plot(SNR_dB, -2loss.detach()/np.log(2), color=‘red’, marker=‘o’, markersize=3)
xy = (SNR_dB, (-2loss.detach()/np.log(2)).detach().numpy())
plt.annotate(‘(%s, %s)’ % xy, xy=xy, textcoords=‘data’)
plt.legend()
plt.grid()
SNR_dBs = np.arange(0,20)
plt.plot(SNR_dBs, np.log2(1+10**(SNR_dBs/10)), color=‘C0’, label=‘
C
(
P
/
σ
2
)
C(P/σ
2
)’)
plt.plot(SNR_dBs, np.log2(1+10**(SNR_dBs/10)) - 0.5np.log2((np.pinp.e)/6) , linestyle=‘dashed’, color=‘C1’, label=‘
C
(
P
/
σ
2
)
−
1
2
log
2
π
e
6
C(P/σ
2
)−
2
1
log
2
6
πe
’)
plt.grid()
plt.ylabel(‘bits per channel use’)
plt.xlabel(‘SNR in dB’)
plt.xlim([0, 20])
plt.ylim([0, 5])
plt.title(‘AWGN channel capacity gap’)
plt.legend()
def one_hot(a, M):
onehot = np.zeros(M)
onehot[a] = 1
return onehot
Data for the plots
a_plot = np.arange(M)
onehot_plot = np.array([one_hot(a_plot[i], M) for i in range(M)])
learned_x = mapper(torch.tensor(onehot_plot).float())
yy_plot = torch.tanh(learned_x)
Plot
plt.scatter(np.real(learned_x.detach().numpy()), np.imag(learned_x.detach().numpy()))
plt.title(‘Learned constellation’)
plt.grid()
plt.scatter(np.real(yy_plot.detach().numpy()), np.imag(yy_plot.detach().numpy()))
plt.title(‘Constellation after tanh’)
plt.grid()
解释上述代码
最新发布