#Codeforces 1059E Split the Tree
传送门:codeforces.com/contest/105…
大概题意:给你一个树,每个节点都有权值,再给你S和L,让你把树切成多条树枝的组合,要求每条树枝权值和小于S,每条树枝结点数小于L,并且树枝中前后节点的关系一定是前结点是后结点的父亲,求最少的树枝分组
根据题意,树枝不能分叉。
我们首先可以贪心地想这个问题:如果每个作为起点的结点都能获得最长树枝,那么结果一定是最优的。
而题干中给出两个限制条件:S和L,所以我们就根据S和L判断以某一个点作为起点向祖先结点生成树枝时,能生成的最长树枝的另一端结点的编号。
我们可以利用倍增(ST),在O(nlogn)时间内获得这个编号
所以每次我们只需要找到没有子节点的点,生成最长树枝,并更新其他点的子节点。循环这个过程直到所有点都被遍历
#include<bits/stdc++.h>
#define pb push_back
#define MID(a,b) ((a+b)>>1)
#define LL(a) (a<<1)
#define RR(a) (a<<1|1)
using namespace std;
typedef long long ll;
typedef pair<int, int>pii;
typedef pair<double, int> pdi;
typedef pair<ll,int>pli;
typedef pair<ll,ll>pll;
typedef pair<string,int>psi;
const int N = 1e5+5,M=100005;
const int inf=0x3f3f3f3f;
const ll INF=1000000000000000000ll;
const ll mod = 998244353;
const double pi=acos(-1.0);
const double eps=1e-6;
int n, L, fa[N], out[N];
vector<int> e[N];
ll w[N], st[N][21], S;
int stid[N][21];
bool vis[N], used[N];
void init(){
for(int i=1;i<=n;i++) e[i].clear();
memset(out, 0, sizeof(out));
memset(vis, false, sizeof(vis));
memset(used,false, sizeof(used));
}
// st[id][i]指以id为起始点(不包括i)的2^i各点的权值和
// stid[id][i]指以id为起始点(不包括i)的第2^i个点的编号
void getst(int id){
int siz = e[id].size();
if(id != 1)
st[id][0] = w[fa[id]];
else st[id][0] = 0;
stid[id][0]=fa[id];
for(int i=1; i<=20; i++){
stid[id][i] = stid[stid[id][i-1]][i-1];
if(stid[id][i] != stid[id][i-1])
st[id][i] = st[id][i-1] + st[stid[id][i-1]][i-1];
else st[id][i] = st[id][i-1];
}
for(int i=0;i<siz;i++){
getst(e[id][i]);
}
}
int upfind(int id, int len, ll val){
if(id == 1){
return 1;
}
if(val == S)return id;
for(int i=0; i<=20; i++){
if((1<<i)+len == L) {
if(st[id][i]+val > S){
return id;
}
return stid[id][i];
}
if((1<<i)+len > L) {
return upfind(stid[id][i-1], (1<<(i-1))+len, val+st[id][i-1]);
}
if(val + st[id][i] == S) {
return stid[id][i];
}
if(i != 20 && val+st[id][i] < S && val+st[id][i+1] > S){
return upfind(stid[id][i], (1<<i) + len, val + st[id][i]);
}
if(i == 20 && val+st[id][i] < S){
return 1;
}
}
}
int solve(){
int ans = 0;
while(1){
int cnt = 0;
for(int i=1;i<=n;i++){
if(!vis[i] && !out[i]){
cnt++;
vis[i] = true;
int id = upfind(i, 1, w[i]);
vis[id] = true;
if(fa[id] != id && !used[id]){
out[fa[id]]--;
used[id] = true;
}
ans++;
}
}
if(!cnt){
break;
}
}
return ans;
}
int main(){
while(~scanf("%d%d%lld", &n, &L, &S)){
init();
bool ju = false;
for(int i=1;i<=n;i++){
scanf("%lld", &w[i]);
if(w[i] > S){
ju = true;
}
}
int x;
for(int i=2;i<=n;i++){
scanf("%d", &x);
e[x].pb(i);
out[x] ++;
fa[i] = x;
}
if(ju){
printf("-1\n");
continue;
}
if(L == 1){
printf("%d\n", n);
continue;
}
fa[1] = 1;
getst(1);
printf("%d\n", solve());
}
}