Bytedance Tree 问题 | 豆包MarsCode AI刷题

274 阅读4分钟

问题描述

众所周知,每两周的周三是字节跳动的活动日。作为活动组织者的小 A,在这次活动日上布置了一棵 Bytedance Tree。

Bytedance Tree 由 n 个结点构成,每个结点的编号分别为 1,2,3......n,有 n - 1 条边将它们连接起来,根结点为 1。而且为了观赏性,小 A 给 M 个结点挂上了 K 种礼物(0 ≤ K ≤ M ≤ N, 且保证一个结点只有一个礼物)。

现在小 A 希望将 Bytedance Tree 划分为 K 个 Special 连通分块,送给参加活动日活动的同学们,请问热心肠的你能帮帮他,告诉小 A 一共有多少种划分方式吗?

一个 Special 连通分块应该具有以下特性:

  • Special 连通分块里只有一种礼物(该种类礼物的数量不限)
  • Special 连通分块可以包含任意数量的未挂上礼物的结点

由于方案数可能过大,对结果取模 998244353

输入格式

第一行输入两个整数 n 和 k,分别表示 n 个结点和 k 种装饰品。

接下来一行,输入 n 个整数 a0, a1,...an,表示第 i 个结点挂着的礼物种类为 ai(0 表示没有挂礼物)

接下来 n - 1 行,输入两个整数 u 和 v,分别表示结点 u 和结点 v 之间有一条边。

输出格式

一行

输出方案数即可

数据范围

2 ≤ n ≤ 1000000(1e6)

2 ≤ k ≤ n

样例 1

INPUT

7 3

1 0 0 0 0 2 3

1 7

3 7

2 1

3 5

5 6

6 4

OUTPUT

3

样例 2

INPUT

5 2

1 0 1 0 2

1 2

1 5

2 4

3 5

OUTPUT

0

示例1

      1
     / \
    2   7
        |
        3
        |
        5
        |
        6
        |
        4
  • 划分方式 1:

    • 类型1: 节点1, 节点2
    • 类型2: 节点6, 节点4
    • 类型3: 节点7, 节点3, 节点5
  • 划分方式 2:

    • 类型1: 节点1, 节点2
    • 类型2: 节点6, 节点4, 节点5
    • 类型3: 节点7, 节点3
  • 划分方式 3:

    • 类型1: 节点1, 节点2
    • 类型2: 节点6, 节点4, 节点5, 节点3
    • 类型3: 节点7

思路

  • 解析输入:

    • 读取礼物分布列表,确定哪些结点挂有礼物,并记录每种礼物对应的结点。
    • 构建树的邻接表表示。
  • 选择根结点:

    • 从礼物结点中任选一个作为根结点。
  • 广度优先搜索(BFS):

    • 从根结点出发,使用BFS遍历整个树,并记录每个结点的父结点。
  • 计算切分方式:

    • 对于每个非根的礼物结点,沿着路径向上遍历,直到遇到另一个礼物结点(根结点)。
    • 记录路径上可切分的位置数,并将这些数相乘,得到最终的方案数。
  • 处理特殊情况:

    • 如果某种礼物类型对应多个结点,当前实现假设每种类型只有一个结点。如果存在多结点对应同一类型的情况,程序将返回0。

代码

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;

public class Main {
    public static int solution(int nodes, int decorations, List<List<Integer>> tree) {
        // 模拟输入的第一行是礼物分布
        List<Integer> decorationsList = tree.get(0);
        
        // 找到所有挂有礼物的结点
        List<Integer> decorationNodes = new ArrayList<>();
        for(int i = 0; i < decorationsList.size(); i++) {
            if(decorationsList.get(i) != 0){
                decorationNodes.add(i+1); // 1-based indexing
            }
        }
        
        // 检查礼物的数量是否与K一致
        if(decorationNodes.size() != decorations){
            return 0;
        }
        
        // 检查每种礼物类型是否只对应一个结点
        int[] typeCount = new int[decorations +1];
        for(int i = 0; i < decorationsList.size(); i++){
            if(decorationsList.get(i) != 0){
                typeCount[decorationsList.get(i)]++;
            }
        }
        for(int t = 1; t <= decorations; t++){
            if(typeCount[t] !=1){
                return 0;
            }
        }
        
        // 选择第一个礼物结点作为根结点
        int root = decorationNodes.get(0);
        
        // 构建树的邻接表
        List<List<Integer>> adj = new ArrayList<>(nodes +1);
        for(int i = 0; i <= nodes; i++){
            adj.add(new ArrayList<>());
        }
        for(int i =1; i < tree.size(); i++){
            List<Integer> edge = tree.get(i);
            int u = edge.get(0);
            int v = edge.get(1);
            adj.get(u).add(v);
            adj.get(v).add(u);
        }
        
        // BFS 来确定每个结点的父结点
        int[] parent = new int[nodes +1];
        boolean[] isDecoration = new boolean[nodes +1];
        for(int node : decorationNodes){
            isDecoration[node] = true;
        }
        parent[root] = -1;
        Queue<Integer> queue = new LinkedList<>();
        queue.add(root);
        while(!queue.isEmpty()){
            int u = queue.poll();
            for(int v : adj.get(u)){
                if(parent[v] == 0 && v != root){
                    parent[v] = u;
                    queue.add(v);
                }
            }
        }
        
        // 计算方案数
        long mod = 998244353;
        long total =1;
        for(int i =1; i < decorationNodes.size(); i++){
            int node = decorationNodes.get(i);
            int count =0;
            int current = node;
            while(parent[current] != -1){
                int p = parent[current];
                count++;
                if(isDecoration[p]){
                    break;
                }
                current = p;
            }
            total = (total * count) % mod;
        }
        
        return (int) total;
    }

    public static void main(String[] args) {
    
        List<List<Integer>> testTree1 = new ArrayList<>();
        testTree1.add(List.of(1, 0, 0, 0, 0, 2, 3)); 
        testTree1.add(List.of(1, 7));
        testTree1.add(List.of(3, 7));
        testTree1.add(List.of(2, 1));
        testTree1.add(List.of(3, 5));
        testTree1.add(List.of(5, 6));
        testTree1.add(List.of(6, 4));

        List<List<Integer>> testTree2 = new ArrayList<>();
        testTree2.add(List.of(1, 0, 1, 0, 2)); 
        testTree2.add(List.of(1, 2));
        testTree2.add(List.of(1, 5));
        testTree2.add(List.of(2, 4));
        testTree2.add(List.of(3, 5));


        // 运行并验证结果
        System.out.println(solution(7, 3, testTree1) == 3); 
        System.out.println(solution(5, 2, testTree2) == 0); 
    }
}