题意:
一棵树,1为根。每个点只能走向祖先,时间为距离的平方,另外走到祖先后会停留P单位时间。求每个点到达根的最小时间的最大值。
思路:
显然dp方程为:f[i]=(d[i]-d[j])^2+P+f[j] ( j 是 i 的祖先)
转化成向量点积形式:f[i]=d[i]^2+P+(1 , 2*d[i] ) * ( f[j]+d[j]^2 , -d[j] )
注意到深度越大,向量(1 , 2*d[i] )会逆时针旋转
于是维护点集i的所有祖先的 ( f[j]+d[j]^2 , -d[j] ) 的下凸壳,就能求出f[i]的最小值
注意到向量( f[j]+d[j]^2 , -d[j] )只会添加在下凸壳右边,于是只要用个队列来保存就可以了
查询的时候只要一直扔掉左边的点,找到最大值为止,然后退出当前dfs的时候加回来就可以了
如果一个个保存删掉的点,再一个个加回来的话,在菊花图的情况下复杂度就会变成O(n^2),事实上目前网上的题解都是这么做的,估计出题人也没意识到复杂度的问题,因为这看起来很像O(n)。
我的做法是每次dfs记录当前队列的左端点,和右端点,操作完后,事实上只是替换了数组里的tail位置的元素,把head和tail还有这个元素事先保存下来,最后恢复head,tail和这个元素就可以了,这样复杂度就是完美的O(n)了。
也不知道这种做法叫什么,毕竟自己乱想出来的,就叫回溯队列好了233
后来一想这个做法还是n方。。。因为每次head和tail的删除操作都可能是O(n)的,所以要二分查找head和tail分别应该删到哪,这样算法复杂度就是O(nlogn)了
这里如果不用斜率,而把两边的dx都乘到对面去的话是有问题的,因为f[j]+d[j]^2的规模是1e14,2*d[i]的规模是1e7,乘起来理论上会爆,然而数据好像还是没卡掉这种做法
代码1:O(n^2)做法
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100005;
struct P
{
ll x,y;
ll operator*(const P &b)const
{
return x*b.x+y*b.y;
}
}stk[N];
int n,p;
int head,tail;
vector<P>link[N];
ll f[N],d[N];
long double slope(P a,P b)
{
long double tmp=1.0;
return tmp*(b.y-a.y)/(b.x-a.x);
}
void dfs(int u,int fa,int deep)
{
int phead=head,ptail=tail;
d[u]=deep;
if (u==1) f[u]=-p;
else
{
P tmp={1,2*d[u]};
while(tail-head+1>=2&&stk[head]*tmp>=stk[head+1]*tmp) head++;
f[u]=d[u]*d[u]+p+stk[head]*tmp;
}
P tmp={f[u]+d[u]*d[u],-d[u]};
while(tail-head+1>=2&&
slope(stk[tail-1],stk[tail])>=slope(stk[tail],tmp)) tail--;
P mem=stk[tail+1];
stk[++tail]=tmp;
for (int i=0;i<link[u].size();i++)
{
int v=link[u][i].x;
if (v==fa) continue;
dfs(v,u,deep+link[u][i].y);
}
tail--;
stk[tail+1]=mem;
head=phead; tail=ptail;
}
int main()
{
int t; scanf("%d",&t);
while(t--)
{
scanf("%d%d",&n,&p);
for (int i=1;i<n;i++)
{
ll x,y,z;
scanf("%d%d%lld",&x,&y,&z);
if (z<0) n/=0;
link[x].push_back({y,z});
}
head=1; tail=0;
dfs(1,1,0);
ll ans=0;
for (int i=1;i<=n;i++) ans=max(ans,f[i]);
printf("%lld\n",ans);
for (int i=1;i<=n;i++) link[i].clear();
}
}
代码2:O(nlogn)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100005;
struct P
{
ll x,y;
ll operator*(const P &b)const
{
return x*b.x+y*b.y;
}
}stk[N];
int n,p;
int head,tail;
vector<P>link[N];
ll f[N],d[N];
long double slope(P a,P b)
{
long double tmp=1.0;
return tmp*(b.y-a.y)/(b.x-a.x);
}
int gethead(P v)
{
int l=head,r=tail;
while(l<r)
{
int mid=(l+r)/2;
if (stk[mid]*v>=stk[mid+1]*v) l=mid+1;
else r=mid;
}
return l;
}
int gettail(P v)
{
int l=head,r=tail;
while(l<r)
{
int mid=(l+r)/2;
if (slope(stk[mid],stk[mid+1])<slope(stk[mid+1],v)) l=mid+1;
else r=mid;
}
return l;
}
void dfs(int u,int fa,int deep)
{
int phead=head,ptail=tail;
d[u]=deep;
if (u==1) f[u]=-p;
else
{
P tmp={1,2*d[u]};
head=gethead(tmp);
//while(tail-head+1>=2&&stk[head]*tmp>=stk[head+1]*tmp) head++;
f[u]=d[u]*d[u]+p+stk[head]*tmp;
}
P tmp={f[u]+d[u]*d[u],-d[u]};
tail=gettail(tmp);
// while(tail-head+1>=2&&
// slope(stk[tail-1],stk[tail])>=slope(stk[tail],tmp)) tail--;
P mem=stk[tail+1];
stk[++tail]=tmp;
for (int i=0;i<link[u].size();i++)
{
int v=link[u][i].x;
if (v==fa) continue;
dfs(v,u,deep+link[u][i].y);
}
tail--;
stk[tail+1]=mem;
head=phead; tail=ptail;
}
int main()
{
int t; scanf("%d",&t);
while(t--)
{
scanf("%d%d",&n,&p);
for (int i=1;i<n;i++)
{
ll x,y,z;
scanf("%d%d%lld",&x,&y,&z);
link[x].push_back({y,z});
}
head=1; tail=0;
dfs(1,1,0);
ll ans=0;
for (int i=1;i<=n;i++) ans=max(ans,f[i]);
printf("%lld\n",ans);
for (int i=1;i<=n;i++) link[i].clear();
}
}