HDU5956 The Elder 下凸壳的维护

题意:

一棵树,1为根。每个点只能走向祖先,时间为距离的平方,另外走到祖先后会停留P单位时间。求每个点到达根的最小时间的最大值。

 

思路:

显然dp方程为:f[i]=(d[i]-d[j])^2+P+f[j]  ( j 是 i 的祖先)

转化成向量点积形式:f[i]=d[i]^2+P+(1 , 2*d[i] ) * ( f[j]+d[j]^2 , -d[j] )

注意到深度越大,向量(1 , 2*d[i] )会逆时针旋转

于是维护点集i的所有祖先的  ( f[j]+d[j]^2 , -d[j] ) 的下凸壳,就能求出f[i]的最小值

注意到向量( f[j]+d[j]^2 , -d[j] )只会添加在下凸壳右边,于是只要用个队列来保存就可以了

查询的时候只要一直扔掉左边的点,找到最大值为止,然后退出当前dfs的时候加回来就可以了

如果一个个保存删掉的点,再一个个加回来的话,在菊花图的情况下复杂度就会变成O(n^2),事实上目前网上的题解都是这么做的,估计出题人也没意识到复杂度的问题,因为这看起来很像O(n)。

我的做法是每次dfs记录当前队列的左端点,和右端点,操作完后,事实上只是替换了数组里的tail位置的元素,把head和tail还有这个元素事先保存下来,最后恢复head,tail和这个元素就可以了,这样复杂度就是完美的O(n)了。

也不知道这种做法叫什么,毕竟自己乱想出来的,就叫回溯队列好了233

后来一想这个做法还是n方。。。因为每次head和tail的删除操作都可能是O(n)的,所以要二分查找head和tail分别应该删到哪,这样算法复杂度就是O(nlogn)了

这里如果不用斜率,而把两边的dx都乘到对面去的话是有问题的,因为f[j]+d[j]^2的规模是1e14,2*d[i]的规模是1e7,乘起来理论上会爆,然而数据好像还是没卡掉这种做法

 

代码1:O(n^2)做法

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100005;
struct P
{
	ll x,y;
	ll operator*(const P &b)const
	{
		return x*b.x+y*b.y;
	}
}stk[N];
int n,p;
int head,tail;
vector<P>link[N];
ll f[N],d[N];
long double slope(P a,P b)
{
	long double tmp=1.0;
	return tmp*(b.y-a.y)/(b.x-a.x);
}
void dfs(int u,int fa,int deep)
{
	int phead=head,ptail=tail;
	d[u]=deep;
	if (u==1) f[u]=-p;
	else
	{
		P tmp={1,2*d[u]};
		while(tail-head+1>=2&&stk[head]*tmp>=stk[head+1]*tmp) head++;
		f[u]=d[u]*d[u]+p+stk[head]*tmp;
	}
	P tmp={f[u]+d[u]*d[u],-d[u]};
	while(tail-head+1>=2&&
		slope(stk[tail-1],stk[tail])>=slope(stk[tail],tmp)) tail--;
	P mem=stk[tail+1];
	stk[++tail]=tmp;

	for (int i=0;i<link[u].size();i++)
	{
		int v=link[u][i].x;
		if (v==fa) continue;
		dfs(v,u,deep+link[u][i].y);
	}
	tail--;
	stk[tail+1]=mem;
	head=phead; tail=ptail;
}
int main()
{
	int t; scanf("%d",&t);
	while(t--)
	{
		scanf("%d%d",&n,&p);
		for (int i=1;i<n;i++)
		{
			ll x,y,z;
			scanf("%d%d%lld",&x,&y,&z);
			if (z<0) n/=0;
			link[x].push_back({y,z});
		}
		head=1; tail=0;
		dfs(1,1,0);
		ll ans=0;
		for (int i=1;i<=n;i++) ans=max(ans,f[i]);
		printf("%lld\n",ans);

		for (int i=1;i<=n;i++) link[i].clear();
	}
}

 

代码2:O(nlogn)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100005;
struct P
{
    ll x,y;
    ll operator*(const P &b)const
    {
        return x*b.x+y*b.y;
    }
}stk[N];
int n,p;
int head,tail;
vector<P>link[N];
ll f[N],d[N];
long double slope(P a,P b)
{
    long double tmp=1.0;
    return tmp*(b.y-a.y)/(b.x-a.x);
}
int gethead(P v)
{
	int l=head,r=tail;
	while(l<r)
	{
		int mid=(l+r)/2;
		if (stk[mid]*v>=stk[mid+1]*v) l=mid+1;
		else r=mid;
	}
	return l;
}
int gettail(P v)
{
	int l=head,r=tail;
	while(l<r)
	{
		int mid=(l+r)/2;
		if (slope(stk[mid],stk[mid+1])<slope(stk[mid+1],v)) l=mid+1;
		else r=mid;
	}
	return l;
}
void dfs(int u,int fa,int deep)
{
    int phead=head,ptail=tail;
    d[u]=deep;
    if (u==1) f[u]=-p;
    else
    {
        P tmp={1,2*d[u]};
        head=gethead(tmp);

        //while(tail-head+1>=2&&stk[head]*tmp>=stk[head+1]*tmp) head++;
        f[u]=d[u]*d[u]+p+stk[head]*tmp;
    }
    P tmp={f[u]+d[u]*d[u],-d[u]};
    tail=gettail(tmp);
//    while(tail-head+1>=2&&
//        slope(stk[tail-1],stk[tail])>=slope(stk[tail],tmp)) tail--;
    P mem=stk[tail+1];
    stk[++tail]=tmp;

    for (int i=0;i<link[u].size();i++)
    {
        int v=link[u][i].x;
        if (v==fa) continue;
        dfs(v,u,deep+link[u][i].y);
    }
    tail--;
    stk[tail+1]=mem;
    head=phead; tail=ptail;
}
int main()
{
    int t; scanf("%d",&t);
    while(t--)
    {
        scanf("%d%d",&n,&p);
        for (int i=1;i<n;i++)
        {
            ll x,y,z;
            scanf("%d%d%lld",&x,&y,&z);
            link[x].push_back({y,z});
        }
        head=1; tail=0;
        dfs(1,1,0);
        ll ans=0;
        for (int i=1;i<=n;i++) ans=max(ans,f[i]);
        printf("%lld\n",ans);

        for (int i=1;i<=n;i++) link[i].clear();
    }
}

hdu4912-Paths on the tree(2014 Multi-University Training Contest 5)

题意:给你一棵树,和m条链,问最多能取多少条没有公共点的链。

这道题的加强版出现在15年的多校第1场,见hdu5293,加强版每条链还有权值。

因为我是先做了加强版(当时坑了一天),看到这道题,于是直接树形dp了。。。。(标算是贪心。。。)

每个节点u维护两个值:1.f[u]表示u所在子树最多有几条链   2.sum[u]表示u的子节点的f值的和

sum[u]很好搞,f[u]的转移方法如下:

假设有一条链x,y|lca(x,y)==u

f[u]=max(sum[u],sigma(sum[v])-sigma(f[v])+1)

其中v是链上的点

简单的说就是:

1.不要这条链,f[u]就是sum[u]

2.要这条链,那么f[u]=这条链下面的所有子树能够选择的链的个数+1

链上求和的话可以用树状数组或者线段树维护dfs序,每个点记录根到该点路径上的和,每修改一个点就是把其子树里点的值都加上f[u]或sum[u].

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=100005;
int base[N],vec[2*N],pre[2*N],tot;

struct pair{
	int x,y;
}to[N];
int head[N],nxt[N],total;

int fa[N][25],deep[N];
int dfn[N],time,l[N],r[N];
int cf[N],cs[N];
int f[N],sum[N];
int n,m,x,y,i;
void addf(int x,int k)
{
    for (int i=x;i<=n;i+=i&-i) cf[i]+=k;
}
void addsum(int x,int k)
{
    for (int i=x;i<=n;i+=i&-i) cs[i]+=k;
}
int getf(int x)
{
    int tmp=0;
    for (int i=x;i;i-=i&-i) tmp+=cf[i];
    return tmp;
}
int getsum(int x)
{
    int tmp=0;
    for (int i=x;i;i-=i&-i) tmp+=cs[i];
    return tmp;
}
void push(int u,int x,int y)
{
	to[++total]={x,y}; nxt[total]=head[u];  head[u]=total;
}
void link(int x,int y)
{
	vec[++tot]=y;  pre[tot]=base[x];  base[x]=tot;
}
int lca(int x,int y)
{
	if (deep[x]<deep[y]) swap(x,y);
	if (deep[x]>deep[y])
		{
			for (int j=20;j>=0;j--)
				if (deep[fa[x][j]]>deep[y]) x=fa[x][j];
			x=fa[x][0];
		}
	if (x==y) return x;
	for (int j=20;j>=0;j--)
		if (fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j];
	return fa[x][0];
}
void dfs1(int u)
{
	dfn[u]=++time;
	l[u]=time;
	for (int now=base[u];now;now=pre[now])
	{
		int v=vec[now];
		if (v!=fa[u][0])
		{
			fa[v][0]=u;
			deep[v]=deep[u]+1;
			dfs1(v);
		}
	}
	r[u]=time;
}
void dfs2(int u)
{
	sum[u]=0;
	for (int now=base[u];now;now=pre[now])
	{
		int v=vec[now];
		if (v!=fa[u][0])
		{
			dfs2(v);
			sum[u]+=f[v];
		}
	}
	f[u]=sum[u];
	for (int i=head[u];i;i=nxt[i])
	{
		int x=dfn[to[i].x],y=dfn[to[i].y];
		f[u]=max(f[u],getsum(x)+getsum(y)+sum[u]-getf(x)-getf(y)+1);
	}
	addf(l[u],f[u]);
	addf(r[u]+1,-f[u]);
	addsum(l[u],sum[u]);
	addsum(r[u]+1,-sum[u]);
}
int main()
{
	while(~scanf("%d%d",&n,&m))
	{
		time=0;
		tot=0;  total=0;
		memset(head,0,sizeof(head));
		memset(base,0,sizeof(base));
		memset(cf,0,sizeof(cf));
		memset(cs,0,sizeof(cs));
		for (i=1;i<n;i++) scanf("%d%d",&x,&y),link(x,y),link(y,x);
		fa[1][0]=1;  deep[1]=1;
		dfs1(1);
		for (int j=1;j<=20;j++)
			for (i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1];
		for (i=1;i<=m;i++)
		{
			scanf("%d%d",&x,&y);
			push(lca(x,y),x,y);
		}
		dfs2(1);
		printf("%d\n",f[1]);
	}
}

贪心做法:

将链按lca的深度排序,从深度大的开始,能取就取,取完后把该链下的子树全部标记,因此每次判断能否取就是O(1)的.(竟然跑的比上面的方法慢?!)

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=100005;
int base[N],vec[2*N],pre[2*N],tot;

struct chain{
	int x,y,u;
}a[N];
int fa[N][25],deep[N];
bool vis[N];
bool cmp(chain a,chain b) {return deep[a.u]>deep[b.u];}
int n,m,x,y,i;
void link(int x,int y)
{
	vec[++tot]=y;  pre[tot]=base[x];  base[x]=tot;
}
int lca(int x,int y)
{
	if (deep[x]<deep[y]) swap(x,y);
	if (deep[x]>deep[y])
		{
			for (int j=20;j>=0;j--)
				if (deep[fa[x][j]]>deep[y]) x=fa[x][j];
			x=fa[x][0];
		}
	if (x==y) return x;
	for (int j=20;j>=0;j--)
		if (fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j];
	return fa[x][0];
}
void dfs(int u)
{
	for (int now=base[u];now;now=pre[now])
	{
		int v=vec[now];
		if (v!=fa[u][0])
		{
			fa[v][0]=u;
			deep[v]=deep[u]+1;
			dfs(v);
		}
	}
}
void clean(int u)
{
	if (vis[u]) return;
	vis[u]=1;
	for (int now=base[u];now;now=pre[now])
	{
		int v=vec[now];
		if (v!=fa[u][0]) clean(v);
	}
}
int main()
{
	while(~scanf("%d%d",&n,&m))
	{
		tot=0;
		memset(base,0,sizeof(base));
		memset(vis,0,sizeof(vis));
		for (i=1;i<n;i++) scanf("%d%d",&x,&y),link(x,y),link(y,x);
		fa[1][0]=1;  deep[1]=1;
		dfs(1);
		for (int j=1;j<=20;j++)
			for (i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1];
		for (i=1;i<=m;i++)
		{
			scanf("%d%d",&a[i].x,&a[i].y);
			a[i].u=lca(a[i].x,a[i].y);
		}
		sort(a+1,a+m+1,cmp);
		int ans=0;
		for (i=1;i<=m;i++)
			if (!vis[a[i].x]&&!vis[a[i].y])
			{
				clean(a[i].u);
				ans++;
			}
		printf("%d\n",ans);
	}
}