题目链接:
http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemId=5807
题意:
给一棵n个节点的红黑树,每条边有距离。树上有m个点是红的,每个红点的cost是0,每个黑点的cost是离它最近的红点祖先到它的距离。
q个询问,每次给k个点,现在允许你将树上某个点变红,是这k个点的cost的最大值最小,输出这个值。(数据规模都是10w)
思路1:
把每个红点与它父亲断开,这样就有很多棵根节点红色其他节点黑色的树(红根树)。
每次询问k个点,我们只需要考虑它们所在的红根树当中,max(cost)最大和次大的两棵树。
现在要选某个点变红,肯定是选在最大的红根树里。将max(cost)最小化后,和次大的树比一下就行了。
问题转化成,在红根树里选一个点变红,使k'个点的max(cost)最小。那就把k'个点按cost从小到大排序,求k'个后缀LCA就行了。
(可惜比赛的时候一句话写错了,搞半天还以为思路错了)
复杂度:O(n*log(n))
#include<bits/stdc++.h> using namespace std; typedef long long ll; typedef pair<int,ll> P; const int N=100005; int red[N]; int redfa[N]; vector<P>link[N]; ll cost[N]; int deep[N]; int sz[N],son[N],fa[N]; void dfs(int u,int _fa,int _red,ll _cost,int _deep){ fa[u]=_fa; deep[u]=_deep; cost[u]=_cost; redfa[u]=_red; sz[u]=1; son[u]=-1; for (P p:link[u]){ int v=p.first; ll w=p.second; if (v==_fa) continue; if (red[v]) dfs(v,u,v,0,_deep+1); else dfs(v,u,_red,_cost+w,_deep+1); sz[u]+=sz[v]; if (son[u]==-1||sz[v]>sz[son[u]]) son[u]=v; } } int top[N],id[N],T; void dfs2(int u,int _fa,int p){ top[u]=p; id[u]=++T; if (son[u]!=-1) dfs2(son[u],u,p); for (P p:link[u]){ int v=p.first; ll w=p.second; if (v==_fa||v==son[u]) continue; dfs2(v,u,v); } } int getlca(int x,int y){ if (x==0) return y; int f1=top[x],f2=top[y]; while(f1!=f2){ if (deep[f1]<deep[f2]) swap(f1,f2),swap(x,y); x=fa[f1],f1=top[x]; } if (deep[x]<deep[y]) swap(x,y); return y; } int stk[N]; int que[N]; ll f[N]; vector<int>VS[N]; int vis[N]; bool cmp(int i,int j){ return f[i]>f[j]; } bool cmp2(int i,int j){ return cost[i]<cost[j]; } int main(){ int t; scanf("%d",&t); while(t--){ int n,m,q; scanf("%d%d%d",&n,&m,&q); for (int i=1;i<=m;i++){ int x; scanf("%d",&x); red[x]=1; } for (int i=1;i<n;i++){ int x,y; ll w; scanf("%d%d%lld",&x,&y,&w); link[x].push_back({y,w}); link[y].push_back({x,w}); } dfs(1,0,1,0,1); dfs2(1,0,1); for (int T=1;T<=q;T++){ int num; scanf("%d",&num); int top=0; for (int i=1;i<=num;i++){ scanf("%d",&que[i]); int _red=redfa[que[i]]; VS[_red].emplace_back(que[i]); if (vis[_red]!=T) stk[++top]=_red,f[_red]=cost[que[i]]; else f[_red]=max(f[_red],cost[que[i]]); vis[_red]=T; } sort(stk+1,stk+top+1,cmp); int _red=stk[1]; sort(VS[_red].begin(),VS[_red].end(),cmp2); int tot=VS[_red].size(); ll ret = cost[VS[_red][tot - 1]]; int lca=VS[_red][tot - 1]; for (int i=tot-1;i>=0;i--){ lca=getlca(lca,VS[_red][i]); ll tmp; if (i!=0) tmp=cost[VS[_red][i-1]]; else tmp=0; ret=min(ret,max(tmp,f[_red]-cost[lca])); } if (top==1) printf("%lld\n",ret); else printf("%lld\n",max(ret,f[stk[2]])); for (int i=1;i<=top;i++) VS[stk[i]].clear(); } for (int i=1;i<=n;i++){ link[i].clear(); red[i]=0; vis[i]=0; } } }
思路2:
二分答案,不满足答案的LCA合并起来再看是否满足。(大家基本是这么做的)
复杂度:O(n*log(1e14)*log(n)) 但是树链剖分求LCA实在是太快了,也能过
求LCA的地方还可以转RMQ用ST表做到O(1),就是写起来比较麻烦
(试了下倍增求LCA,T掉了,果然树剖大法好!)
#include<bits/stdc++.h> using namespace std; typedef long long ll; typedef pair<int,ll> P; const int N=100005; int red[N]; int redfa[N]; vector<P>link[N]; ll cost[N]; int deep[N]; int que[N]; int sz[N],son[N],fa[N]; void dfs(int u,int _fa,int _red,ll _cost,int _deep){ fa[u]=_fa; deep[u]=_deep; cost[u]=_cost; redfa[u]=_red; sz[u]=1; son[u]=-1; for (P p:link[u]){ int v=p.first; ll w=p.second; if (v==_fa) continue; if (red[v]) dfs(v,u,v,0,_deep+1); else dfs(v,u,_red,_cost+w,_deep+1); sz[u]+=sz[v]; if (son[u]==-1||sz[v]>sz[son[u]]) son[u]=v; } } int top[N],id[N],T; void dfs2(int u,int _fa,int p){ top[u]=p; id[u]=++T; if (son[u]!=-1) dfs2(son[u],u,p); for (P p:link[u]){ int v=p.first; ll w=p.second; if (v==_fa||v==son[u]) continue; dfs2(v,u,v); } } int getlca(int x,int y){ if (x==0) return y; int f1=top[x],f2=top[y]; while(f1!=f2){ if (deep[f1]<deep[f2]) swap(f1,f2),swap(x,y); x=fa[f1],f1=top[x]; } if (deep[x]<deep[y]) swap(x,y); return y; } bool check(ll len,int num){ int lca=0; for (int i=1;i<=num;i++) if (cost[que[i]]>len) lca=getlca(lca,que[i]); for (int i=1;i<=num;i++) if (cost[que[i]]>len){ if (deep[redfa[que[i]]]>=deep[lca]) return 0; if (cost[que[i]]-cost[lca]>len) return 0; } return 1; } int main(){ int t; scanf("%d",&t); while(t--){ int n,m,q; scanf("%d%d%d",&n,&m,&q); for (int i=1;i<=m;i++){ int x; scanf("%d",&x); red[x]=1; } for (int i=1;i<n;i++){ int x,y; ll w; scanf("%d%d%lld",&x,&y,&w); link[x].push_back({y,w}); link[y].push_back({x,w}); } dfs(1,0,1,0,1); dfs2(1,0,1); for (int T=1;T<=q;T++){ int num; scanf("%d",&num); for (int i=1;i<=num;i++) scanf("%d",&que[i]); ll l=0,r=1e14; while(l<r){ ll mid=(l+r)/2; if (!check(mid,num)) l=mid+1; else r=mid; } printf("%lld\n",l); } for (int i=1;i<=n;i++){ link[i].clear(); red[i]=0; deep[i]=0; } } }