bzoj-1036

zjhl2 posted @ 2015年8月05日 04:34 with tags 树链剖分 线段树 , 350 阅读

终于学会了树链剖分,以前以为它很难,现在看来其实也挺简单的,就两个dfs+线段树维护一下嘛。第一次交忘记修改的是x对应的编号,WA了一次,修改后AC。

#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<algorithm>
#include<queue>
#include<functional>
using namespace std;
const int N=30005;
int size[N],top[N],son[N],f[N],id[N],c[N],w[N],deep[N],n,i,q,tot,idn,x,y;
int base[N],vec[2*N],pre[2*N];
struct node{
    int sum,mx,l,r,mid;
}tr[4*N];
char chr[10];
bool vis[N];
void add(int x,int y)
{
    tot++;
    vec[tot]=y;  pre[tot]=base[x];  base[x]=tot;
}
void dfs1(int u)
{
    vis[u]=true;
    int now=base[u];
    size[u]=1;
    while(now)
    {
        int v=vec[now];
        if (!vis[v])
        {
            f[v]=u;
            deep[v]=deep[u]+1;
            dfs1(v);
            size[u]+=size[v];
            if (size[v]>size[son[u]]) son[u]=v;
        }
        now=pre[now];
    }
}
void dfs2(int u)
{
    idn++; id[u]=idn;
    vis[u]=true;
    int now=base[u];
    if (son[u]) top[son[u]]=top[u],dfs2(son[u]);
    while(now)
    {
        int v=vec[now];
        if (!vis[v]) top[v]=v,dfs2(v);
        now=pre[now];
    }
}
void build(int i,int l,int r)
{
    tr[i].l=l; tr[i].r=r; tr[i].mid=(l+r)/2;
    if (l==r){tr[i].sum=tr[i].mx=c[l];return;}
    build(i*2,l,tr[i].mid);
    build(i*2+1,tr[i].mid+1,r);
    tr[i].sum=tr[i*2].sum+tr[i*2+1].sum;
    tr[i].mx=max(tr[i*2].mx,tr[i*2+1].mx);
}
int querymax(int i)
{
    if (x<=tr[i].l&&tr[i].r<=y) return tr[i].mx;
    if (y<=tr[i].mid) return querymax(i*2);
    if (x>tr[i].mid) return querymax(i*2+1);
    return max(querymax(i*2),querymax(i*2+1));
}
int querysum(int i)
{
    if (x<=tr[i].l&&tr[i].r<=y) return tr[i].sum;
    if (y<=tr[i].mid) return querysum(i*2);
    if (x>tr[i].mid) return querysum(i*2+1);
    return querysum(i*2)+querysum(i*2+1);
}
void change(int i)
{
    if (tr[i].l==tr[i].r) {tr[i].sum=y;tr[i].mx=y;return;}
    if (x<=tr[i].mid) change(i*2);else change(i*2+1);
    tr[i].sum=tr[i*2].sum+tr[i*2+1].sum;
    tr[i].mx=max(tr[i*2].mx,tr[i*2+1].mx);
}
void slovemax(int a,int b)
{
    int f1=top[a]; int f2=top[b];int tmp=-50000;
    while(f1!=f2)
    {
        if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b);
        x=id[f1];y=id[a];
        tmp=max(tmp,querymax(1));
        a=f[f1]; f1=top[a];
    }
    if (deep[a]>deep[b]) swap(a,b);
    x=id[a]; y=id[b];
    tmp=max(tmp,querymax(1));
    printf("%d\n",tmp);
}
void slovesum(int a,int b)
{
    int f1=top[a]; int f2=top[b];int tmp=0;
    while(f1!=f2)
    {
        if (deep[f1]<deep[f2]) swap(f1,f2),swap(a,b);
        x=id[f1];y=id[a];
        tmp+=querysum(1);
        a=f[f1]; f1=top[a];
    }
    if (deep[a]>deep[b]) swap(a,b);
    x=id[a]; y=id[b];
    tmp+=querysum(1);
    printf("%d\n",tmp);
}
int main()
{
    scanf("%d",&n);
    for (i=1;i<n;i++) scanf("%d%d",&x,&y),add(x,y),add(y,x);
    for (i=1;i<=n;i++) scanf("%d",&w[i]);
    deep[1]=1; dfs1(1);
    memset(vis,0,sizeof(vis));
    top[1]=1; dfs2(1);
    for (i=1;i<=n;i++) c[id[i]]=w[i];
    build(1,1,n);
    scanf("%d",&q);
    for (i=1;i<=q;i++)
    {
        scanf("%s%d%d",chr,&x,&y);
        if (chr[1]=='M') slovemax(x,y);
        if (chr[1]=='H') x=id[x],change(1);
        if (chr[1]=='S') slovesum(x,y);
    }
}

登录 *


loading captcha image...
(输入验证码)
or Ctrl+Enter