题目链接:
http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemId=5807
题意:
给一棵n个节点的红黑树,每条边有距离。树上有m个点是红的,每个红点的cost是0,每个黑点的cost是离它最近的红点祖先到它的距离。
q个询问,每次给k个点,现在允许你将树上某个点变红,是这k个点的cost的最大值最小,输出这个值。(数据规模都是10w)
思路1:
把每个红点与它父亲断开,这样就有很多棵根节点红色其他节点黑色的树(红根树)。
每次询问k个点,我们只需要考虑它们所在的红根树当中,max(cost)最大和次大的两棵树。
现在要选某个点变红,肯定是选在最大的红根树里。将max(cost)最小化后,和次大的树比一下就行了。
问题转化成,在红根树里选一个点变红,使k'个点的max(cost)最小。那就把k'个点按cost从小到大排序,求k'个后缀LCA就行了。
(可惜比赛的时候一句话写错了,搞半天还以为思路错了)
复杂度:O(n*log(n))
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,ll> P;
const int N=100005;
int red[N];
int redfa[N];
vector<P>link[N];
ll cost[N];
int deep[N];
int sz[N],son[N],fa[N];
void dfs(int u,int _fa,int _red,ll _cost,int _deep){
fa[u]=_fa;
deep[u]=_deep;
cost[u]=_cost;
redfa[u]=_red;
sz[u]=1;
son[u]=-1;
for (P p:link[u]){
int v=p.first; ll w=p.second;
if (v==_fa) continue;
if (red[v]) dfs(v,u,v,0,_deep+1);
else dfs(v,u,_red,_cost+w,_deep+1);
sz[u]+=sz[v];
if (son[u]==-1||sz[v]>sz[son[u]]) son[u]=v;
}
}
int top[N],id[N],T;
void dfs2(int u,int _fa,int p){
top[u]=p;
id[u]=++T;
if (son[u]!=-1) dfs2(son[u],u,p);
for (P p:link[u]){
int v=p.first; ll w=p.second;
if (v==_fa||v==son[u]) continue;
dfs2(v,u,v);
}
}
int getlca(int x,int y){
if (x==0) return y;
int f1=top[x],f2=top[y];
while(f1!=f2){
if (deep[f1]<deep[f2]) swap(f1,f2),swap(x,y);
x=fa[f1],f1=top[x];
}
if (deep[x]<deep[y]) swap(x,y);
return y;
}
int stk[N];
int que[N];
ll f[N];
vector<int>VS[N];
int vis[N];
bool cmp(int i,int j){
return f[i]>f[j];
}
bool cmp2(int i,int j){
return cost[i]<cost[j];
}
int main(){
int t; scanf("%d",&t);
while(t--){
int n,m,q; scanf("%d%d%d",&n,&m,&q);
for (int i=1;i<=m;i++){
int x; scanf("%d",&x);
red[x]=1;
}
for (int i=1;i<n;i++){
int x,y; ll w; scanf("%d%d%lld",&x,&y,&w);
link[x].push_back({y,w});
link[y].push_back({x,w});
}
dfs(1,0,1,0,1);
dfs2(1,0,1);
for (int T=1;T<=q;T++){
int num; scanf("%d",&num);
int top=0;
for (int i=1;i<=num;i++){
scanf("%d",&que[i]);
int _red=redfa[que[i]];
VS[_red].emplace_back(que[i]);
if (vis[_red]!=T) stk[++top]=_red,f[_red]=cost[que[i]];
else f[_red]=max(f[_red],cost[que[i]]);
vis[_red]=T;
}
sort(stk+1,stk+top+1,cmp);
int _red=stk[1];
sort(VS[_red].begin(),VS[_red].end(),cmp2);
int tot=VS[_red].size();
ll ret = cost[VS[_red][tot - 1]];
int lca=VS[_red][tot - 1];
for (int i=tot-1;i>=0;i--){
lca=getlca(lca,VS[_red][i]);
ll tmp;
if (i!=0) tmp=cost[VS[_red][i-1]];
else tmp=0;
ret=min(ret,max(tmp,f[_red]-cost[lca]));
}
if (top==1) printf("%lld\n",ret);
else printf("%lld\n",max(ret,f[stk[2]]));
for (int i=1;i<=top;i++) VS[stk[i]].clear();
}
for (int i=1;i<=n;i++){
link[i].clear();
red[i]=0;
vis[i]=0;
}
}
}
思路2:
二分答案,不满足答案的LCA合并起来再看是否满足。(大家基本是这么做的)
复杂度:O(n*log(1e14)*log(n)) 但是树链剖分求LCA实在是太快了,也能过
求LCA的地方还可以转RMQ用ST表做到O(1),就是写起来比较麻烦
(试了下倍增求LCA,T掉了,果然树剖大法好!)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,ll> P;
const int N=100005;
int red[N];
int redfa[N];
vector<P>link[N];
ll cost[N];
int deep[N];
int que[N];
int sz[N],son[N],fa[N];
void dfs(int u,int _fa,int _red,ll _cost,int _deep){
fa[u]=_fa;
deep[u]=_deep;
cost[u]=_cost;
redfa[u]=_red;
sz[u]=1;
son[u]=-1;
for (P p:link[u]){
int v=p.first; ll w=p.second;
if (v==_fa) continue;
if (red[v]) dfs(v,u,v,0,_deep+1);
else dfs(v,u,_red,_cost+w,_deep+1);
sz[u]+=sz[v];
if (son[u]==-1||sz[v]>sz[son[u]]) son[u]=v;
}
}
int top[N],id[N],T;
void dfs2(int u,int _fa,int p){
top[u]=p;
id[u]=++T;
if (son[u]!=-1) dfs2(son[u],u,p);
for (P p:link[u]){
int v=p.first; ll w=p.second;
if (v==_fa||v==son[u]) continue;
dfs2(v,u,v);
}
}
int getlca(int x,int y){
if (x==0) return y;
int f1=top[x],f2=top[y];
while(f1!=f2){
if (deep[f1]<deep[f2]) swap(f1,f2),swap(x,y);
x=fa[f1],f1=top[x];
}
if (deep[x]<deep[y]) swap(x,y);
return y;
}
bool check(ll len,int num){
int lca=0;
for (int i=1;i<=num;i++)
if (cost[que[i]]>len) lca=getlca(lca,que[i]);
for (int i=1;i<=num;i++)
if (cost[que[i]]>len){
if (deep[redfa[que[i]]]>=deep[lca]) return 0;
if (cost[que[i]]-cost[lca]>len) return 0;
}
return 1;
}
int main(){
int t; scanf("%d",&t);
while(t--){
int n,m,q; scanf("%d%d%d",&n,&m,&q);
for (int i=1;i<=m;i++){
int x; scanf("%d",&x);
red[x]=1;
}
for (int i=1;i<n;i++){
int x,y; ll w; scanf("%d%d%lld",&x,&y,&w);
link[x].push_back({y,w});
link[y].push_back({x,w});
}
dfs(1,0,1,0,1);
dfs2(1,0,1);
for (int T=1;T<=q;T++){
int num; scanf("%d",&num);
for (int i=1;i<=num;i++) scanf("%d",&que[i]);
ll l=0,r=1e14;
while(l<r){
ll mid=(l+r)/2;
if (!check(mid,num)) l=mid+1;
else r=mid;
}
printf("%lld\n",l);
}
for (int i=1;i<=n;i++){
link[i].clear();
red[i]=0;
deep[i]=0;
}
}
}