问题描述
给定一个仅由小写字母组成的字符串 S,需要计算其中有多少个子序列的首尾字符相同。子序列的定义是从原字符串中按原顺序选取若干字符(可以不连续)组成的新字符串。最终结果需对 取模。
示例:
-
输入:
"arcaea"
输出:28 -
输入:
"abcabc"
输出:18 -
输入:
"aaaaa"
输出:31
解题思路
1. 子序列的定义与性质
- 子序列是从原字符串中按顺序选取若干字符组成的新字符串,字符可以不连续。
- 要求子序列的首尾字符相同,即子序列的第一个字符和最后一个字符相同。
2. 问题转化
我们需要统计所有满足以下条件的子序列数量:
- 子序列长度至少为1(单个字符本身也是一个子序列)。
- 子序列的首字符和尾字符相同。
3. 动态规划与字符统计
利用动态规划、结合字符统计来高效计算。
主要思路:
-
统计每个字符的位置:首先遍历字符串,记录每个字符出现的位置。
-
计算以某字符为首尾的子序列数量:
对于每个字符
c,假设它在字符串中出现了k次,位置分别为p0, p1, ..., pk-1。对于任意一对位置(pi, pj),其中pi <= pj,则以c为首尾的子序列可以通过在pi和pj之间任意选取字符来构成。- 当
i == j时,对应的子序列只有一个,即单个字符c。 - 当
i < j时,pi和pj之间有pj - pi - 1个字符,每个字符可以选择是否包含,因此有2^(pj - pi - 1)种子序列。
- 当
-
优化计算:
为了避免重复计算和提高效率,可以预先计算所有可能的
2^k mod 998244353,并利用前缀和或其他技巧加速计算。
4. 实现步骤
-
预计算幂次:
计算并存储
pow2[i] = 2^i mod 998244353,用于快速查询2^(pj - pi -1)。 -
统计字符位置:
使用哈希表(如
defaultdict)记录每个字符出现的所有位置。 -
遍历每个字符:
对于每个字符
c,遍历其所有出现的位置,计算以c为首尾的子序列数量,并累加到总数中。 -
结果取模:
最终结果对
998244353取模输出。
代码实现
MOD = 998244353
def solution(s: str) -> int:
n = len(s)
pow2 = [1] * (n + 1)
for i in range(1, n + 1):
pow2[i] = pow2[i-1] * 2 % MOD
from collections import defaultdict
char_positions = defaultdict(list)
for idx, c in enumerate(s):
char_positions[c].append(idx)
total = 0
for c, positions in char_positions.items():
k = len(positions)
# 对于每个字符c,计算以c为首尾的子序列数量
# 单个字符的子序列数量为k
# 两个及以上字符的子序列数量为 Σ(2^(pj - pi -1)),其中 pj > pi
# 总计为k + Σ_{i < j} 2^(pj - pi -1)
total_c = k # 单个字符的子序列
for i in range(k):
for j in range(i+1, k):
total_c = (total_c + pow2[positions[j] - positions[i] -1]) % MOD
total = (total + total_c) % MOD
return total
if __name__ == '__main__':
print(solution("arcaea") == 28)
print(solution("abcabc") == 18)
print(solution("aaaaa") == 31)
代码解释
-
预计算幂次:
pow2 = [1] * (n + 1) for i in range(1, n + 1): pow2[i] = pow2[i-1] * 2 % MOD- 计算
2^i mod 998244353,存储在pow2数组中,方便后续快速查询。
- 计算
-
统计字符位置:
from collections import defaultdict char_positions = defaultdict(list) for idx, c in enumerate(s): char_positions[c].append(idx)- 使用
defaultdict记录每个字符在字符串中的所有出现位置。
- 使用
-
计算每个字符的贡献:
for c, positions in char_positions.items(): k = len(positions) total_c = k # 单个字符的子序列 for i in range(k): for j in range(i+1, k): total_c = (total_c + pow2[positions[j] - positions[i] -1]) % MOD total = (total + total_c) % MOD- 对于每个字符
c,首先将其作为单个字符的子序列数量k加入total_c。 - 然后,对于每一对不同的位置
(i, j),计算2^(pj - pi -1)并累加到total_c中。 - 最后,将每个字符
c的贡献累加到总结果total中。
- 对于每个字符
复杂度分析
-
时间复杂度:
- 统计字符位置:
- 对于每个字符,遍历其所有出现位置的两两组合:假设每个字符最多出现 次,则总时间复杂度为 。在最坏情况下,当所有字符相同时,时间复杂度为 。
-
空间复杂度:
- 主要用于存储
pow2数组和字符位置,空间复杂度为 。
- 主要用于存储
总结
通过预先计算幂次和统计每个字符的位置,我们能够高效地计算出所有以相同字符为首尾的子序列数量。尽管在最坏情况下时间复杂度为 ,但对于大多数实际情况而言,效率仍然可以接受。