题意:
一棵树,每个点有点权w,q次操作
1.将点x的点权修改为y
2.问离x距离不超过d的点权和
思路:
将树分治中的计算部分需要的树状数组提出来,维护这个树状数组,每次查询修改均为O(log2n)
#include<bits/stdc++.h> using namespace std; const int N=100005; vector<int>link[N]; struct MSG { int id1,id2; int deep; }msg[N][17]; int w[N]; int c[N*17]; int idn,L[N*2],R[N*2]; int maxfloor[N]; void add(int l,int r,int x,int y) { for (int i=x;i<r-l;i+=(i&-i)) c[l+i]+=y; } int get(int l,int r,int x) { if (x<1) return 0; if (x>=r-l) x=r-l-1; int ret=0; for (int i=x;i;i-=(i&-i)) ret+=c[l+i]; return ret; } bool vis[N]; int cent,mxsz; int size[N]; void findcent(int u,int fa,int n) { size[u]=1; int mx=0; for (int v:link[u]) { if (vis[v]||v==fa) continue; findcent(v,u,n); size[u]+=size[v]; mx=max(mx,size[v]); } mx=max(mx,n-size[u]); if (mx<mxsz) cent=u,mxsz=mx; } int getmaxdeep(int u,int fa) { int ret=1; for (int v:link[u]) if (v!=fa&&!vis[v]) ret=max(ret,1+getmaxdeep(v,u)); return ret; } void dfs(int u,int fa,int deep,int idn,int floor,int tp) { if (tp==0) msg[u][floor].id1=idn; else msg[u][floor].id2=idn; msg[u][floor].deep=deep; add(L[idn],R[idn],deep,w[u]); for (int v:link[u]) if (!vis[v]&&v!=fa) dfs(v,u,deep+1,idn,floor,tp); } void build(int u,int n,int floor) { cent=u; mxsz=n; findcent(u,-1,n); u=cent; maxfloor[u]=floor; idn++; L[idn]=R[idn-1]; R[idn]=L[idn]+getmaxdeep(u,-1)+1; dfs(u,-1,1,idn,floor,0); msg[u][floor].id2=-1; for (int v:link[u]) if (!vis[v]) { idn++; L[idn]=R[idn-1]; R[idn]=L[idn]+getmaxdeep(v,u)+2; dfs(v,u,2,idn,floor,1); } vis[u]=1; for (int v:link[u]) if (!vis[v]) if (size[v]<size[u]) build(v,size[v],floor+1); else build(v,n-size[u],floor+1); } int main() { int n,q; while(~scanf("%d%d",&n,&q)) { for (int i=1;i<=n;i++) scanf("%d",&w[i]); for (int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); link[x].push_back(y); link[y].push_back(x); } build(1,n,0); while(q--) { char s[2]; int x,y; scanf("%s%d%d",s,&x,&y); if (s[0]=='?') { int ans=0; for (int floor=0;floor<=maxfloor[x];floor++) { int id1=msg[x][floor].id1; int id2=msg[x][floor].id2; int deep=msg[x][floor].deep; ans+=get(L[id1],R[id1],y+2-deep); if (id2!=-1) ans-=get(L[id2],R[id2],y+2-deep); } printf("%d\n",ans); } else { for (int floor=0;floor<=maxfloor[x];floor++) { int id1=msg[x][floor].id1; int id2=msg[x][floor].id2; int deep=msg[x][floor].deep; add(L[id1],R[id1],deep,y-w[x]); if (id2!=-1) add(L[id2],R[id2],deep,y-w[x]); } w[x]=y; } } memset(vis,0,sizeof(vis)); memset(msg,0,sizeof(msg)); memset(w,0,sizeof(w)); memset(c,0,sizeof(c)); idn=0; for (int i=1;i<=n;i++) link[i].clear(); } }