题意:给你一棵树,和m条链,问最多能取多少条没有公共点的链。
这道题的加强版出现在15年的多校第1场,见hdu5293,加强版每条链还有权值。
因为我是先做了加强版(当时坑了一天),看到这道题,于是直接树形dp了。。。。(标算是贪心。。。)
每个节点u维护两个值:1.f[u]表示u所在子树最多有几条链 2.sum[u]表示u的子节点的f值的和
sum[u]很好搞,f[u]的转移方法如下:
假设有一条链x,y|lca(x,y)==u
f[u]=max(sum[u],sigma(sum[v])-sigma(f[v])+1)
其中v是链上的点
简单的说就是:
1.不要这条链,f[u]就是sum[u]
2.要这条链,那么f[u]=这条链下面的所有子树能够选择的链的个数+1
链上求和的话可以用树状数组或者线段树维护dfs序,每个点记录根到该点路径上的和,每修改一个点就是把其子树里点的值都加上f[u]或sum[u].
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=100005;
int base[N],vec[2*N],pre[2*N],tot;
struct pair{
int x,y;
}to[N];
int head[N],nxt[N],total;
int fa[N][25],deep[N];
int dfn[N],time,l[N],r[N];
int cf[N],cs[N];
int f[N],sum[N];
int n,m,x,y,i;
void addf(int x,int k)
{
for (int i=x;i<=n;i+=i&-i) cf[i]+=k;
}
void addsum(int x,int k)
{
for (int i=x;i<=n;i+=i&-i) cs[i]+=k;
}
int getf(int x)
{
int tmp=0;
for (int i=x;i;i-=i&-i) tmp+=cf[i];
return tmp;
}
int getsum(int x)
{
int tmp=0;
for (int i=x;i;i-=i&-i) tmp+=cs[i];
return tmp;
}
void push(int u,int x,int y)
{
to[++total]={x,y}; nxt[total]=head[u]; head[u]=total;
}
void link(int x,int y)
{
vec[++tot]=y; pre[tot]=base[x]; base[x]=tot;
}
int lca(int x,int y)
{
if (deep[x]<deep[y]) swap(x,y);
if (deep[x]>deep[y])
{
for (int j=20;j>=0;j--)
if (deep[fa[x][j]]>deep[y]) x=fa[x][j];
x=fa[x][0];
}
if (x==y) return x;
for (int j=20;j>=0;j--)
if (fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j];
return fa[x][0];
}
void dfs1(int u)
{
dfn[u]=++time;
l[u]=time;
for (int now=base[u];now;now=pre[now])
{
int v=vec[now];
if (v!=fa[u][0])
{
fa[v][0]=u;
deep[v]=deep[u]+1;
dfs1(v);
}
}
r[u]=time;
}
void dfs2(int u)
{
sum[u]=0;
for (int now=base[u];now;now=pre[now])
{
int v=vec[now];
if (v!=fa[u][0])
{
dfs2(v);
sum[u]+=f[v];
}
}
f[u]=sum[u];
for (int i=head[u];i;i=nxt[i])
{
int x=dfn[to[i].x],y=dfn[to[i].y];
f[u]=max(f[u],getsum(x)+getsum(y)+sum[u]-getf(x)-getf(y)+1);
}
addf(l[u],f[u]);
addf(r[u]+1,-f[u]);
addsum(l[u],sum[u]);
addsum(r[u]+1,-sum[u]);
}
int main()
{
while(~scanf("%d%d",&n,&m))
{
time=0;
tot=0; total=0;
memset(head,0,sizeof(head));
memset(base,0,sizeof(base));
memset(cf,0,sizeof(cf));
memset(cs,0,sizeof(cs));
for (i=1;i<n;i++) scanf("%d%d",&x,&y),link(x,y),link(y,x);
fa[1][0]=1; deep[1]=1;
dfs1(1);
for (int j=1;j<=20;j++)
for (i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1];
for (i=1;i<=m;i++)
{
scanf("%d%d",&x,&y);
push(lca(x,y),x,y);
}
dfs2(1);
printf("%d\n",f[1]);
}
}
贪心做法:
将链按lca的深度排序,从深度大的开始,能取就取,取完后把该链下的子树全部标记,因此每次判断能否取就是O(1)的.(竟然跑的比上面的方法慢?!)
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=100005;
int base[N],vec[2*N],pre[2*N],tot;
struct chain{
int x,y,u;
}a[N];
int fa[N][25],deep[N];
bool vis[N];
bool cmp(chain a,chain b) {return deep[a.u]>deep[b.u];}
int n,m,x,y,i;
void link(int x,int y)
{
vec[++tot]=y; pre[tot]=base[x]; base[x]=tot;
}
int lca(int x,int y)
{
if (deep[x]<deep[y]) swap(x,y);
if (deep[x]>deep[y])
{
for (int j=20;j>=0;j--)
if (deep[fa[x][j]]>deep[y]) x=fa[x][j];
x=fa[x][0];
}
if (x==y) return x;
for (int j=20;j>=0;j--)
if (fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j];
return fa[x][0];
}
void dfs(int u)
{
for (int now=base[u];now;now=pre[now])
{
int v=vec[now];
if (v!=fa[u][0])
{
fa[v][0]=u;
deep[v]=deep[u]+1;
dfs(v);
}
}
}
void clean(int u)
{
if (vis[u]) return;
vis[u]=1;
for (int now=base[u];now;now=pre[now])
{
int v=vec[now];
if (v!=fa[u][0]) clean(v);
}
}
int main()
{
while(~scanf("%d%d",&n,&m))
{
tot=0;
memset(base,0,sizeof(base));
memset(vis,0,sizeof(vis));
for (i=1;i<n;i++) scanf("%d%d",&x,&y),link(x,y),link(y,x);
fa[1][0]=1; deep[1]=1;
dfs(1);
for (int j=1;j<=20;j++)
for (i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1];
for (i=1;i<=m;i++)
{
scanf("%d%d",&a[i].x,&a[i].y);
a[i].u=lca(a[i].x,a[i].y);
}
sort(a+1,a+m+1,cmp);
int ans=0;
for (i=1;i<=m;i++)
if (!vis[a[i].x]&&!vis[a[i].y])
{
clean(a[i].u);
ans++;
}
printf("%d\n",ans);
}
}