问题分析与解题思路
问题描述
题目要求我们将一棵由n个结点构成的 Bytedance Tree 划分为K个 Special 连通分块,每个分块具有以下特性Special 连通分块里只有一种礼物(该种类礼物的数量不限)。Special 连通分块可以包含任意数量的未挂上礼物的结点。
最终需要计算满足条件的划分方式的总数,并对结果取模998244353。
输入参数 :nodes:树中结点的数量。decorations:不同礼物的种类数量K。tree:一个包含树的礼物信息和边信息的列表,其中tree[0]是每个结点的礼物编号列表,tree[1:]是树的边列表。
解题思路 1. 提取礼物信息和边信息:gifts数组存储每个结点的礼物编号。edges列表存储树的边。
2. 初始化辅助变量cut_edges数组用于记录哪些边被“剪断”。MOD为取模的值998244353。
3. 判断连通分块是否合法的函数is_valid_block:根据当前的cut_edges状态构建邻接表adj。遍历每个结点,检查其连接的所有结点是否具有相同的礼物编号,或者没有礼物。
4. 深度优先搜索函数dfs:递归尝试剪断边,并更新当前的剪边计数cut_count。当剪边数达到decorations - 1时,检查当前划分是否合法。递归结束条件:剪边数超过decorations - 1或者遍历完所有边。
5. 遍历所有可能的剪边组合:使用dfs函数遍历所有可能的剪边情况,并统计合法的划分方式。
代码详细分析
提取礼物信息和边信息
gifts = tree[0]
edges = tree[1:]
#确保 gifts 数组的长度与节点数一致
if len(gifts) < nodes:
gifts.extend([0] * (nodes - len(gifts)))
gifts数组存储每个结点的礼物编号,如果某个结点没有礼物,则默认编号为0。edges列表存储树的边,每条边是一个(u, v)的二元组,表示结点u和结点v之间有一条边。如果gifts数组的长度小于结点数量nodes,则补0,确保每个结点都有一个礼物编号。
辅助变量初始化
cut_edges = [False] * (nodes - 1)
MOD = 998244353cut_edges数组用于记录树的每条边是否被剪断,初始状态为False。 MOD是取模的值998244353,用于防止结果溢出。
判断连通分块是否合法的函数is_valid_block
def is_valid_block(cut_edges):
adj = [[] for _ in range(nodes + 1)]
for i in range(nodes - 1):
if not cut_edges[i]:
u, v = edges[i]
adj[u].append(v)
adj[v].append(u)
for i in range(nodes):
if gifts[i] != 0:
for node in adj[i + 1]:
if gifts[node - 1] != gifts[i] and gifts[node - 1] != 0:
return False
return True
构建邻接表adj,根据cut_edges的状态决定哪些边被包含在内。遍历每个结点,检查其邻接结点是否具有相同的礼物编号。如果某个邻接结点的礼物编号不同且不为0,则该划分不合法,返回False。如果所有结点都满足条件,返回True。
深度优先搜索函数dfs
def dfs(i, cut_count):
nonlocal count
if cut_count > decorations - 1 or i >= nodes - 1:
return 0
# 尝试剪断当前边
cut_edges[i] = True
if is_valid_block(cut_edges):
count = (count + 1) % MOD
dfs(i + 1, cut_count + 1)
# 恢复当前边
cut_edges[i] = False
dfs(i + 1, cut_count)
使用nonlocal关键字访问外部作用域的count变量。递归尝试剪断当前边,并检查是否合法,如果合法则更新count。递归调用dfs函数,尝试剪断下一条边。恢复当前边的状态,继续递归不剪断当前边的情况。
主逻辑
count = 0
dfs(0, 0)
print(count)
初始化count为0。从第0条边开始,尝试所有可能的剪边组合。最终输出合法的划分方式总数count。
该算法使用深度优先搜索(DFS)遍历所有可能的剪边组合,对于每种组合,使用is_valid_block函数检查是否满足条件。通过递归尝试剪断和恢复边的状态,能够枚举所有可能的划分方式,并统计合法的划分总数。算法的时间复杂度较高,但由于题目数据范围较小,因此能够在合理时间内得出结果。
代码展示:def solution(nodes, decorations, tree):
MOD = 998244353
# 提取礼物信息和边信息
gifts = tree[0]
edges = tree[1:]
# 确保 gifts 数组的长度与节点数一致
if len(gifts) < nodes:
gifts.extend([0] * (nodes - len(gifts)))
# 记录剪边的数组
cut_edges = [False] * (nodes- 1)
# 判断连通分块是否只包含一种礼物的函数
def is_valid_block(cut_edges):
# 初始化邻接表
adj = [[] for _ in range(nodes + 1)]
for i in range(nodes-1):
if cut_edges[i] == False:
u, v = edges[i]
adj_u_set = set(adj[u])
for w in adj[v]:
if w not in adj_u_set:
adj[u].append(w)
adj_u_set.add(w)
adj[u].append(v)
adj_v_set = set(adj[v])
for w in adj[u]:
if w not in adj_v_set:
adj[v].append(w)
adj_v_set.add(w)
adj[v].append(u)
for i in range(nodes):
# print(gifts[i],adj[i+1])
if gifts[i]!=0:
for node in adj[i+1]:
#print(gifts[node-1],gifts[i])
if gifts[node-1]!=gifts[i] and gifts[node-1]!=0:
return False
return True
count=0
# 深度优先搜索函数
def dfs(i, cut_count):
nonlocal count # 使用 nonlocal 关键字访问外部作用域的 count 变量
if cut_count > decorations - 1 or i>=nodes-1:
return 0
#如果剪掉这条边
if cut_count == decorations-1:
#print(i,cut_count)
if is_valid_block(cut_edges):
count+=1
#print (cut_edges,is_valid_block(cut_edges))
return 0
cut_edges[i]=True
dfs(i+1,cut_count+1)
cut_edges[i]=False
dfs(i+1,cut_count)
return 0
# 从根节点开始 DFS
dfs(0,0)
return count % MOD
if name == "main":
# You can add more test cases here
testTree1 = [[1,0,0,0,0,2,3],[1,7],[3,7],[2,1],[3,5],[5,6],[6,4]]
testTree2 = [[1,0,1,0,2],[1,2],[1,5],[2,4],[3,5]]
tree = [[1, 2, 0, 1, 0, 2], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]
print(solution(7, 3, testTree1) )
print(solution(6, 2, tree) )