bzoj-1036
终于学会了树链剖分,以前以为它很难,现在看来其实也挺简单的,就两个dfs+线段树维护一下嘛。第一次交忘记修改的是x对应的编号,WA了一次,修改后AC。
#include<cstdio> #include<iostream> #include<cstring> #include<string> #include<algorithm> #include<queue> #include<functional> using namespace std; const int N=30005; int size[N],top[N],son[N],f[N],id[N],c[N],w[N],deep[N],n,i,q,tot,idn,x,y; int base[N],vec[2*N],pre[2*N]; struct node{ int sum,mx,l,r,mid; }tr[4*N]; char chr[10]; bool vis[N]; void add(int x,int y) { tot++; vec[tot]=y; pre[tot]=base[x]; base[x]=tot; } void dfs1(int u) { vis[u]=true; int now=base[u]; size[u]=1; while(now) { int v=vec[now]; if (!vis[v]) { f[v]=u; deep[v]=deep[u]+1; dfs1(v); size[u]+=size[v]; if (size[v]>size[son[u]]) son[u]=v; } now=pre[now]; } } void dfs2(int u) { idn++; id[u]=idn; vis[u]=true; int now=base[u]; if (son[u]) top[son[u]]=top[u],dfs2(son[u]); while(now) { int v=vec[now]; if (!vis[v]) top[v]=v,dfs2(v); now=pre[now]; } } void build(int i,int l,int r) { tr[i].l=l; tr[i].r=r; tr[i].mid=(l+r)/2; if (l==r){tr[i].sum=tr[i].mx=c[l];return;} build(i*2,l,tr[i].mid); build(i*2+1,tr[i].mid+1,r); tr[i].sum=tr[i*2].sum+tr[i*2+1].sum; tr[i].mx=max(tr[i*2].mx,tr[i*2+1].mx); } int querymax(int i) { if (x<=tr[i].l&&tr[i].r<=y) return tr[i].mx; if (y<=tr[i].mid) return querymax(i*2); if (x>tr[i].mid) return querymax(i*2+1); return max(querymax(i*2),querymax(i*2+1)); } int querysum(int i) { if (x<=tr[i].l&&tr[i].r<=y) return tr[i].sum; if (y<=tr[i].mid) return querysum(i*2); if (x>tr[i].mid) return querysum(i*2+1); return querysum(i*2)+querysum(i*2+1); } void change(int i) { if (tr[i].l==tr[i].r) {tr[i].sum=y;tr[i].mx=y;return;} if (x<=tr[i].mid) change(i*2);else change(i*2+1); tr[i].sum=tr[i*2].sum+tr[i*2+1].sum; tr[i].mx=max(tr[i*2].mx,tr[i*2+1].mx); } void slovemax(int a,int b) { int f1=top[a]; int f2=top[b];int tmp=-50000; while(f1!=f2) { if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b); x=id[f1];y=id[a]; tmp=max(tmp,querymax(1)); a=f[f1]; f1=top[a]; } if (deep[a]>deep[b]) swap(a,b); x=id[a]; y=id[b]; tmp=max(tmp,querymax(1)); printf("%d\n",tmp); } void slovesum(int a,int b) { int f1=top[a]; int f2=top[b];int tmp=0; while(f1!=f2) { if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b); x=id[f1];y=id[a]; tmp+=querysum(1); a=f[f1]; f1=top[a]; } if (deep[a]>deep[b]) swap(a,b); x=id[a]; y=id[b]; tmp+=querysum(1); printf("%d\n",tmp); } int main() { scanf("%d",&n); for (i=1;i<n;i++) scanf("%d%d",&x,&y),add(x,y),add(y,x); for (i=1;i<=n;i++) scanf("%d",&w[i]); deep[1]=1; dfs1(1); memset(vis,0,sizeof(vis)); top[1]=1; dfs2(1); for (i=1;i<=n;i++) c[id[i]]=w[i]; build(1,1,n); scanf("%d",&q); for (i=1;i<=q;i++) { scanf("%s%d%d",chr,&x,&y); if (chr[1]=='M') slovemax(x,y); if (chr[1]=='H') x=id[x],change(1); if (chr[1]=='S') slovesum(x,y); } }