本文已参与「新人创作礼」活动,一起开启掘金创作之路。
样例输入:
2
3
2 1 2
9 9 10
1 2 10
1 3 11
3
1 1 2
9 9 10
1 2 10
1 3 11
样例输出:
3
2
题意:给定一棵有n个节点的树,树上每个节点有一个k和w,w是这个节点的点权,k是修改点权的代价系数,然后给出n-1条边,每条边有一个边权,我们要对点权进行修改,第x个点的点权是w[x],那么将其点权修改为y的代价是c[x]*|y-w[x]|,最后要保证对于每条边,都要满足该边权小于等于该边两端点的点权的最大值,以及要满足大于等于该边两端点的点权的最小值,问我们最少的代价是多少?
分析:这一道题我做的挺坎坷的,接下来我会沿着我的思考历程一点点讲解一下:
设f[i][0]代表处理完以i为根的子树并且当前节点的weight小于等于其父边权值的最小代价,f[i][1]代表处理完以i为根的子树并且当前节点的weight大于等于其父边权值的最小代价
知道了状态定义那么就容易想状态转移方程了,对于以x为根的子树,我们假设现在x的所有子节点j已经处理出来了f[j][0]和f[j][1],那么我们就用f[j][0/1]更新f[x][0/1],我们是来枚举x点的最优点权,不妨假设是y,那么更新过程就是假如d[x]j是小于y的,那么只需要加上以j为根的子树且满足j的点权小于等于父边权的最小代价即可,同理当d[x][j]是大于y的,只需要加上以j为根的子树且满足j的点权大于等于父边权的最小代价即可,而当d[x][j]等于y时我们需要加上二者的较小值,因为等于情况下,孩子节点的取值就任意了。这个应该不难理解,可是我们怎么求解x节点的点权呢?一个一个枚举吗?这肯定是不行的,这个时候我发现了一个问题,类似于货舱选址。就是说假如x有t个孩子节点,x节点与t个孩子节点连边的边权都直接看成数轴上的一个数,那么这t个数直接将数轴分成了t+1段区间,类似下图,其中v[x]代表x节点的点权位置,Y代表最优解所对应的点权
我们发现,x的值如果在某一个区间内移动,那么是不会影响子节点的值的选取的,因为他并不会突然超过某个边权使得子节点从原来取值为大于等于父边权的代价变为小于等于父边权的代价,也不会突然小于某个边权使得子节点从原来取值为小于等于父边权的代价变为大于等于父边权的代价。而唯一会影响的就是根节点x变为该权值的代价,假如v[x]是在我们所选取区间的右边,那么显然我们选取的这个区间内的值应该尽可能靠右,那么也就是区间右端点,因为这样权值变换代价最小,同理,如果v[x]是在我们所选取区间的左边,那么显然我们选取的这个区间内的值应该尽可能靠左,也就是区间左端点,所以这样我们就显然能够得到一个结论就是最优解一定是某个边权值,这是通过调整法得到的,我就这样写完发现了第一个问题,就是说,如果答案正好在v[x]所在的区间内,那么答案就是v[x],所以我们还需要单独考虑了一下v[x]不变的情况。然后我就开始枚举每一个边权值,然后扫一遍每一个孩子节点记录一下结果,发现这样会超时,因为这样极端情况下复杂度可能达到O(n^2),然后就开始想优化,优化倒不难想,就是对子节点按照父边权进行从小到大排序,然后记录一下前缀和presumf[i][0],和后缀sufsum[i][1],这样的好处就是我们在枚举第k大的边权的时候,我们就直接算一下第k大边权对应的孩子节点大于等于父边权和小于等于父边权的两种情况取一下最小值,然后加上presum[k-1]和sufsum[k+1][1],这表示前k-1个孩子结点都是取小于等于父边权的情况,而k+1之后的孩子节点都是取大于等于父边权的情况,但是后来我发现这样做也会存在一个问题,就是说如果有边权相同的时候可能会导致问题,就假如第k大边权和第k+1大边权值是相同的,那么我们在枚举第k大边权时就会默认把第k+1大边权取成f[k+1][1],但有可能是取f[k+1][0]的,所以这个问题我们还需要解决,我最后想的解决方法是把边权相同的边分到一块里面,里面不仅记录f[][0]的和以及f[][1]的和,还需要记录一个f[][0]和f[][1]的最小值的和,因为我们遍历到改边权时,如果子节点和该权值相同,那么子节点j取f[j][0]和f[j][1]的最小值即可,在代码中f1记录的每一块中f[][0]的加和,f2记录每一块中f[][0]和f[][1]较小值的加和,f3记录每一块中f[][1]的加和,然后分别枚举x点权为每一块点权即可,最后不要忘记枚举x点权不变的情况。
细节见代码:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<map>
#include<queue>
#include<vector>
#include<cmath>
using namespace std;
const int N=5e5+10;
int h[N],ne[N],e[N],w[N],idx;
long long v[N],k[N];
long long f[N][2];
//f[i][0]代表处理完以i为根的子树并且当前节点的weight小于等于其父边的最小代价
//f[i][1]代表处理完以i为根的子树并且当前节点的weight大于等于其父边的最小代价
long long presum[N],sufsum[N];
void add(int x,int y,int z)
{
e[idx]=y;
w[idx]=z;
ne[idx]=h[x];
h[x]=idx++;
}
struct node{
int id;//节点编号
int w;//与父亲节点之间的边的边权
};
struct block{
int w;//记录块内边权
long long f1;//块内f[][0]之和
long long f2;//块内f[][0/1]的最小值之和
long long f3;//块内f[][1]之和
};
bool cmp(node a,node b)
{
return a.w<b.w;
}
long long cal(int id,int x)//将第id个点权变为x的代价
{
if(x<v[id]) return 1ll*k[id]*(v[id]-x);
else return 1ll*k[id]*(x-v[id]);
}
void dfs(int x,int fa)
{
f[x][0]=f[x][1]=0x3f3f3f3f3f3f3f3f;
vector<node>son;//存储与儿子连边的信息
int fw=0;//记录x与其父亲节点之间的边权
for(int i=h[x];i!=-1;i=ne[i])
{
int j=e[i];
if(j==fa)
{
fw=w[i];
continue;
}
dfs(j,x);
son.push_back(node{j,w[i]});
}
if(!son.size())//没有孩子节点需要特判
{
if(v[x]<=fw)
{
f[x][0]=0;
f[x][1]=cal(x,fw);//将x点坐标变为父边权
}
else
{
f[x][0]=cal(x,fw);//将x点坐标变为父边权
f[x][1]=0;
}
return ;
}
//枚举父亲节点边权
sort(son.begin(),son.end(),cmp);
vector<block> tson;
for(int i=0;i<son.size();i++)
{
long long f1=f[son[i].id][0],f2=min(f[son[i].id][0],f[son[i].id][1]),f3=f[son[i].id][1];
while(i+1<son.size()&&son[i].w==son[i+1].w)//满足条件就将第i个节点和第i+1个节点合并
{
f1+=f[son[i+1].id][0];
f2+=min(f[son[i+1].id][0],f[son[i+1].id][1]);
f3+=f[son[i+1].id][1];
i++;
}
tson.push_back({son[i].w,f1,f2,f3});
}
presum[0]=tson[0].f1;
for(int i=1;i<tson.size();i++)
presum[i]=presum[i-1]+tson[i].f1;
sufsum[tson.size()-1]=tson[tson.size()-1].f3;
for(int i=tson.size()-2;i>=0;i--)
sufsum[i]=sufsum[i+1]+tson[i].f3;
for(int i=0;i<tson.size();i++)
{
long long val=tson[i].f2;
if(i) val+=presum[i-1];
if(i+1<tson.size()) val+=sufsum[i+1];
val+=cal(x,tson[i].w);//将x点权变为tson[i].w的代价
if(tson[i].w<=fw) f[x][0]=min(f[x][0],val);
if(tson[i].w>=fw) f[x][1]=min(f[x][1],val);
}
//考虑x点权值不变的情况
long long val=0;
for(int i=0;i<tson.size();i++)
{
if(tson[i].w<v[x]) val+=tson[i].f1;
else if(tson[i].w==v[x]) val+=tson[i].f2;
else val+=tson[i].f3;
}
if(v[x]<=fw) f[x][0]=min(f[x][0],val);
if(v[x]>=fw) f[x][1]=min(f[x][1],val);
//考虑将x节点值变为fw
val=cal(x,fw);
for(int i=0;i<tson.size();i++)
{
if(tson[i].w<fw) val+=tson[i].f1;
else if(tson[i].w==fw) val+=tson[i].f2;
else val+=tson[i].f3;
}
f[x][1]=min(f[x][1],val);
f[x][0]=min(f[x][0],val);
}
int main()
{
int T;
cin>>T;
while(T--)
{
int n;
scanf("%d",&n);
idx=0;
for(int i=1;i<=n;i++)
{
h[i]=-1;
scanf("%lld",&k[i]);
}
for(int i=1;i<=n;i++)
scanf("%lld",&v[i]);
for(int i=1;i<n;i++)
{
int u,v,z;
scanf("%d%d%d",&u,&v,&z);
add(u,v,z);add(v,u,z);
}
dfs(1,-1);
printf("%lld\n",min(f[1][0],f[1][1]));
}
return 0;
}