BNUOJ 53081 线段树区间合并

题目:

http://www.bnuoj.com/bnuoj/problem_show.php?pid=53081

 

思路:

将‘(’看作1,‘)’看作-1,每次询问:从x开始所有前缀的前缀和均非负,且和为0的最长前缀。

维护线段树节点代表区间的最小前缀和,每次询问处理方式为:

若当前区间[x,i]合并当前节点后最小前缀和非负,则合并,否则不合并且终止查询。

输出合并后的区间内最小前缀对应的长度即可。

 

#include<bits/stdc++.h>
using namespace std;

const int N=500005;

struct line
{
    int len,mi,id,sum;
    line operator+(const line &l)const
    {
        line tmp=*this;
        tmp.len+=l.len;
        if (sum+l.mi<=tmp.mi)
            tmp.mi=sum+l.mi,tmp.id=l.id;
        tmp.sum+=l.sum;
        return tmp;
    }
}a[N*4];

char s[N];

void build(int i,int l,int r)
{
    if (l==r)
    {
        if (s[l]=='(') a[i]={1,1,l,1};
        else a[i]={1,-1,l,-1};
        return;
    }
    int mid=(l+r)/2;
    build(i*2,l,mid);
    build(i*2+1,mid+1,r);
    a[i]=a[i*2]+a[i*2+1];
}
void cg(int i,int l,int r,int x)
{
    if (l==r)
    {
        if (s[l]=='(') a[i]={1,1,l,1};
        else a[i]={1,-1,l,-1};
        return;
    }
    int mid=(l+r)/2;
    if (x<=mid) cg(i*2,l,mid,x);
    else cg(i*2+1,mid+1,r,x);
    a[i]=a[i*2]+a[i*2+1];
}

line L;
bool get(int i,int l,int r,int x)
{
    if (x<=l)
    {
        if ((L+a[i]).mi>=0)
        {
            L=L+a[i];
            return 1;
        }
        else
        {
            if (l==r) return 0;
            int mid=(l+r)/2;
            if (get(i*2,l,mid,x))
                get(i*2+1,mid+1,r,x);
            return 0;
        }
    }
    int mid=(l+r)/2;
    if (x<=mid)
    {
        if (get(i*2,l,mid,x)&&get(i*2+1,mid+1,r,x)) return 1;
        return 0;
    }
    if (get(i*2+1,mid+1,r,x)) return 1;
    return 0;
}
int main()
{
    int t; scanf("%d",&t);
    while(t--)
    {
        int n,q; scanf("%d%d",&n,&q);
        scanf("%s",s+1);
        build(1,1,n);
        while(q--)
        {
            int tp,x; scanf("%d%d",&tp,&x);
            if (tp==1)
            {
                s[x]='('+')'-s[x];
                cg(1,1,n,x);
            }
            else
            {
                L={0,0,x-1,0};
                get(1,1,n,x);
                printf("%d\n",L.id-x+1);
            }
        }
    }
}

Codeforces 600E. Lomsat gelral (线段树合并)

为什么线段树合并复杂度是O(nlogn)?因为每次merge都会删除一个节点,一开始节点是nlogn个,所以总复杂度就是O(nlogn)。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=100005;
int a[N];

struct NODE
{
	int l,r,ls,rs,mx;
	ll cnt;
}tree[N*18];
int tot;
int root[N];
void init()
{
	tot=0;
}
int newnode(int l,int r)
{
	tot++;
	tree[tot]={l,r,0,0,0,0};
	return tot;
}
void up(int i)
{
	int ls=tree[i].ls,rs=tree[i].rs;
	tree[i].mx=max(tree[ls].mx,tree[rs].mx);
	tree[i].cnt=0;
	if (tree[ls].mx==tree[i].mx) tree[i].cnt+=tree[ls].cnt;
	if (tree[rs].mx==tree[i].mx) tree[i].cnt+=tree[rs].cnt;
}
int build(int l,int r,int x)
{
	int np=newnode(l,r);
	if (l==r)
	{
		tree[np].mx=1;
		tree[np].cnt=x;
		return np;
	}
	int mid=(l+r)/2;
	if (x<=mid) tree[np].ls=build(l,mid,x);
	else tree[np].rs=build(mid+1,r,x);
	up(np);
	return np;
}
int merge(int x,int y,int l,int r)
{
	if (x==0) return y;
	if (y==0) return x;
	if (l==r)
	{
		tree[x].mx+=tree[y].mx;
		return x;
	}
	int mid=(l+r)/2;
	tree[x].ls=merge(tree[x].ls,tree[y].ls,l,mid);
	tree[x].rs=merge(tree[x].rs,tree[y].rs,mid+1,r);
	up(x);
	return x;
}

vector<int>link[N];
ll f[N];
int n;
void dfs(int u,int fa)
{
	root[u]=build(1,n,a[u]);
	for (int i=0;i<link[u].size();i++)
	{
		int v=link[u][i];
		if (v==fa) continue;
		dfs(v,u);
		merge(root[u],root[v],1,n);
	}
	f[u]=tree[root[u]].cnt;
}
int main()
{
	scanf("%d",&n);
	for (int i=1;i<=n;i++) scanf("%d",&a[i]);
	for (int i=1;i<n;i++)
	{
		int x,y; scanf("%d%d",&x,&y);
		link[x].push_back(y);
		link[y].push_back(x);
	}
	init();
	dfs(1,1);
	for (int i=1;i<=n;i++) printf("%lld ",f[i]);
	for (int i=1;i<=n;i++) link[i].clear();

}
/*

http://codeforces.com/contest/600/problem/E

*/

HDU 5957 环套树+线段树维护bfs序

题意:

n个点n条边的连通图,边长度都为1,没有重边自环,两种操作:

1.把与x距离小于等于k的所有点的权值增加d  (k<=2)

2.问离x距离小于等于k的所有点的权值和 (k<=2)

 

思路:

首先这是个环套树,找到环后从环上的每个点开始bfs,处理出每个点向子树走k步的bfs序区间,一层层维护就好了。

分类讨论比较麻烦。

 

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100005;
ll lazy[4*N],sum[4*N];
int n;
void init()
{
    memset(lazy,0,sizeof(lazy));
    memset(sum,0,sizeof(sum));
}
void up(int x)
{
	sum[x]=sum[x*2]+sum[x*2+1];
}
void down(int x,int l,int r)
{
	if (lazy[x])
	{
		int mid=(l+r)/2;
		sum[x*2]+=(mid-l+1)*lazy[x];
		sum[x*2+1]+=(r-mid)*lazy[x];
		lazy[x*2]+=lazy[x];
		lazy[x*2+1]+=lazy[x];
		lazy[x]=0;
	}
}
void add(int i,int l,int r,int x,int y,int w)
{
	if (x<=l&&r<=y)
	{
		sum[i]+=w*(r-l+1);
		lazy[i]+=w;
		return;
	}
	down(i,l,r);
	int mid=(l+r)/2;
	if (x<=mid) add(i*2,l,mid,x,y,w);
	if (y>mid) add(i*2+1,mid+1,r,x,y,w);
	up(i);
}
ll get(int i,int l,int r,int x,int y)
{
	if (x<=l&&r<=y) return sum[i];
	down(i,l,r);
	int mid=(l+r)/2;
	ll ret=0;
	if (x<=mid) ret+=get(i*2,l,mid,x,y);
	if (y>mid) ret+=get(i*2+1,mid+1,r,x,y);
	return ret;
}

void ad(int l,int r,int w)
{
	if (l>r) return;
	add(1,1,n,l,r,w);
}
ll gt(int l,int r)
{
	if (l>r) return 0;
	return get(1,1,n,l,r);
}

bool vis[N],on[N];
int l[N],r[N];
vector<int>link[N];
int dfs1(int u,int fa)
{
	vis[u]=1;
	for (int i=0;i<link[u].size();i++)
	{
		int v=link[u][i];
		if (v==fa) continue;
		l[v]=u;
		if (vis[v]) return v;
		int tmp=dfs1(v,u);
		if (tmp) return tmp;
	}
	return 0;
}
int tot;
int q[N],bl[N][3],br[N][3],fa[N];
void bfs(int u)
{
	vis[u]=1;
	q[++tot]=u;
	bl[u][0]=tot; bl[u][1]=bl[u][2]=N;
	br[u][0]=tot; br[u][1]=br[u][2]=0;
	int head=tot;
	while(head<=tot)
	{
		int u=q[head]; head++;
		for (int i=0;i<link[u].size();i++)
		{
			int v=link[u][i];
			if (vis[v]||on[v]) continue;
			q[++tot]=v;
			bl[v][0]=tot; bl[v][1]=bl[v][2]=N;
			br[v][0]=tot; br[v][1]=br[v][2]=0;
			fa[v]=u;
			vis[v]=1;
		}
	}
}

void look(int n)
{
	for (int i=1;i<=n;i++)
	{
		printf("%d %lld\n",i,get(1,1,n,bl[i][0],br[i][0]));
	}
}
char s[10];
int main()
{
	int t; scanf("%d",&t);
	while(t--)
	{
		scanf("%d",&n);
		for (int i=1;i<=n;i++)
		{
			int x,y; scanf("%d%d",&x,&y);
			link[x].push_back(y);
			link[y].push_back(x);
		}

		for (int i=1;i<=n;i++) vis[i]=0,on[i]=0;
		int root=dfs1(1,1);
		r[l[root]]=root; on[root]=1;
        for (int i=l[root];i!=root;i=l[i])  r[l[i]]=i,on[i]=1;

        tot=0;
        for (int i=1;i<=n;i++) vis[i]=0;
        bfs(root);
		for (int i=l[root];i!=root;i=l[i]) bfs(i);
        for (int k=1;k<=2;k++)
			for (int i=tot;i>=1;i--)
			{
				int v=q[i];
				if (on[v]) continue;
				int u=fa[v];
				bl[u][k]=min(bl[u][k],bl[v][k-1]);
				br[u][k]=max(br[u][k],br[v][k-1]);
			}

		init();
        int m; scanf("%d",&m);
		while(m--)
		{
			//look(n);
			scanf("%s",s);
			if (s[0]=='M')
			{
				int u,k,w; scanf("%d%d%d",&u,&k,&w);
				if (k==0) ad(bl[u][0],br[u][0],w);
				if (k==1)
				{
					ad(bl[u][0],br[u][0],w);
					ad(bl[u][1],br[u][1],w);//
					if (on[u])
					{
						int lv=l[u],rv=r[u];
						ad(bl[lv][0],br[lv][0],w);
						ad(bl[rv][0],br[rv][0],w);
					}
					else
					{
						int v=fa[u];
						ad(bl[v][0],br[v][0],w);
					}
				}
				if (k==2)
				{
					ad(bl[u][0],br[u][0],w);
					ad(bl[u][1],br[u][1],w);
					ad(bl[u][2],br[u][2],w);
					if (on[u])
					{
						int lv=l[u];
						ad(bl[lv][0],br[lv][0],w);
						ad(bl[lv][1],br[lv][1],w);
						int rv=r[u];
						ad(bl[rv][0],br[rv][0],w);
						ad(bl[rv][1],br[rv][1],w);
						lv=l[lv];
						if (lv!=rv)
							ad(bl[lv][0],br[lv][0],w);
						else lv=r[lv];
						rv=r[rv];
						if (rv!=lv)
							ad(bl[rv][0],br[rv][0],w);
					}
					else
					{
						int v=fa[u];
						ad(bl[v][0],br[v][0],w);
						ad(bl[v][1],br[v][1],w);
						ad(bl[u][0],br[u][0],-w);
						if (on[v])
						{
							int lv=l[v],rv=r[v];
							ad(bl[lv][0],br[lv][0],w);
							ad(bl[rv][0],br[rv][0],w);
						}
						else
						{
							v=fa[v];
							ad(bl[v][0],br[v][0],w);
						}
					}
				}
			}
			if (s[0]=='Q')
			{
				int u,k,w; scanf("%d%d",&u,&k);
				ll ret=0;
				if (k==0) ret+=gt(bl[u][0],br[u][0]);
				if (k==1)
				{
					ret+=gt(bl[u][0],br[u][0]);
					ret+=gt(bl[u][1],br[u][1]);
					if (on[u])
					{
						int lv=l[u],rv=r[u];
						ret+=gt(bl[lv][0],br[lv][0]);
						ret+=gt(bl[rv][0],br[rv][0]);
					}
					else
					{
						int v=fa[u];
						ret+=gt(bl[v][0],br[v][0]);
					}
				}
				if (k==2)
				{
					ret+=gt(bl[u][0],br[u][0]);
					ret+=gt(bl[u][1],br[u][1]);
					ret+=gt(bl[u][2],br[u][2]);
					if (on[u])
					{
						int lv=l[u];
						ret+=gt(bl[lv][0],br[lv][0]);
						ret+=gt(bl[lv][1],br[lv][1]);
						int rv=r[u];
						ret+=gt(bl[rv][0],br[rv][0]);
						ret+=gt(bl[rv][1],br[rv][1]);
						lv=l[lv];
						if (lv!=rv)
							ret+=gt(bl[lv][0],br[lv][0]);
						else lv=r[lv];
						rv=r[rv];
						if (rv!=lv)
							ret+=gt(bl[rv][0],br[rv][0]);
					}
					else
					{
						int v=fa[u];
						ret+=gt(bl[v][0],br[v][0]);
						ret+=gt(bl[v][1],br[v][1]);
						ret-=gt(bl[u][0],br[u][0]);
						if (on[v])
						{
							int lv=l[v],rv=r[v];
							ret+=gt(bl[lv][0],br[lv][0]);
							ret+=gt(bl[rv][0],br[rv][0]);
						}
						else
						{
							v=fa[v];
							ret+=gt(bl[v][0],br[v][0]);
						}
					}
				}
				printf("%lld\n",ret);
			}

		}

		for (int i=1;i<=n;i++) link[i].clear();
	}
}
/*
1
8
1 2
2 3
3 4
1 4
2 5
3 6
5 7
5 8
12
MODIFY 8 1 5
MODIFY 8 2 2
QUERY 5 0
QUERY 5 1
QUERY 5 2
MODIFY 7 2 2
QUERY 7 2
MODIFY 3 1 5
MODIFY 2 2 2
QUERY 6 1
MODIFY 4 1 -2
QUERY 2 2
*/

 

HDU 5052 树链剖分

题意:给你一棵树,每个点有点权,问从x走到y,先后取两个数a和b,b-a的最大值。

 

做法:树链剖分一下,然后线段树把路径区间都合并起来。

 

比赛的时候居然没想出来。。。(现在想想不是很简单嘛。。。)

 

#include<cstdio>
#include<algorithm>
using namespace std;
const int N=50005;
int a[N];
int base[N],vec[2*N],pre[2*N],tot;
int id[N],rd[N],son[N],deep[N],size[N],fa[N],top[N];
int T;
int n;
void link(int x,int y)
{
	vec[++tot]=y; pre[tot]=base[x]; base[x]=tot;
}

struct seg
{
	int ma,mi,fl,fr;
	seg operator+(const seg &t)const
	{
		if (ma==-1) return t;
		if (t.ma==-1) return *this;
		int maa=max(ma,t.ma);
		int mii=min(mi,t.mi);
		int ffl=max(fl,t.fl);
		ffl=max(ffl,ma-t.mi);
		int ffr=max(fr,t.fr);
		ffr=max(ffr,t.ma-mi);
		return {maa,mii,ffl,ffr};
	}
	void operator+=(const int t)
	{
		ma+=t; mi+=t;
	}
}s[4*N];
int lazy[4*N];
void build(int i,int l,int r)
{
	lazy[i]=0;
	if (l==r)
	{
		s[i]={a[rd[l]],a[rd[l]],0};
		return;
	}
	int mid=(l+r)/2;
	build(i*2,l,mid);
	build(i*2+1,mid+1,r);
	s[i]=s[i*2]+s[i*2+1];
}
void pd(int i)
{
	if (lazy[i])
	{
		lazy[i*2]+=lazy[i];
		lazy[i*2+1]+=lazy[i];
		s[i*2]+=lazy[i];
		s[i*2+1]+=lazy[i];
		lazy[i]=0;
	}
}
void add(int i,int l,int r,int x,int y,int z)
{
	if (x<=l&&r<=y)
	{
		s[i]+=z;
		lazy[i]+=z;
		return;
	}
	pd(i);
	int mid=(l+r)/2;
	if (x<=mid) add(i*2,l,mid,x,y,z);
	if (y>mid) add(i*2+1,mid+1,r,x,y,z);
	s[i]=s[i*2]+s[i*2+1];
}
seg get(int i,int l,int r,int x,int y)
{
	if (x<=l&&r<=y) return s[i];
	pd(i);
	int mid=(l+r)/2;
	seg tmp={-1,-1,-1,-1};
	if (x<=mid) tmp=tmp+get(i*2,l,mid,x,y);
	if (y>mid) tmp=tmp+get(i*2+1,mid+1,r,x,y);
	return tmp;
}


void dfs1(int u,int d,int p)
{
	fa[u]=p;
	deep[u]=d;
	son[u]=-1;
	size[u]=1;
	for (int now=base[u];now;now=pre[now])
	{
		int v=vec[now];
		if (v==p) continue;
		dfs1(v,d+1,u);
		size[u]+=size[v];
		if (son[u]==-1||size[v]>size[son[u]]) son[u]=v;
	}
}
void dfs2(int u,int p)
{
	top[u]=p;
	id[u]=++T;
	rd[T]=u;
	if (son[u]!=-1) dfs2(son[u],p);
	for (int now=base[u];now;now=pre[now])
	{
		int v=vec[now];
		if (v==fa[u]||v==son[u]) continue;
		dfs2(v,v);
	}
}

int lca(int x,int y)
{
	int f1=top[x],f2=top[y];
	while(f1!=f2)
	{
		if (deep[f1]<deep[f2]) swap(f1,f2),swap(x,y);
		x=fa[f1],f1=top[x];
	}
	if (deep[x]<deep[y]) swap(x,y);
	return y;
}

seg getseg(int x,int LCA,int v)
{
	seg tmp={-1,-1,-1,-1};
	while(x!=LCA&&deep[top[x]]>deep[LCA])
	{
		tmp=get(1,1,n,id[top[x]],id[x])+tmp;
		add(1,1,n,id[top[x]],id[x],v);
		x=fa[top[x]];
	}
	tmp=get(1,1,n,id[LCA],id[x])+tmp;
	add(1,1,n,id[LCA],id[x],v);
	return tmp;
}
int solve(int x,int y,int v)
{
	int LCA=lca(x,y);

	seg tmp1=getseg(x,LCA,v);
	add(1,1,n,id[LCA],id[LCA],-v);
	seg tmp2=getseg(y,LCA,v);

	int ans=max(tmp1.fl,tmp2.fr);
	ans=max(ans,tmp2.ma-tmp1.mi);
	return ans;
}

int main()
{
	int t; scanf("%d",&t);
	while(t--)
	{
		scanf("%d",&n);
		for (int i=1;i<=n;i++) scanf("%d",&a[i]);
		for (int i=1;i<n;i++)
		{
			int x,y;
			scanf("%d%d",&x,&y);
			link(x,y);
			link(y,x);
		}

		dfs1(1,1,-1);
		dfs2(1,-1);
		build(1,1,n);

		int q;
		scanf("%d",&q);
		while(q--)
		{
			int x,y,v;
			scanf("%d%d%d",&x,&y,&v);
			printf("%d\n",solve(x,y,v));
		}
		tot=T=0;
		for (int i=1;i<=n;i++) base[i]=0;
	}
}

hdu 4897(树剖好题)

题意:

一棵树,一开始所有边都是白的,3种操作。

1.把a到b路径上的边变色(白变黑,黑变白)

2.把与所有(a到b路径上的点)相连的边都改变颜色(不包括路径上的边)

3.询问a到b路径上有多少黑边

 

做法:

先树链剖分

如果只有操作1,就是普通的线段树维护重链的做法。

但是有操作2,暴力修改所有儿子的话每次都是O(n),于是想到把要不要变色的信息存在父节点,查询的时候查一下父节点是否要求它变色。

 

具体怎么修改就是轻重边分开考虑。先把所有路径上的点都在point线段树里修改。

对于路径上的轻边,因为它父亲被改变了而这条轻边是不用修改的,那就把这条轻边在edge线段树里修改来抵消父亲对它的影响。

对于路径外需要修改的重边,直接在edge线段树上修改。(这样的重边只有logn条)

 

查询的时候,对于路径上的重边,直接在边的线段树上求答案,对于路径上的轻边,还要结合它的父亲来判断它是否是黑的。

 

代码应该还是比较好懂的。

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=100005;
int size[N],top[N],son[N],fa[N],id[N],rd[N],deep[N],n,idn;
int base[N],vec[2*N],pre[2*N],tot;
struct SegTree{
	int f[4*N],sum[4*N];
	bool lazy[4*N];
	void build(int i,int l,int r)
	{
		if (l==r) {sum[i]=1; f[i]=0; lazy[i]=0; return;}
		int mid=(l+r)/2;
		build(i*2,l,mid);
		build(i*2+1,mid+1,r);
		sum[i]=sum[i*2]+sum[i*2+1];
		f[i]=0; lazy[i]=0;
	}
	void pd(int i)
	{
		if (lazy[i])
		{
			lazy[i]=0;
			f[i*2]=sum[i*2]-f[i*2];
			f[i*2+1]=sum[i*2+1]-f[i*2+1];
			lazy[i*2]^=1;
			lazy[i*2+1]^=1;
		}
	}
	void up(int i)
	{
		f[i]=f[i*2]+f[i*2+1];
	}
	void add(int i,int l,int r,int x,int y)
	{
		if (x>y) return;
		if (x<=l&&r<=y)
		{
			f[i]=sum[i]-f[i];
			lazy[i]^=1;
			return ;
		}
		pd(i);
		int mid=(l+r)/2;
		if (x<=mid) add(i*2,l,mid,x,y);
		if (y>mid) add(i*2+1,mid+1,r,x,y);
		up(i);
	}
	int get(int i,int l,int r,int x,int y)
	{
		if (x>y) return 0;
		if (x<=l&&r<=y) return f[i];
		pd(i);
		int tmp=0;
		int mid=(l+r)/2;
		if (x<=mid) tmp+=get(i*2,l,mid,x,y);
		if (y>mid) tmp+=get(i*2+1,mid+1,r,x,y);
		return tmp;
	}
}edg,pnt;
void init()
{
	memset(top,0,sizeof(top));
	memset(son,-1,sizeof(son));
	memset(base,0,sizeof(base));
	tot=idn=0;
}
void link(int x,int y)
{
    tot++;
    vec[tot]=y;  pre[tot]=base[x];  base[x]=tot;
}
void dfs1(int u,int p)
{
	fa[u]=p;
    size[u]=1;
    son[u]=-1;
    for (int now=base[u];now;now=pre[now])
	{
        int v=vec[now];
        if (v==p) continue;
		deep[v]=deep[u]+1;
		dfs1(v,u);
		size[u]+=size[v];
		if (son[u]==-1||size[v]>size[son[u]]) son[u]=v;
    }
}
void dfs2(int u,int p)
{
	top[u]=p;
    id[u]=++idn;
	rd[idn]=u;
    if (son[u]!=-1) dfs2(son[u],p);
    for (int now=base[u];now;now=pre[now])
	{
		int v=vec[now];
		if (v==son[u]||v==fa[u]) continue;
        dfs2(v,v);
    }
}
void cg1(int a,int b)
{
    int f1=top[a]; int f2=top[b];
    while(f1!=f2)
    {
        if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b);
        int x=id[f1], y=id[a];
        edg.add(1,1,n,x,y);
        a=fa[f1]; f1=top[a];
    }
    if (deep[a]>deep[b]) swap(a,b);
    if (a==b) return;
    int x=id[a], y=id[b];
	edg.add(1,1,n,x+1,y);
}

void cg2(int a,int b)
{
	int f1=top[a]; int f2=top[b];
    while(f1!=f2)
    {
        if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b);
        int x=id[f1], y=id[a];
        pnt.add(1,1,n,x,y);
        edg.add(1,1,n,x,x);
        if (son[a]!=-1) edg.add(1,1,n,id[son[a]],id[son[a]]);
        a=fa[f1]; f1=top[a];
    }
    if (deep[a]>deep[b]) swap(a,b);
    int x=id[a], y=id[b];
	pnt.add(1,1,n,x,y);
	edg.add(1,1,n,x,x);
	if (son[b]!=-1) edg.add(1,1,n,id[son[b]],id[son[b]]);
}

int query(int a,int b)
{
	int ans=0;
	int f1=top[a]; int f2=top[b];
    while(f1!=f2)
    {
        if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b);
        int x=id[f1], y=id[a];
        ans+=edg.get(1,1,n,x+1,y);
        int tmp1=edg.get(1,1,n,x,x);
        int z=id[fa[f1]];
        int tmp2=pnt.get(1,1,n,z,z);
        ans+=tmp1^tmp2;
        a=fa[f1]; f1=top[a];
    }
    if (deep[a]>deep[b]) swap(a,b);
    int x=id[a], y=id[b];
	ans+=edg.get(1,1,n,x+1,y);
	return ans;
}

int main()
{
	int t;
	scanf("%d",&t);
    while(t--)
	{
		scanf("%d",&n);
		init();
		edg.build(1,1,n);
		pnt.build(1,1,n);
    	for (int i=1;i<n;i++)
		{
			int x,y;
			scanf("%d%d",&x,&y);
			link(x,y),link(y,x);
		}
    	dfs1(1,1);
    	dfs2(1,1);

    	int q;
    	scanf("%d",&q);
    	while(q--)
		{
			int tp,x,y;
			scanf("%d%d%d",&tp,&x,&y);
			if (tp==1)
			{
				cg1(x,y);
			}
			if (tp==2)
			{
				cg2(x,y);
				pnt.get(1,1,n,1,1);
			}
			if (tp==3)
			{
				printf("%d\n",query(x,y));
			}
		}
	}
}

 

 

bzoj-1036

终于学会了树链剖分,以前以为它很难,现在看来其实也挺简单的,就两个dfs+线段树维护一下嘛。第一次交忘记修改的是x对应的编号,WA了一次,修改后AC。

#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<algorithm>
#include<queue>
#include<functional>
using namespace std;
const int N=30005;
int size[N],top[N],son[N],f[N],id[N],c[N],w[N],deep[N],n,i,q,tot,idn,x,y;
int base[N],vec[2*N],pre[2*N];
struct node{
    int sum,mx,l,r,mid;
}tr[4*N];
char chr[10];
bool vis[N];
void add(int x,int y)
{
    tot++;
    vec[tot]=y;  pre[tot]=base[x];  base[x]=tot;
}
void dfs1(int u)
{
    vis[u]=true;
    int now=base[u];
    size[u]=1;
    while(now)
    {
        int v=vec[now];
        if (!vis[v])
        {
            f[v]=u;
            deep[v]=deep[u]+1;
            dfs1(v);
            size[u]+=size[v];
            if (size[v]>size[son[u]]) son[u]=v;
        }
        now=pre[now];
    }
}
void dfs2(int u)
{
    idn++; id[u]=idn;
    vis[u]=true;
    int now=base[u];
    if (son[u]) top[son[u]]=top[u],dfs2(son[u]);
    while(now)
    {
        int v=vec[now];
        if (!vis[v]) top[v]=v,dfs2(v);
        now=pre[now];
    }
}
void build(int i,int l,int r)
{
    tr[i].l=l; tr[i].r=r; tr[i].mid=(l+r)/2;
    if (l==r){tr[i].sum=tr[i].mx=c[l];return;}
    build(i*2,l,tr[i].mid);
    build(i*2+1,tr[i].mid+1,r);
    tr[i].sum=tr[i*2].sum+tr[i*2+1].sum;
    tr[i].mx=max(tr[i*2].mx,tr[i*2+1].mx);
}
int querymax(int i)
{
    if (x<=tr[i].l&&tr[i].r<=y) return tr[i].mx;
    if (y<=tr[i].mid) return querymax(i*2);
    if (x>tr[i].mid) return querymax(i*2+1);
    return max(querymax(i*2),querymax(i*2+1));
}
int querysum(int i)
{
    if (x<=tr[i].l&&tr[i].r<=y) return tr[i].sum;
    if (y<=tr[i].mid) return querysum(i*2);
    if (x>tr[i].mid) return querysum(i*2+1);
    return querysum(i*2)+querysum(i*2+1);
}
void change(int i)
{
    if (tr[i].l==tr[i].r) {tr[i].sum=y;tr[i].mx=y;return;}
    if (x<=tr[i].mid) change(i*2);else change(i*2+1);
    tr[i].sum=tr[i*2].sum+tr[i*2+1].sum;
    tr[i].mx=max(tr[i*2].mx,tr[i*2+1].mx);
}
void slovemax(int a,int b)
{
    int f1=top[a]; int f2=top[b];int tmp=-50000;
    while(f1!=f2)
    {
        if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b);
        x=id[f1];y=id[a];
        tmp=max(tmp,querymax(1));
        a=f[f1]; f1=top[a];
    }
    if (deep[a]>deep[b]) swap(a,b);
    x=id[a]; y=id[b];
    tmp=max(tmp,querymax(1));
    printf("%d\n",tmp);
}
void slovesum(int a,int b)
{
    int f1=top[a]; int f2=top[b];int tmp=0;
    while(f1!=f2)
    {
        if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b);
        x=id[f1];y=id[a];
        tmp+=querysum(1);
        a=f[f1]; f1=top[a];
    }
    if (deep[a]>deep[b]) swap(a,b);
    x=id[a]; y=id[b];
    tmp+=querysum(1);
    printf("%d\n",tmp);
}
int main()
{
    scanf("%d",&n);
    for (i=1;i<n;i++) scanf("%d%d",&x,&y),add(x,y),add(y,x);
    for (i=1;i<=n;i++) scanf("%d",&w[i]);
    deep[1]=1; dfs1(1);
    memset(vis,0,sizeof(vis));
    top[1]=1; dfs2(1);
    for (i=1;i<=n;i++) c[id[i]]=w[i];
    build(1,1,n);
    scanf("%d",&q);
    for (i=1;i<=q;i++)
    {
        scanf("%s%d%d",chr,&x,&y);
        if (chr[1]=='M') slovemax(x,y);
        if (chr[1]=='H') x=id[x],change(1);
        if (chr[1]=='S') slovesum(x,y);
    }
}