E. Tree Shuffling

本文详细解析了Codeforces竞赛中E题的解决方案,探讨了一棵树形结构中节点成本与二进制位操作的问题,通过选择特定节点进行位操作,以最小成本使每个节点达到目标二进制位。

链接:https://codeforces.ml/contest/1363/problem/E

Ashish has a tree consisting of nn nodes numbered 11 to nn rooted at node 11. The ii-th node in the tree has a cost aiai, and binary digit bibi is written in it. He wants to have binary digit cici written in the ii-th node in the end.

To achieve this, he can perform the following operation any number of times:

  • Select any kk nodes from the subtree of any node uu, and shuffle the digits in these nodes as he wishes, incurring a cost of k⋅auk⋅au. Here, he can choose kk ranging from 11 to the size of the subtree of uu.

He wants to perform the operations in such a way that every node finally has the digit corresponding to its target.

Help him find the minimum total cost he needs to spend so that after all the operations, every node uu has digit cucu written in it, or determine that it is impossible.

Input

First line contains a single integer nn (1≤n≤2⋅105)(1≤n≤2⋅105) denoting the number of nodes in the tree.

ii-th line of the next nn lines contains 3 space-separated integers aiai, bibi, cici (1≤ai≤109,0≤bi,ci≤1)(1≤ai≤109,0≤bi,ci≤1)  — the cost of the ii-th node, its initial digit and its goal digit.

Each of the next n−1n−1 lines contain two integers uu, vv (1≤u,v≤n, u≠v)(1≤u,v≤n, u≠v), meaning that there is an edge between nodes uu and vv in the tree.

Output

Print the minimum total cost to make every node reach its target digit, and −1−1 if it is impossible.

Examples

input

Copy

5
1 0 1
20 1 0
300 0 1
4000 0 0
50000 1 0
1 2
2 3
2 4
1 5

output

Copy

4

input

Copy

5
10000 0 1
2000 1 0
300 0 1
40 0 0
1 1 0
1 2
2 3
2 4
1 5

output

Copy

24000

input

Copy

2
109 0 1
205 0 1
1 2

output

Copy

-1

Note

The tree corresponding to samples 11 and 22 are:

In sample 11, we can choose node 11 and k=4k=4 for a cost of 4⋅14⋅1 = 44 and select nodes 1,2,3,51,2,3,5, shuffle their digits and get the desired digits in every node.

In sample 22, we can choose node 11 and k=2k=2 for a cost of 10000⋅210000⋅2, select nodes 1,51,5 and exchange their digits, and similarly, choose node 22 and k=2k=2 for a cost of 2000⋅22000⋅2, select nodes 2,32,3 and exchange their digits to get the desired digits in every node.

In sample 33, it is impossible to get the desired digits, because there is no node with digit 11 initially.

代码:

#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define lb double
#define INF 0x3f3f3f3f
#define maxn 200010
#define yes cout<<"Yes"<<endl
#define no cout<<"No"<<endl
#define rep(i,x,y) for(int i=x;i<=y;i++)
#define gep(i,x,y) for(int i=x;i>=y;i--)
ll n,t,x,k,u,v,max1,s,ans,mod=1e9+7;
ll a[200001],b[200001],c[200001],dp[200001][2],f[200001];
vector<ll>p[200001];
void dfs(int i,int fa)
{
	for(int j=0;j<p[i].size();j++)
	{
		if(p[i][j]!=fa)
		{
			a[p[i][j]]=min(a[p[i][j]],a[i]);
			dfs(p[i][j],i);
		}
	}
}
void dfss(int i,int fa)
{
	if(b[i]!=c[i])
	dp[i][b[i]]++;
	for(int j=0;j<p[i].size();j++)
	{
		if(p[i][j]!=fa)
		{
			dfss(p[i][j],i);
			dp[i][1]+=dp[p[i][j]][1];
			dp[i][0]+=dp[p[i][j]][0];
		}
	}
	k=min(dp[i][1],dp[i][0]);
	ans+=k*a[i]*2;
	dp[i][1]-=k;
	dp[i][0]-=k;
}
int main()
{
	cin>>n;
	ll s1=0,s2=0;
	for(int i=1;i<=n;i++)
	{
		cin>>a[i]>>b[i]>>c[i];
		s1+=b[i];
		s2+=c[i];
		f[i]=0;
	}
	for(int i=1;i<n;i++)
	{
		cin>>u>>v;
		p[u].push_back(v);
		p[v].push_back(u);
	}
	if(s1!=s2)
	cout<<-1<<endl;
	else
	{
		ans=0;
		dfs(1,0);
		dfss(1,0);
		cout<<ans<<endl;
	}
}
    

 

# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm # DeiT: https://github.com/facebookresearch/deit # -------------------------------------------------------- from functools import partial import torch import torch.nn as nn from timm.models.vision_transformer import PatchEmbed, Block from util.pos_embed import get_2d_sincos_pos_embed from timm.models.vision_transformer import Block as TimmBlock class MaskedAutoencoderViT(nn.Module): """ Masked Autoencoder with VisionTransformer backbone """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): super().__init__() # -------------------------------------------------------------------------- # MAE encoder specifics self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding self.blocks = nn.ModuleList([ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) # -------------------------------------------------------------------------- # -------------------------------------------------------------------------- # MAE decoder specifics self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding self.decoder_blocks = nn.ModuleList([ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(decoder_depth)]) self.decoder_norm = norm_layer(decoder_embed_dim) self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch # -------------------------------------------------------------------------- self.norm_pix_loss = norm_pix_loss self.initialize_weights() def initialize_weights(self): # initialization # initialize (and freeze) pos_embed by sin-cos embedding pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) # initialize patch_embed like nn.Linear (instead of nn.Conv2d) w = self.patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) torch.nn.init.normal_(self.cls_token, std=.02) torch.nn.init.normal_(self.mask_token, std=.02) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def patchify(self, imgs): """ imgs: (N, 3, H, W) x: (N, L, patch_size**2 *3) """ p = self.patch_embed.patch_size[0] assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) return x def unpatchify(self, x): """ x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W) """ p = self.patch_embed.patch_size[0] h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) return imgs def random_masking(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] # sort noise for each sample ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) # keep the first subset ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # generate the binary mask: 0 is keep, 1 is remove mask = torch.ones([N, L], device=x.device) mask[:, :len_keep] = 0 # unshuffle to get the binary mask mask = torch.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore def forward_encoder(self, x, mask_ratio): # embed patches x = self.patch_embed(x) # add pos embed w/o cls token x = x + self.pos_embed[:, 1:, :] # masking: length -> length * mask_ratio x, mask, ids_restore = self.random_masking(x, mask_ratio) # append cls token cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) # apply Transformer blocks for blk in self.blocks: x = blk(x) x = self.norm(x) return x, mask, ids_restore def forward_decoder(self, x, ids_restore): # embed tokens x = self.decoder_embed(x) # append mask tokens to sequence mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token # add pos embed x = x + self.decoder_pos_embed # apply Transformer blocks for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) # predictor projection x = self.decoder_pred(x) # remove cls token x = x[:, 1:, :] return x def forward_loss(self, imgs, pred, mask): """ imgs: [N, 3, H, W] pred: [N, L, p*p*3] mask: [N, L], 0 is keep, 1 is remove, """ target = self.patchify(imgs) if self.norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6)**.5 loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches return loss def forward(self, imgs, mask_ratio=0.75): latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] loss = self.forward_loss(imgs, pred, mask) return loss, pred, mask def mae_vit_base_patch16_dec512d8b(**kwargs): model = MaskedAutoencoderViT( patch_size=16, embed_dim=768, depth=12, num_heads=12, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def mae_vit_large_patch16_dec512d8b(**kwargs): model = MaskedAutoencoderViT( patch_size=16, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def mae_vit_huge_patch14_dec512d8b(**kwargs): model = MaskedAutoencoderViT( patch_size=14, embed_dim=1280, depth=32, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model # set recommended archs mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 这样子对吗
最新发布
12-04
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值