Codeforces 1059E Split the Tree

160 阅读3分钟

#Codeforces 1059E Split the Tree

传送门:codeforces.com/contest/105…

大概题意:给你一个树,每个节点都有权值,再给你S和L,让你把树切成多条树枝的组合,要求每条树枝权值和小于S,每条树枝结点数小于L,并且树枝中前后节点的关系一定是前结点是后结点的父亲,求最少的树枝分组

根据题意,树枝不能分叉。

我们首先可以贪心地想这个问题:如果每个作为起点的结点都能获得最长树枝,那么结果一定是最优的。

而题干中给出两个限制条件:S和L,所以我们就根据S和L判断以某一个点作为起点向祖先结点生成树枝时,能生成的最长树枝的另一端结点的编号。

我们可以利用倍增(ST),在O(nlogn)时间内获得这个编号

所以每次我们只需要找到没有子节点的点,生成最长树枝,并更新其他点的子节点。循环这个过程直到所有点都被遍历

#include<bits/stdc++.h>
#define pb push_back
#define MID(a,b) ((a+b)>>1)
#define LL(a) (a<<1)
#define RR(a) (a<<1|1)
using namespace std;
typedef long long ll;
typedef pair<int, int>pii;
typedef pair<double, int> pdi;
typedef pair<ll,int>pli;
typedef pair<ll,ll>pll;
typedef pair<string,int>psi;
const int N = 1e5+5,M=100005;
const int inf=0x3f3f3f3f;
const ll INF=1000000000000000000ll;
const ll mod = 998244353;
const double pi=acos(-1.0);
const double eps=1e-6;
int n, L, fa[N], out[N];
vector<int> e[N];
ll w[N], st[N][21], S;
int stid[N][21];
bool vis[N], used[N];
void init(){
	for(int i=1;i<=n;i++) e[i].clear();
	memset(out, 0, sizeof(out));
	memset(vis, false, sizeof(vis));
	memset(used,false, sizeof(used));
}
// st[id][i]指以id为起始点(不包括i)的2^i各点的权值和
// stid[id][i]指以id为起始点(不包括i)的第2^i个点的编号
void getst(int id){
	int siz = e[id].size();
	if(id != 1)
		st[id][0] = w[fa[id]]; 
	else st[id][0] = 0;
	stid[id][0]=fa[id];
	for(int i=1; i<=20; i++){
		stid[id][i] = stid[stid[id][i-1]][i-1];
		if(stid[id][i] != stid[id][i-1])
			st[id][i] = st[id][i-1] + st[stid[id][i-1]][i-1];
		else st[id][i] = st[id][i-1];
		
	}
	for(int i=0;i<siz;i++){
		getst(e[id][i]);
	}
}
int upfind(int id, int len, ll val){
	if(id == 1){
		return 1;
	}
	if(val == S)return id;
	for(int i=0; i<=20; i++){
		if((1<<i)+len == L) { 
			if(st[id][i]+val > S){
				return id;
			}
			return stid[id][i];
		}
		if((1<<i)+len > L) {
			return upfind(stid[id][i-1], (1<<(i-1))+len, val+st[id][i-1]);
		}
		if(val + st[id][i] == S) {
			return stid[id][i];
		}
		if(i != 20 && val+st[id][i] < S && val+st[id][i+1] > S){
			return upfind(stid[id][i], (1<<i) + len, val + st[id][i]);
		}
		if(i == 20 && val+st[id][i] < S){
			return 1;
		} 
	}
}

int solve(){
	int ans = 0;	
	while(1){
		int cnt = 0;
		for(int i=1;i<=n;i++){
			if(!vis[i] && !out[i]){
				cnt++;
				vis[i] = true;
				int id = upfind(i, 1, w[i]);
				vis[id] = true;
				if(fa[id] != id && !used[id]){
					out[fa[id]]--;
					used[id] = true;
				} 
				ans++;
			}
		}
		if(!cnt){
			break;
		}
	}
	return ans;
}

int main(){
	while(~scanf("%d%d%lld", &n, &L, &S)){
		init();
		bool ju = false;
		for(int i=1;i<=n;i++){
			scanf("%lld", &w[i]);
			if(w[i] > S){
				ju = true;
			}
		}
		int x; 
		for(int i=2;i<=n;i++){
			scanf("%d", &x);
			e[x].pb(i);
			out[x] ++;
			fa[i] = x;
		}
		if(ju){
			printf("-1\n");
			continue;
		}
		if(L == 1){
			printf("%d\n", n);
			continue;
		}
		fa[1] = 1;
		getst(1);
		printf("%d\n", solve());
	}
}