题意:给你一棵树,和m条链,问最多能取多少条没有公共点的链。
这道题的加强版出现在15年的多校第1场,见hdu5293,加强版每条链还有权值。
因为我是先做了加强版(当时坑了一天),看到这道题,于是直接树形dp了。。。。(标算是贪心。。。)
每个节点u维护两个值:1.f[u]表示u所在子树最多有几条链 2.sum[u]表示u的子节点的f值的和
sum[u]很好搞,f[u]的转移方法如下:
假设有一条链x,y|lca(x,y)==u
f[u]=max(sum[u],sigma(sum[v])-sigma(f[v])+1)
其中v是链上的点
简单的说就是:
1.不要这条链,f[u]就是sum[u]
2.要这条链,那么f[u]=这条链下面的所有子树能够选择的链的个数+1
链上求和的话可以用树状数组或者线段树维护dfs序,每个点记录根到该点路径上的和,每修改一个点就是把其子树里点的值都加上f[u]或sum[u].
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int N=100005; int base[N],vec[2*N],pre[2*N],tot; struct pair{ int x,y; }to[N]; int head[N],nxt[N],total; int fa[N][25],deep[N]; int dfn[N],time,l[N],r[N]; int cf[N],cs[N]; int f[N],sum[N]; int n,m,x,y,i; void addf(int x,int k) { for (int i=x;i<=n;i+=i&-i) cf[i]+=k; } void addsum(int x,int k) { for (int i=x;i<=n;i+=i&-i) cs[i]+=k; } int getf(int x) { int tmp=0; for (int i=x;i;i-=i&-i) tmp+=cf[i]; return tmp; } int getsum(int x) { int tmp=0; for (int i=x;i;i-=i&-i) tmp+=cs[i]; return tmp; } void push(int u,int x,int y) { to[++total]={x,y}; nxt[total]=head[u]; head[u]=total; } void link(int x,int y) { vec[++tot]=y; pre[tot]=base[x]; base[x]=tot; } int lca(int x,int y) { if (deep[x]<deep[y]) swap(x,y); if (deep[x]>deep[y]) { for (int j=20;j>=0;j--) if (deep[fa[x][j]]>deep[y]) x=fa[x][j]; x=fa[x][0]; } if (x==y) return x; for (int j=20;j>=0;j--) if (fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j]; return fa[x][0]; } void dfs1(int u) { dfn[u]=++time; l[u]=time; for (int now=base[u];now;now=pre[now]) { int v=vec[now]; if (v!=fa[u][0]) { fa[v][0]=u; deep[v]=deep[u]+1; dfs1(v); } } r[u]=time; } void dfs2(int u) { sum[u]=0; for (int now=base[u];now;now=pre[now]) { int v=vec[now]; if (v!=fa[u][0]) { dfs2(v); sum[u]+=f[v]; } } f[u]=sum[u]; for (int i=head[u];i;i=nxt[i]) { int x=dfn[to[i].x],y=dfn[to[i].y]; f[u]=max(f[u],getsum(x)+getsum(y)+sum[u]-getf(x)-getf(y)+1); } addf(l[u],f[u]); addf(r[u]+1,-f[u]); addsum(l[u],sum[u]); addsum(r[u]+1,-sum[u]); } int main() { while(~scanf("%d%d",&n,&m)) { time=0; tot=0; total=0; memset(head,0,sizeof(head)); memset(base,0,sizeof(base)); memset(cf,0,sizeof(cf)); memset(cs,0,sizeof(cs)); for (i=1;i<n;i++) scanf("%d%d",&x,&y),link(x,y),link(y,x); fa[1][0]=1; deep[1]=1; dfs1(1); for (int j=1;j<=20;j++) for (i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1]; for (i=1;i<=m;i++) { scanf("%d%d",&x,&y); push(lca(x,y),x,y); } dfs2(1); printf("%d\n",f[1]); } }
贪心做法:
将链按lca的深度排序,从深度大的开始,能取就取,取完后把该链下的子树全部标记,因此每次判断能否取就是O(1)的.(竟然跑的比上面的方法慢?!)
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int N=100005; int base[N],vec[2*N],pre[2*N],tot; struct chain{ int x,y,u; }a[N]; int fa[N][25],deep[N]; bool vis[N]; bool cmp(chain a,chain b) {return deep[a.u]>deep[b.u];} int n,m,x,y,i; void link(int x,int y) { vec[++tot]=y; pre[tot]=base[x]; base[x]=tot; } int lca(int x,int y) { if (deep[x]<deep[y]) swap(x,y); if (deep[x]>deep[y]) { for (int j=20;j>=0;j--) if (deep[fa[x][j]]>deep[y]) x=fa[x][j]; x=fa[x][0]; } if (x==y) return x; for (int j=20;j>=0;j--) if (fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j]; return fa[x][0]; } void dfs(int u) { for (int now=base[u];now;now=pre[now]) { int v=vec[now]; if (v!=fa[u][0]) { fa[v][0]=u; deep[v]=deep[u]+1; dfs(v); } } } void clean(int u) { if (vis[u]) return; vis[u]=1; for (int now=base[u];now;now=pre[now]) { int v=vec[now]; if (v!=fa[u][0]) clean(v); } } int main() { while(~scanf("%d%d",&n,&m)) { tot=0; memset(base,0,sizeof(base)); memset(vis,0,sizeof(vis)); for (i=1;i<n;i++) scanf("%d%d",&x,&y),link(x,y),link(y,x); fa[1][0]=1; deep[1]=1; dfs(1); for (int j=1;j<=20;j++) for (i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1]; for (i=1;i<=m;i++) { scanf("%d%d",&a[i].x,&a[i].y); a[i].u=lca(a[i].x,a[i].y); } sort(a+1,a+m+1,cmp); int ans=0; for (i=1;i<=m;i++) if (!vis[a[i].x]&&!vis[a[i].y]) { clean(a[i].u); ans++; } printf("%d\n",ans); } }