[CF592D]Super M

本文介绍了一种使用树形动态规划方法解决特定问题的算法:在一颗树中找到遍历所有标记节点的最小总路程。文章通过具体实例解释了如何利用DFS序和树的直径概念来优化解决方案。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

592D:Super M

题意简述

给出一棵n个节点的树,其中有m个节点为标记节点。
树上相邻节点距离为1
你被要求从任意点出发,遍历这些标记节点,求最小的总路程。

数据范围

1mn123456

思路

树形DP。
最优的方案肯定是沿着DFS序走。
每DFS到一个标记节点,就把沿路往上的路径算两遍加入贡献直到另一个标记节点。
因为我们不必回到起点,所有再做一次DFS,求出标记节点之间的最远距离,减到答案里。
这一步我们可以用类似树的直径的方法DFS两次。

代码

#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
struct Node{
    int s,t,next;
}e[250010];
int head[130010],cnt;
void addedge(int s,int t)
{
    e[cnt].s=s;e[cnt].t=t;e[cnt].next=head[s];head[s]=cnt++;
    e[cnt].s=t;e[cnt].t=s;e[cnt].next=head[t];head[t]=cnt++;
}
bool pd[130010],lable[130010];
int dis[130010];
int n,m,u,v,pos,ans,tmp;
void dfs(int node,int lastfa)
{
    pd[node]=lable[node];
    for (int i=head[node];i!=-1;i=e[i].next)
        if (e[i].t!=lastfa)
        {
            dfs(e[i].t,node);
            if (pd[e[i].t])
                ans+=2;
            pd[node]|=pd[e[i].t];
        }
}
void dfs2(int node,int lastfa,int sum)
{
    dis[node]=sum;
    for (int i=head[node];i!=-1;i=e[i].next)
        if (e[i].t!=lastfa)
            dfs2(e[i].t,node,sum+1);
}
int main()
{
    scanf("%d%d",&n,&m);
    memset(head,0xff,sizeof(head));
    cnt=0;
    for (int i=1;i<n;i++)
    {
        scanf("%d%d",&u,&v);
        addedge(u,v);
    }
    for (int i=1;i<=m;i++)
    {
        scanf("%d",&u);
        lable[u]=1;
    }
    dfs(u,u);
    dis[0]=-1;
    dfs2(u,u,0);
    for (int i=1;i<=n;i++)
        if (lable[i]&&dis[i]>dis[pos])
            pos=i;
    tmp=pos;
    dfs2(pos,pos,0);
    pos=0;
    for (int i=1;i<=n;i++)
        if (lable[i]&&dis[i]>dis[pos])
            pos=i;
    ans=ans-dis[pos];
    pos=min(pos,tmp);
    printf("%d\n%d\n",pos,ans);
    return 0;
}
alphabet = np.arange(-(8-1),8,2) #alphabet = np.array([-10, -6, -3, -1, 1, 3, 6, 10]) <-- Non-uniform alphabet alphabet = alphabet / np.sqrt(np.mean(alphabet**2)) class Receiver(nn.Module): def init(self, M): super().init() self.lin1 = nn.Linear(1, M) def forward(self, y): y = self.lin1(y) return y class Encoder(nn.Module): def init(self, M): super().init() self.lin1 = nn.Linear(1, M, bias=False) nn.init.constant_(self.lin1.weight, 1 / M) self.out = nn.Softmax(dim=0) def forward(self, y): return self.out(self.lin1(y)) class Mapper(nn.Module): def init(self, M): super().init() self.lin1 = nn.Linear(M, 1) self.lin1.weight = nn.Parameter(torch.Tensor([[i for i in alphabet]])) # set weigths equal to alphabet def forward(self, y): y = self.lin1(y) return y def sampler(prob, n): samples = torch.empty(0) for idx, p in enumerate(prob): occurrences = torch.round(n * p).type(torch.LongTensor) samples = torch.cat((samples, torch.ones(occurrences, dtype=torch.int64) * torch.tensor(idx))) indexes = torch.randperm(samples.shape[0]) return samples[indexes] def gradient_correction_factor(app, idx, prob, M): (nn,M)= app.shape cf = torch.zeros(M) for j in range(M): tmp = app[:, j] cf[j] = torch.sum(torch.log(tmp[idxj])) / (nn*prob[j]) # tmp[idxj] selects the ll of those xy pairs which belong to the current symbol j return cf def AWGN_channel(x, sigma2): noise_t = np.sqrt(sigma2)*torch.randn(x.shape) return x + noise_t def tanh_channel(x, sigma2): noise_t = np.sqrt(sigma2)*torch.randn(x.shape) return torch.tanh(x) + noise_t M = 8 n = 10_000 SNR_dB = 7 SNR = 10**(SNR_dB/10) sigma2 = 1/SNR nepochs = 4000 dec = Receiver(M) enc = Encoder(M) mapper = Mapper(M) loss_fn = nn.CrossEntropyLoss() alphabet = np.arange(-(M-1),M,2) alphabet = np.array([-10, -6, -3, -1, 1, 3, 6, 10]) <-- Non-uniform alphabet alphabet = alphabet / np.sqrt(np.mean(alphabet**2)) alphabet_t = torch.tensor(alphabet).float() for i in range(0, M): i_onehot = nn.functional.one_hot(torch.tensor(i), 8).float() print(mapper(i_onehot)) enc.lin1._parameters[‘weight’].shape lr = 0.1 opt = optim.Adam( list(enc.parameters()) + list(dec.parameters()) + list(mapper.parameters()), lr=lr) opt = optim.Adam( list(enc.parameters()) + list(dec.parameters()), lr=lr) torch.arange(1,M+1).squeeze() mapper(nn.functional.one_hot(torch.arange(M),M).float()).squeeze().shape for j in range(nepochs): # logits = enc(torch.tensor([1], dtype=torch.float)) probs = enc(torch.tensor([1], dtype=torch.float)) # probs = nn.functional.softmax(logits, -1) probs.retain_grad() # Sample indexes indices = sampler(probs, n) indices = indices.type(torch.LongTensor) # Modulation alphabet_t = mapper(nn.functional.one_hot(torch.arange(M),M).float()).squeeze() norm_factor = torch.rsqrt(torch.sum(torch.pow(torch.abs(alphabet_t), 2) * probs)) alphabet_norm = alphabet_t * norm_factor onehot = nn.functional.one_hot(indices, M).float() symbols = torch.matmul(onehot, torch.transpose(input=alphabet_norm.reshape(1,-1), dim0=0, dim1=1)) # Channel # y = AWGN_channel(symbols, sigma2) y = tanh_channel(symbols, sigma2) # Demodulator ll = dec(y.reshape(-1,1).float()) app = nn.functional.softmax(ll, 1) #Q(X|Y) # Loss loss = -(torch.sum(-probs*torch.log(probs)) - loss_fn(ll, indices))# -(H(X) - CE(P,Q)), the gradient descent minimizes, therefore we minimize the opposite to maximize the MI in the end. opt.zero_grad() loss.backward(retain_graph=True) # correction factor cf = - (gradient_correction_factor(app, indices, probs, M) - torch.log(probs)).detach() # if j % 500 == 0: # print('missing factors: ', cf.detach().numpy()) # print('current grad: ', probs.grad.detach().numpy()) probs.grad = cf probs.backward(torch.tensor([1., 1., 1., 1., 1., 1., 1., 1.])) # probs.grad += cf.detach() # enc.lin1._parameters['weight'].grad += cf.reshape(-1, 1).detach() opt.step() # Printout and visualization if j % 500 == 0: print(f'epoch {j}: Loss = {loss.detach().numpy() / np.log(2) :.4f}') enc.lin1._parameters[‘weight’] probs plt.rcParams[‘figure.figsize’] = [4, 4] plt.hist(symbols.detach().numpy(), bins=100) tikzplotlib.save(“/home/ddeandres/Projects/internship_pcs/documentation/figs/aref_gcs_{}dB.tex”.format(SNR_dB)) plt.show() plt.savefig(‘/home/ddeandres/Projects/internship_pcs/documentation/figs/aref_gcs_{}dB.pgf’.format(SNR_dB)) Scatterplot pp = (probs.reshape(-1,1)probs.reshape(1,-1)).reshape(-1,1).detach().numpy() alph = alphabet_norm.detach().numpy() a = [] for c in np.flip(alph): for d in alph: a.append(d+1jc) plt.scatter(np.real(a), np.imag(a), pp*2000) plt.show() def AWGN_channel_np(x, sigma2): noise = np.sqrt(sigma2) * np.random.randn(x.size) return x + noise def AWGNdemapper(y, const, varN): apps = np.exp(-np.abs(np.transpose([y])-const)**2/(2*varN)) return apps / np.transpose([np.sum(apps, 1)]) def xesmd(apps, idx): “”" Estimates symbolwise equivocation from reference symbols indices and a posteriori probabilities. “”" eq = -np.log(np.take_along_axis(apps, idx[:, None], axis=1) / np.transpose([np.sum(apps, 1)])) eq[eq==np.inf] = 1000 return np.mean(eq) n = 100_000 SNR_dBs = np.arange(5,22) M = 8 alphabet = np.arange(-7,8,2) alphabet = alphabet / np.sqrt(np.mean(alphabet**2)) indices = np.random.choice(np.arange(M), n) symbols2 = alphabet[indices] mi_64 = [] for snrdB in SNR_dBs: sigma2 = 1/(10**(snrdB/10)) sigma2 = sigma2 y = AWGN_channel_np(symbols2, sigma2) apps = AWGNdemapper(y, alphabet, sigma2) xe = xesmd(apps, indices) mi_64.append(2*(3 - (xe) / np.log(2))) print((-2*loss.detach()/np.log(2)).detach().numpy()) Plot plt.rcParams[‘figure.figsize’] = [8, 6] plt.plot(SNR_dBs, mi_64, label = ‘64QAM’) plt.plot(SNR_dBs, np.log2(1+10**(SNR_dBs/10)), color=‘black’, label=‘Capacity’) plt.plot(SNR_dB, -2loss.detach()/np.log(2), color=‘red’, marker=‘o’, markersize=3) xy = (SNR_dB, (-2loss.detach()/np.log(2)).detach().numpy()) plt.annotate(‘(%s, %s)’ % xy, xy=xy, textcoords=‘data’) plt.legend() plt.grid() SNR_dBs = np.arange(0,20) plt.plot(SNR_dBs, np.log2(1+10**(SNR_dBs/10)), color=‘C0’, label=‘ C ( P / σ 2 ) C(P/σ 2 )’) plt.plot(SNR_dBs, np.log2(1+10**(SNR_dBs/10)) - 0.5np.log2((np.pinp.e)/6) , linestyle=‘dashed’, color=‘C1’, label=‘ C ( P / σ 2 ) − 1 2 log ⁡ 2 π e 6 C(P/σ 2 )− 2 1 ​ log 2 ​ 6 πe ​ ’) plt.grid() plt.ylabel(‘bits per channel use’) plt.xlabel(‘SNR in dB’) plt.xlim([0, 20]) plt.ylim([0, 5]) plt.title(‘AWGN channel capacity gap’) plt.legend() def one_hot(a, M): onehot = np.zeros(M) onehot[a] = 1 return onehot Data for the plots a_plot = np.arange(M) onehot_plot = np.array([one_hot(a_plot[i], M) for i in range(M)]) learned_x = mapper(torch.tensor(onehot_plot).float()) yy_plot = torch.tanh(learned_x) Plot plt.scatter(np.real(learned_x.detach().numpy()), np.imag(learned_x.detach().numpy())) plt.title(‘Learned constellation’) plt.grid() plt.scatter(np.real(yy_plot.detach().numpy()), np.imag(yy_plot.detach().numpy())) plt.title(‘Constellation after tanh’) plt.grid() 解释上述代码
最新发布
08-13
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值