题意:
一棵树,一开始所有边都是白的,3种操作。
1.把a到b路径上的边变色(白变黑,黑变白)
2.把与所有(a到b路径上的点)相连的边都改变颜色(不包括路径上的边)
3.询问a到b路径上有多少黑边
做法:
先树链剖分
如果只有操作1,就是普通的线段树维护重链的做法。
但是有操作2,暴力修改所有儿子的话每次都是O(n),于是想到把要不要变色的信息存在父节点,查询的时候查一下父节点是否要求它变色。
具体怎么修改就是轻重边分开考虑。先把所有路径上的点都在point线段树里修改。
对于路径上的轻边,因为它父亲被改变了而这条轻边是不用修改的,那就把这条轻边在edge线段树里修改来抵消父亲对它的影响。
对于路径外需要修改的重边,直接在edge线段树上修改。(这样的重边只有logn条)
查询的时候,对于路径上的重边,直接在边的线段树上求答案,对于路径上的轻边,还要结合它的父亲来判断它是否是黑的。
代码应该还是比较好懂的。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=100005;
int size[N],top[N],son[N],fa[N],id[N],rd[N],deep[N],n,idn;
int base[N],vec[2*N],pre[2*N],tot;
struct SegTree{
int f[4*N],sum[4*N];
bool lazy[4*N];
void build(int i,int l,int r)
{
if (l==r) {sum[i]=1; f[i]=0; lazy[i]=0; return;}
int mid=(l+r)/2;
build(i*2,l,mid);
build(i*2+1,mid+1,r);
sum[i]=sum[i*2]+sum[i*2+1];
f[i]=0; lazy[i]=0;
}
void pd(int i)
{
if (lazy[i])
{
lazy[i]=0;
f[i*2]=sum[i*2]-f[i*2];
f[i*2+1]=sum[i*2+1]-f[i*2+1];
lazy[i*2]^=1;
lazy[i*2+1]^=1;
}
}
void up(int i)
{
f[i]=f[i*2]+f[i*2+1];
}
void add(int i,int l,int r,int x,int y)
{
if (x>y) return;
if (x<=l&&r<=y)
{
f[i]=sum[i]-f[i];
lazy[i]^=1;
return ;
}
pd(i);
int mid=(l+r)/2;
if (x<=mid) add(i*2,l,mid,x,y);
if (y>mid) add(i*2+1,mid+1,r,x,y);
up(i);
}
int get(int i,int l,int r,int x,int y)
{
if (x>y) return 0;
if (x<=l&&r<=y) return f[i];
pd(i);
int tmp=0;
int mid=(l+r)/2;
if (x<=mid) tmp+=get(i*2,l,mid,x,y);
if (y>mid) tmp+=get(i*2+1,mid+1,r,x,y);
return tmp;
}
}edg,pnt;
void init()
{
memset(top,0,sizeof(top));
memset(son,-1,sizeof(son));
memset(base,0,sizeof(base));
tot=idn=0;
}
void link(int x,int y)
{
tot++;
vec[tot]=y; pre[tot]=base[x]; base[x]=tot;
}
void dfs1(int u,int p)
{
fa[u]=p;
size[u]=1;
son[u]=-1;
for (int now=base[u];now;now=pre[now])
{
int v=vec[now];
if (v==p) continue;
deep[v]=deep[u]+1;
dfs1(v,u);
size[u]+=size[v];
if (son[u]==-1||size[v]>size[son[u]]) son[u]=v;
}
}
void dfs2(int u,int p)
{
top[u]=p;
id[u]=++idn;
rd[idn]=u;
if (son[u]!=-1) dfs2(son[u],p);
for (int now=base[u];now;now=pre[now])
{
int v=vec[now];
if (v==son[u]||v==fa[u]) continue;
dfs2(v,v);
}
}
void cg1(int a,int b)
{
int f1=top[a]; int f2=top[b];
while(f1!=f2)
{
if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b);
int x=id[f1], y=id[a];
edg.add(1,1,n,x,y);
a=fa[f1]; f1=top[a];
}
if (deep[a]>deep[b]) swap(a,b);
if (a==b) return;
int x=id[a], y=id[b];
edg.add(1,1,n,x+1,y);
}
void cg2(int a,int b)
{
int f1=top[a]; int f2=top[b];
while(f1!=f2)
{
if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b);
int x=id[f1], y=id[a];
pnt.add(1,1,n,x,y);
edg.add(1,1,n,x,x);
if (son[a]!=-1) edg.add(1,1,n,id[son[a]],id[son[a]]);
a=fa[f1]; f1=top[a];
}
if (deep[a]>deep[b]) swap(a,b);
int x=id[a], y=id[b];
pnt.add(1,1,n,x,y);
edg.add(1,1,n,x,x);
if (son[b]!=-1) edg.add(1,1,n,id[son[b]],id[son[b]]);
}
int query(int a,int b)
{
int ans=0;
int f1=top[a]; int f2=top[b];
while(f1!=f2)
{
if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b);
int x=id[f1], y=id[a];
ans+=edg.get(1,1,n,x+1,y);
int tmp1=edg.get(1,1,n,x,x);
int z=id[fa[f1]];
int tmp2=pnt.get(1,1,n,z,z);
ans+=tmp1^tmp2;
a=fa[f1]; f1=top[a];
}
if (deep[a]>deep[b]) swap(a,b);
int x=id[a], y=id[b];
ans+=edg.get(1,1,n,x+1,y);
return ans;
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
scanf("%d",&n);
init();
edg.build(1,1,n);
pnt.build(1,1,n);
for (int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
link(x,y),link(y,x);
}
dfs1(1,1);
dfs2(1,1);
int q;
scanf("%d",&q);
while(q--)
{
int tp,x,y;
scanf("%d%d%d",&tp,&x,&y);
if (tp==1)
{
cg1(x,y);
}
if (tp==2)
{
cg2(x,y);
pnt.get(1,1,n,1,1);
}
if (tp==3)
{
printf("%d\n",query(x,y));
}
}
}
}