hdu 4897(树剖好题)
题意:
一棵树,一开始所有边都是白的,3种操作。
1.把a到b路径上的边变色(白变黑,黑变白)
2.把与所有(a到b路径上的点)相连的边都改变颜色(不包括路径上的边)
3.询问a到b路径上有多少黑边
做法:
先树链剖分
如果只有操作1,就是普通的线段树维护重链的做法。
但是有操作2,暴力修改所有儿子的话每次都是O(n),于是想到把要不要变色的信息存在父节点,查询的时候查一下父节点是否要求它变色。
具体怎么修改就是轻重边分开考虑。先把所有路径上的点都在point线段树里修改。
对于路径上的轻边,因为它父亲被改变了而这条轻边是不用修改的,那就把这条轻边在edge线段树里修改来抵消父亲对它的影响。
对于路径外需要修改的重边,直接在edge线段树上修改。(这样的重边只有logn条)
查询的时候,对于路径上的重边,直接在边的线段树上求答案,对于路径上的轻边,还要结合它的父亲来判断它是否是黑的。
代码应该还是比较好懂的。
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; const int N=100005; int size[N],top[N],son[N],fa[N],id[N],rd[N],deep[N],n,idn; int base[N],vec[2*N],pre[2*N],tot; struct SegTree{ int f[4*N],sum[4*N]; bool lazy[4*N]; void build(int i,int l,int r) { if (l==r) {sum[i]=1; f[i]=0; lazy[i]=0; return;} int mid=(l+r)/2; build(i*2,l,mid); build(i*2+1,mid+1,r); sum[i]=sum[i*2]+sum[i*2+1]; f[i]=0; lazy[i]=0; } void pd(int i) { if (lazy[i]) { lazy[i]=0; f[i*2]=sum[i*2]-f[i*2]; f[i*2+1]=sum[i*2+1]-f[i*2+1]; lazy[i*2]^=1; lazy[i*2+1]^=1; } } void up(int i) { f[i]=f[i*2]+f[i*2+1]; } void add(int i,int l,int r,int x,int y) { if (x>y) return; if (x<=l&&r<=y) { f[i]=sum[i]-f[i]; lazy[i]^=1; return ; } pd(i); int mid=(l+r)/2; if (x<=mid) add(i*2,l,mid,x,y); if (y>mid) add(i*2+1,mid+1,r,x,y); up(i); } int get(int i,int l,int r,int x,int y) { if (x>y) return 0; if (x<=l&&r<=y) return f[i]; pd(i); int tmp=0; int mid=(l+r)/2; if (x<=mid) tmp+=get(i*2,l,mid,x,y); if (y>mid) tmp+=get(i*2+1,mid+1,r,x,y); return tmp; } }edg,pnt; void init() { memset(top,0,sizeof(top)); memset(son,-1,sizeof(son)); memset(base,0,sizeof(base)); tot=idn=0; } void link(int x,int y) { tot++; vec[tot]=y; pre[tot]=base[x]; base[x]=tot; } void dfs1(int u,int p) { fa[u]=p; size[u]=1; son[u]=-1; for (int now=base[u];now;now=pre[now]) { int v=vec[now]; if (v==p) continue; deep[v]=deep[u]+1; dfs1(v,u); size[u]+=size[v]; if (son[u]==-1||size[v]>size[son[u]]) son[u]=v; } } void dfs2(int u,int p) { top[u]=p; id[u]=++idn; rd[idn]=u; if (son[u]!=-1) dfs2(son[u],p); for (int now=base[u];now;now=pre[now]) { int v=vec[now]; if (v==son[u]||v==fa[u]) continue; dfs2(v,v); } } void cg1(int a,int b) { int f1=top[a]; int f2=top[b]; while(f1!=f2) { if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b); int x=id[f1], y=id[a]; edg.add(1,1,n,x,y); a=fa[f1]; f1=top[a]; } if (deep[a]>deep[b]) swap(a,b); if (a==b) return; int x=id[a], y=id[b]; edg.add(1,1,n,x+1,y); } void cg2(int a,int b) { int f1=top[a]; int f2=top[b]; while(f1!=f2) { if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b); int x=id[f1], y=id[a]; pnt.add(1,1,n,x,y); edg.add(1,1,n,x,x); if (son[a]!=-1) edg.add(1,1,n,id[son[a]],id[son[a]]); a=fa[f1]; f1=top[a]; } if (deep[a]>deep[b]) swap(a,b); int x=id[a], y=id[b]; pnt.add(1,1,n,x,y); edg.add(1,1,n,x,x); if (son[b]!=-1) edg.add(1,1,n,id[son[b]],id[son[b]]); } int query(int a,int b) { int ans=0; int f1=top[a]; int f2=top[b]; while(f1!=f2) { if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b); int x=id[f1], y=id[a]; ans+=edg.get(1,1,n,x+1,y); int tmp1=edg.get(1,1,n,x,x); int z=id[fa[f1]]; int tmp2=pnt.get(1,1,n,z,z); ans+=tmp1^tmp2; a=fa[f1]; f1=top[a]; } if (deep[a]>deep[b]) swap(a,b); int x=id[a], y=id[b]; ans+=edg.get(1,1,n,x+1,y); return ans; } int main() { int t; scanf("%d",&t); while(t--) { scanf("%d",&n); init(); edg.build(1,1,n); pnt.build(1,1,n); for (int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); link(x,y),link(y,x); } dfs1(1,1); dfs2(1,1); int q; scanf("%d",&q); while(q--) { int tp,x,y; scanf("%d%d%d",&tp,&x,&y); if (tp==1) { cg1(x,y); } if (tp==2) { cg2(x,y); pnt.get(1,1,n,1,1); } if (tp==3) { printf("%d\n",query(x,y)); } } } }