题意:
一棵树,1为根。每个点只能走向祖先,时间为距离的平方,另外走到祖先后会停留P单位时间。求每个点到达根的最小时间的最大值。
思路:
显然dp方程为:f[i]=(d[i]-d[j])^2+P+f[j] ( j 是 i 的祖先)
转化成向量点积形式:f[i]=d[i]^2+P+(1 , 2*d[i] ) * ( f[j]+d[j]^2 , -d[j] )
注意到深度越大,向量(1 , 2*d[i] )会逆时针旋转
于是维护点集i的所有祖先的 ( f[j]+d[j]^2 , -d[j] ) 的下凸壳,就能求出f[i]的最小值
注意到向量( f[j]+d[j]^2 , -d[j] )只会添加在下凸壳右边,于是只要用个队列来保存就可以了
查询的时候只要一直扔掉左边的点,找到最大值为止,然后退出当前dfs的时候加回来就可以了
如果一个个保存删掉的点,再一个个加回来的话,在菊花图的情况下复杂度就会变成O(n^2),事实上目前网上的题解都是这么做的,估计出题人也没意识到复杂度的问题,因为这看起来很像O(n)。
我的做法是每次dfs记录当前队列的左端点,和右端点,操作完后,事实上只是替换了数组里的tail位置的元素,把head和tail还有这个元素事先保存下来,最后恢复head,tail和这个元素就可以了,这样复杂度就是完美的O(n)了。
也不知道这种做法叫什么,毕竟自己乱想出来的,就叫回溯队列好了233
后来一想这个做法还是n方。。。因为每次head和tail的删除操作都可能是O(n)的,所以要二分查找head和tail分别应该删到哪,这样算法复杂度就是O(nlogn)了
这里如果不用斜率,而把两边的dx都乘到对面去的话是有问题的,因为f[j]+d[j]^2的规模是1e14,2*d[i]的规模是1e7,乘起来理论上会爆,然而数据好像还是没卡掉这种做法
代码1:O(n^2)做法
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N=100005; struct P { ll x,y; ll operator*(const P &b)const { return x*b.x+y*b.y; } }stk[N]; int n,p; int head,tail; vector<P>link[N]; ll f[N],d[N]; long double slope(P a,P b) { long double tmp=1.0; return tmp*(b.y-a.y)/(b.x-a.x); } void dfs(int u,int fa,int deep) { int phead=head,ptail=tail; d[u]=deep; if (u==1) f[u]=-p; else { P tmp={1,2*d[u]}; while(tail-head+1>=2&&stk[head]*tmp>=stk[head+1]*tmp) head++; f[u]=d[u]*d[u]+p+stk[head]*tmp; } P tmp={f[u]+d[u]*d[u],-d[u]}; while(tail-head+1>=2&& slope(stk[tail-1],stk[tail])>=slope(stk[tail],tmp)) tail--; P mem=stk[tail+1]; stk[++tail]=tmp; for (int i=0;i<link[u].size();i++) { int v=link[u][i].x; if (v==fa) continue; dfs(v,u,deep+link[u][i].y); } tail--; stk[tail+1]=mem; head=phead; tail=ptail; } int main() { int t; scanf("%d",&t); while(t--) { scanf("%d%d",&n,&p); for (int i=1;i<n;i++) { ll x,y,z; scanf("%d%d%lld",&x,&y,&z); if (z<0) n/=0; link[x].push_back({y,z}); } head=1; tail=0; dfs(1,1,0); ll ans=0; for (int i=1;i<=n;i++) ans=max(ans,f[i]); printf("%lld\n",ans); for (int i=1;i<=n;i++) link[i].clear(); } }
代码2:O(nlogn)
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int N=100005; struct P { ll x,y; ll operator*(const P &b)const { return x*b.x+y*b.y; } }stk[N]; int n,p; int head,tail; vector<P>link[N]; ll f[N],d[N]; long double slope(P a,P b) { long double tmp=1.0; return tmp*(b.y-a.y)/(b.x-a.x); } int gethead(P v) { int l=head,r=tail; while(l<r) { int mid=(l+r)/2; if (stk[mid]*v>=stk[mid+1]*v) l=mid+1; else r=mid; } return l; } int gettail(P v) { int l=head,r=tail; while(l<r) { int mid=(l+r)/2; if (slope(stk[mid],stk[mid+1])<slope(stk[mid+1],v)) l=mid+1; else r=mid; } return l; } void dfs(int u,int fa,int deep) { int phead=head,ptail=tail; d[u]=deep; if (u==1) f[u]=-p; else { P tmp={1,2*d[u]}; head=gethead(tmp); //while(tail-head+1>=2&&stk[head]*tmp>=stk[head+1]*tmp) head++; f[u]=d[u]*d[u]+p+stk[head]*tmp; } P tmp={f[u]+d[u]*d[u],-d[u]}; tail=gettail(tmp); // while(tail-head+1>=2&& // slope(stk[tail-1],stk[tail])>=slope(stk[tail],tmp)) tail--; P mem=stk[tail+1]; stk[++tail]=tmp; for (int i=0;i<link[u].size();i++) { int v=link[u][i].x; if (v==fa) continue; dfs(v,u,deep+link[u][i].y); } tail--; stk[tail+1]=mem; head=phead; tail=ptail; } int main() { int t; scanf("%d",&t); while(t--) { scanf("%d%d",&n,&p); for (int i=1;i<n;i++) { ll x,y,z; scanf("%d%d%lld",&x,&y,&z); link[x].push_back({y,z}); } head=1; tail=0; dfs(1,1,0); ll ans=0; for (int i=1;i<=n;i++) ans=max(ans,f[i]); printf("%lld\n",ans); for (int i=1;i<=n;i++) link[i].clear(); } }