刷题笔记-统计字符串中首尾字符相同的子序列数量 | 豆包MarsCode AI刷题

137 阅读4分钟

问题描述

www.marscode.cn/practice/vk…

给定一个仅由小写字母组成的字符串 S,需要计算其中有多少个子序列的首尾字符相同。子序列的定义是从原字符串中按原顺序选取若干字符(可以不连续)组成的新字符串。最终结果需对 998244353998244353 取模。

示例:

  • 输入:"arcaea"
    输出:28

  • 输入:"abcabc"
    输出:18

  • 输入:"aaaaa"
    输出:31

解题思路

1. 子序列的定义与性质

  • 子序列是从原字符串中按顺序选取若干字符组成的新字符串,字符可以不连续。
  • 要求子序列的首尾字符相同,即子序列的第一个字符和最后一个字符相同。

2. 问题转化

我们需要统计所有满足以下条件的子序列数量:

  • 子序列长度至少为1(单个字符本身也是一个子序列)。
  • 子序列的首字符和尾字符相同。

3. 动态规划与字符统计

利用动态规划、结合字符统计来高效计算。

主要思路:

  • 统计每个字符的位置:首先遍历字符串,记录每个字符出现的位置。

  • 计算以某字符为首尾的子序列数量

    对于每个字符 c,假设它在字符串中出现了 k 次,位置分别为 p0, p1, ..., pk-1。对于任意一对位置 (pi, pj),其中 pi <= pj,则以 c 为首尾的子序列可以通过在 pipj 之间任意选取字符来构成。

    • i == j 时,对应的子序列只有一个,即单个字符 c
    • i < j 时,pipj 之间有 pj - pi - 1 个字符,每个字符可以选择是否包含,因此有 2^(pj - pi - 1) 种子序列。
  • 优化计算

    为了避免重复计算和提高效率,可以预先计算所有可能的 2^k mod 998244353,并利用前缀和或其他技巧加速计算。

4. 实现步骤

  1. 预计算幂次

    计算并存储 pow2[i] = 2^i mod 998244353,用于快速查询 2^(pj - pi -1)

  2. 统计字符位置

    使用哈希表(如 defaultdict)记录每个字符出现的所有位置。

  3. 遍历每个字符

    对于每个字符 c,遍历其所有出现的位置,计算以 c 为首尾的子序列数量,并累加到总数中。

  4. 结果取模

    最终结果对 998244353 取模输出。

代码实现

MOD = 998244353

def solution(s: str) -> int:
    n = len(s)
    pow2 = [1] * (n + 1)
    for i in range(1, n + 1):
        pow2[i] = pow2[i-1] * 2 % MOD

    from collections import defaultdict
    char_positions = defaultdict(list)
    for idx, c in enumerate(s):
        char_positions[c].append(idx)

    total = 0
    for c, positions in char_positions.items():
        k = len(positions)
        # 对于每个字符c,计算以c为首尾的子序列数量
        # 单个字符的子序列数量为k
        # 两个及以上字符的子序列数量为 Σ(2^(pj - pi -1)),其中 pj > pi
        # 总计为k + Σ_{i < j} 2^(pj - pi -1)
        total_c = k  # 单个字符的子序列
        for i in range(k):
            for j in range(i+1, k):
                total_c = (total_c + pow2[positions[j] - positions[i] -1]) % MOD
        total = (total + total_c) % MOD
    return total

if __name__ == '__main__':
    print(solution("arcaea") == 28)
    print(solution("abcabc") == 18)
    print(solution("aaaaa") == 31)

代码解释

  1. 预计算幂次

    pow2 = [1] * (n + 1)
    for i in range(1, n + 1):
        pow2[i] = pow2[i-1] * 2 % MOD
    
    • 计算 2^i mod 998244353,存储在 pow2 数组中,方便后续快速查询。
  2. 统计字符位置

    from collections import defaultdict
    char_positions = defaultdict(list)
    for idx, c in enumerate(s):
        char_positions[c].append(idx)
    
    • 使用 defaultdict 记录每个字符在字符串中的所有出现位置。
  3. 计算每个字符的贡献

    for c, positions in char_positions.items():
        k = len(positions)
        total_c = k  # 单个字符的子序列
        for i in range(k):
            for j in range(i+1, k):
                total_c = (total_c + pow2[positions[j] - positions[i] -1]) % MOD
        total = (total + total_c) % MOD
    
    • 对于每个字符 c,首先将其作为单个字符的子序列数量 k 加入 total_c
    • 然后,对于每一对不同的位置 (i, j),计算 2^(pj - pi -1) 并累加到 total_c 中。
    • 最后,将每个字符 c 的贡献累加到总结果 total 中。

复杂度分析

  • 时间复杂度

    • 统计字符位置:O(n)O(n)
    • 对于每个字符,遍历其所有出现位置的两两组合:假设每个字符最多出现 mm 次,则总时间复杂度为 O(n+m2)O(n + \sum m^2)。在最坏情况下,当所有字符相同时,时间复杂度为 O(n2)O(n^2)
  • 空间复杂度

    • 主要用于存储 pow2 数组和字符位置,空间复杂度为 O(n)O(n)

总结

通过预先计算幂次和统计每个字符的位置,我们能够高效地计算出所有以相同字符为首尾的子序列数量。尽管在最坏情况下时间复杂度为 O(n2)O(n^2),但对于大多数实际情况而言,效率仍然可以接受。