hdu4912-Paths on the tree(2014 Multi-University Training Contest 5)

题意:给你一棵树,和m条链,问最多能取多少条没有公共点的链。

这道题的加强版出现在15年的多校第1场,见hdu5293,加强版每条链还有权值。

因为我是先做了加强版(当时坑了一天),看到这道题,于是直接树形dp了。。。。(标算是贪心。。。)

每个节点u维护两个值:1.f[u]表示u所在子树最多有几条链   2.sum[u]表示u的子节点的f值的和

sum[u]很好搞,f[u]的转移方法如下:

假设有一条链x,y|lca(x,y)==u

f[u]=max(sum[u],sigma(sum[v])-sigma(f[v])+1)

其中v是链上的点

简单的说就是:

1.不要这条链,f[u]就是sum[u]

2.要这条链,那么f[u]=这条链下面的所有子树能够选择的链的个数+1

链上求和的话可以用树状数组或者线段树维护dfs序,每个点记录根到该点路径上的和,每修改一个点就是把其子树里点的值都加上f[u]或sum[u].

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=100005;
int base[N],vec[2*N],pre[2*N],tot;

struct pair{
	int x,y;
}to[N];
int head[N],nxt[N],total;

int fa[N][25],deep[N];
int dfn[N],time,l[N],r[N];
int cf[N],cs[N];
int f[N],sum[N];
int n,m,x,y,i;
void addf(int x,int k)
{
    for (int i=x;i<=n;i+=i&-i) cf[i]+=k;
}
void addsum(int x,int k)
{
    for (int i=x;i<=n;i+=i&-i) cs[i]+=k;
}
int getf(int x)
{
    int tmp=0;
    for (int i=x;i;i-=i&-i) tmp+=cf[i];
    return tmp;
}
int getsum(int x)
{
    int tmp=0;
    for (int i=x;i;i-=i&-i) tmp+=cs[i];
    return tmp;
}
void push(int u,int x,int y)
{
	to[++total]={x,y}; nxt[total]=head[u];  head[u]=total;
}
void link(int x,int y)
{
	vec[++tot]=y;  pre[tot]=base[x];  base[x]=tot;
}
int lca(int x,int y)
{
	if (deep[x]<deep[y]) swap(x,y);
	if (deep[x]>deep[y])
		{
			for (int j=20;j>=0;j--)
				if (deep[fa[x][j]]>deep[y]) x=fa[x][j];
			x=fa[x][0];
		}
	if (x==y) return x;
	for (int j=20;j>=0;j--)
		if (fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j];
	return fa[x][0];
}
void dfs1(int u)
{
	dfn[u]=++time;
	l[u]=time;
	for (int now=base[u];now;now=pre[now])
	{
		int v=vec[now];
		if (v!=fa[u][0])
		{
			fa[v][0]=u;
			deep[v]=deep[u]+1;
			dfs1(v);
		}
	}
	r[u]=time;
}
void dfs2(int u)
{
	sum[u]=0;
	for (int now=base[u];now;now=pre[now])
	{
		int v=vec[now];
		if (v!=fa[u][0])
		{
			dfs2(v);
			sum[u]+=f[v];
		}
	}
	f[u]=sum[u];
	for (int i=head[u];i;i=nxt[i])
	{
		int x=dfn[to[i].x],y=dfn[to[i].y];
		f[u]=max(f[u],getsum(x)+getsum(y)+sum[u]-getf(x)-getf(y)+1);
	}
	addf(l[u],f[u]);
	addf(r[u]+1,-f[u]);
	addsum(l[u],sum[u]);
	addsum(r[u]+1,-sum[u]);
}
int main()
{
	while(~scanf("%d%d",&n,&m))
	{
		time=0;
		tot=0;  total=0;
		memset(head,0,sizeof(head));
		memset(base,0,sizeof(base));
		memset(cf,0,sizeof(cf));
		memset(cs,0,sizeof(cs));
		for (i=1;i<n;i++) scanf("%d%d",&x,&y),link(x,y),link(y,x);
		fa[1][0]=1;  deep[1]=1;
		dfs1(1);
		for (int j=1;j<=20;j++)
			for (i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1];
		for (i=1;i<=m;i++)
		{
			scanf("%d%d",&x,&y);
			push(lca(x,y),x,y);
		}
		dfs2(1);
		printf("%d\n",f[1]);
	}
}

贪心做法:

将链按lca的深度排序,从深度大的开始,能取就取,取完后把该链下的子树全部标记,因此每次判断能否取就是O(1)的.(竟然跑的比上面的方法慢?!)

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=100005;
int base[N],vec[2*N],pre[2*N],tot;

struct chain{
	int x,y,u;
}a[N];
int fa[N][25],deep[N];
bool vis[N];
bool cmp(chain a,chain b) {return deep[a.u]>deep[b.u];}
int n,m,x,y,i;
void link(int x,int y)
{
	vec[++tot]=y;  pre[tot]=base[x];  base[x]=tot;
}
int lca(int x,int y)
{
	if (deep[x]<deep[y]) swap(x,y);
	if (deep[x]>deep[y])
		{
			for (int j=20;j>=0;j--)
				if (deep[fa[x][j]]>deep[y]) x=fa[x][j];
			x=fa[x][0];
		}
	if (x==y) return x;
	for (int j=20;j>=0;j--)
		if (fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j];
	return fa[x][0];
}
void dfs(int u)
{
	for (int now=base[u];now;now=pre[now])
	{
		int v=vec[now];
		if (v!=fa[u][0])
		{
			fa[v][0]=u;
			deep[v]=deep[u]+1;
			dfs(v);
		}
	}
}
void clean(int u)
{
	if (vis[u]) return;
	vis[u]=1;
	for (int now=base[u];now;now=pre[now])
	{
		int v=vec[now];
		if (v!=fa[u][0]) clean(v);
	}
}
int main()
{
	while(~scanf("%d%d",&n,&m))
	{
		tot=0;
		memset(base,0,sizeof(base));
		memset(vis,0,sizeof(vis));
		for (i=1;i<n;i++) scanf("%d%d",&x,&y),link(x,y),link(y,x);
		fa[1][0]=1;  deep[1]=1;
		dfs(1);
		for (int j=1;j<=20;j++)
			for (i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1];
		for (i=1;i<=m;i++)
		{
			scanf("%d%d",&a[i].x,&a[i].y);
			a[i].u=lca(a[i].x,a[i].y);
		}
		sort(a+1,a+m+1,cmp);
		int ans=0;
		for (i=1;i<=m;i++)
			if (!vis[a[i].x]&&!vis[a[i].y])
			{
				clean(a[i].u);
				ans++;
			}
		printf("%d\n",ans);
	}
}