emmm ,一开始,我想:对于 一个节点,跟它相连 的 距离为2 的点, 要么是他父亲的父亲 ,要么是他父亲的(除他自己和 他父亲 ) 儿子,写了一个 以为是 O(n)的程序 却 70 分,原来是个常数小的 O()啊,GG~~
70 分代码 :
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#define LL long long
using namespace std;
const int mod = 10007 ;
inline int wread(){
char c(getchar ());int wans(0);
while (c<'0' || c>'9') c=getchar ();
while (c>='0' && c<='9'){wans=wans*10+c-'0'; c=getchar ();}
return wans;
}
int K,hed[200002];
struct node{int u,v,nxt;}e[400002];
void ad (int u,int v){e[++K]=(node){u,v,hed[u]};hed[u]=K;}
int n;
int f[200002];
int w[200002];
LL ans1;
int ans2;
void dfs (int x,int fa){
f[x]=fa;
ans1=max(ans1,(LL)w[f[fa]]*(LL)w[x]);
ans2 = ( ans2 + (LL)w[f[fa]]*(LL)w[x] % mod ) % mod ;
for (int i(hed[fa]);i;i=e[i].nxt){
int v(e[i].v);
if (v==x) continue;
ans1=max(ans1,(LL)w[v]*(LL)w[x]);
ans2 = ( ans2 + (LL)w[v]*(LL)w[x] % mod ) % mod ;
}
for (int i(hed[x]);i;i=e[i].nxt){
int v(e[i].v);
if (v==fa) continue;
dfs(v,x);
}
}
int main(){
n=wread();
for (int i(1);i<n;++i){
int u(wread()),v(wread());
ad(u,v); ad(v,u);
}
for (int i(1);i<=n;++i)
w[i]=wread();
dfs(1,0);
printf("%lld %d\n",ans1,ans2);
return 0;
}
后来 看了看题解,没懂~
然后就想 优化 ???怎么优化!!!
我就想了个 dp 维护 最大值 ,开了个变量 统计 总和
十分 BT
(后有简便方法)
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
#define LL long long
using namespace std;
const int mod = 10007 ;
inline int wread(){
char c(getchar ());int wans(0);
while (c<'0' || c>'9') c=getchar ();
while (c>='0' && c<='9'){wans=wans*10+c-'0'; c=getchar ();}
return wans;
}
int K,hed[200002];
struct node{int u,v,nxt;}e[400002];
void ad (int u,int v){e[++K]=(node){u,v,hed[u]};hed[u]=K;}
int n;
int w[200002];
LL ans1;
int ans2;
int dp1[200002][3];//从前面开始 到 第 i 位
int dp2[200002][3];//从后面开始 到 第 j 位
vector <int> C;
void dfs (int x,int fa){
C.clear();
LL all_(0);
int num(0);
dp1[0][0]=dp1[0][1]=dp2[0][0]=dp2[0][1]=0;
for (int i(hed[x]);i;i=e[i].nxt){
int v(e[i].v);
all_+=w[v];
dp1[++num][0] = max (dp1[num-1][0],dp1[num-1][1]) ;
dp1[num][1] = max (max (dp1[num-1][0],dp1[num-1][1]) , w[v] );
C.push_back(v);
}
for (int i(C.size());i>1;--i){
int v(C[i-1]);
dp2[i][0] = max (dp2[i-1][0],dp2[i-1][1]) ;
dp2[i][1] = max (max (dp2[i-1][0],dp2[i-1][1]) , w[v] );
}
for (int i(0);i<C.size();++i){
ans1=max( ans1, (LL)w[ C[i] ] * (LL)max(dp1[i+1][0], dp2[C.size()-i][0]));
ans2=(ans2 + w[ C[i] ] * (LL) (all_ - w[ C[i] ]) % mod ) % mod ;
}
for (int i(hed[x]);i;i=e[i].nxt){
int v(e[i].v);
if (v==fa) continue;
dfs(v,x);
}
}
int main(){
n=wread();
for (int i(1);i<n;++i){
int u(wread()),v(wread());
ad(u,v); ad(v,u);
}
for (int i(1);i<=n;++i)
w[i]=wread();
dfs(1,0);
printf("%lld %d\n",ans1,ans2);
return 0;
}
后来 , 题解告诉我:
最大值 可以直接 求啊!!!
遍历每个点 的时候,对他的 周围的点(包括父亲) 维护 最大值 和 次大值 ,最后相乘 ,过程中不断取 Max
多么简便啊~~~~
贪心的妙用吧
(可惜我想不到啊)
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <vector>
#define LL long long
using namespace std;
const int mod = 10007 ;
inline int wread(){
char c(getchar ());int wans(0);
while (c<'0' || c>'9') c=getchar ();
while (c>='0' && c<='9'){wans=wans*10+c-'0'; c=getchar ();}
return wans;
}
int K,hed[200002];
struct node{int u,v,nxt;}e[400002];
void ad (int u,int v){e[++K]=(node){u,v,hed[u]};hed[u]=K;}
int n;
int w[200002];
LL ans1;
int ans2;
void dfs (int x,int fa){
LL all_(0);
int maxn1(0),maxn2(0);//最大值 次大值
for (int i(hed[x]);i;i=e[i].nxt){
int v(e[i].v);
int zhi(w[v]);
if (zhi>maxn1) maxn1=zhi;
else if (zhi>maxn2) maxn2=zhi;
}
ans1=max(ans1,(LL)maxn1*(LL)maxn2);
for (int i(hed[x]);i;i=e[i].nxt){
int v(e[i].v);
all_+=w[v];
}//统计总和
for(int i(hed[x]);i;i=e[i].nxt){
int v(e[i].v);
ans2=(ans2 + (LL) w[v] * (LL) (all_ - w[ v ]) % mod ) % mod ;//计算:自己推
}
for (int i(hed[x]);i;i=e[i].nxt){
int v(e[i].v);
if (v==fa) continue;
dfs(v,x);
}
}
int main(){
n=wread();
for (int i(1);i<n;++i){
int u(wread()),v(wread());
ad(u,v); ad(v,u);
}
for (int i(1);i<=n;++i)
w[i]=wread();
dfs(1,0);
printf("%lld %d\n",ans1,ans2);
return 0;
}