字符串首尾相同子序列计数 | 豆包MarsCode AI刷题

98 阅读3分钟

问题描述

小M拿到了一个仅由小写字母组成的字符串,她想知道在这个字符串中,有多少个子序列的首尾字符相同。子序列的定义是:从原字符串中按原顺序取出若干字符(可以不连续)组成的新字符串。

例如,对于字符串 "arcaea",其子序列包括 "aca", "ara", "aaa" 等,这些子序列的首尾字符都是相同的。

你需要计算满足这一条件的子序列数量,并输出对 998244353998244353 取模的结果。

测试样例

样例1:

输入:s = "arcaea" 输出:28

样例2:

输入:s = "abcabc" 输出:18

样例3:

输入:s = "aaaaa" 输出:31

解题思路

预计算幂次:

pow2 数组存储了 2^i mod 998244353。

inv2_pow 数组存储了 2^(−i ) mod 998244353,其中 2^(−i ) 是 2^i 的模逆元。

字符位置映射:

使用 HashMap 来记录每个字符在字符串中的所有出现位置。

计算总美丽值:

对于每一个不同字符,根据它在字符串中的位置分布计算贡献到总美丽值的部分。

使用前缀和的思想来减少重复计算,并保持操作的高效性。

代码实现

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Main {
    private static final int MOD = 998244353;

    public static int solution(String s) {
        int n = s.length();

        long[] pow2 = new long[n + 1];
        pow2[0] = 1;
        for (int i = 1; i <= n; i++) {
            pow2[i] = (pow2[i - 1] * 2) % MOD;
        }

        long inv2 = modPow(2, MOD - 2, MOD);
        long[] inv2_pow = new long[n + 1];
        inv2_pow[0] = 1;
        for (int i = 1; i <= n; i++) {
            inv2_pow[i] = (inv2_pow[i - 1] * inv2) % MOD;
        }

        Map<Character, List<Integer>> charPos = new HashMap<>();
        for (int idx = 0; idx < n; idx++) {
            char c = s.charAt(idx);
            charPos.computeIfAbsent(c, k -> new ArrayList<>()).add(idx);
        }

        long total = 0;
        for (Map.Entry<Character, List<Integer>> entry : charPos.entrySet()) {
            List<Integer> positions = entry.getValue();
            int k = positions.size();
            long sumC = 0;
            long prefixInvSum = 0;
            for (int j = 0; j < k; j++) {
                int pj = positions.get(j);
                if (j >= 1) {
                    long term = (pow2[pj - 1] * prefixInvSum) % MOD;
                    sumC = (sumC + term) % MOD;
                }
                prefixInvSum = (prefixInvSum + inv2_pow[pj]) % MOD;
            }
            sumC = (sumC + k) % MOD;
            total = (total + sumC) % MOD;
        }
        return (int) total;
    }

    private static long modPow(long base, long exp, long mod) {
        long result = 1;
        while (exp > 0) {
            if (exp % 2 == 1) {
                result = (result * base) % mod;
            }
            base = (base * base) % mod;
            exp /= 2;
        }
        return result;
    }

    public static void main(String[] args) {
        System.out.println(solution("arcaea") == 28);
        System.out.println(solution("abcabc") == 18);
        System.out.println(solution("aaaaa") == 31);
    }
}

示例

  • 对于字符串 "arcaea":
  • 字符 'a' 的位置为 [0, 3, 5]。
  • 计算这些位置的有效子序列数量,并累加到 total 中。
  • 最终结果为 28。
  • 对于字符串 "abcabc":
  • 字符 'a' 的位置为 [0, 3]。
  • 字符 'b' 的位置为 [1, 4]。
  • 字符 'c' 的位置为 [2, 5]。
  • 计算这些位置的有效子序列数量,并累加到 total 中。
  • 最终结果为 18。
  • 对于字符串 "aaaaa":
  • 字符 'a' 的位置为 [0, 1, 2, 3, 4]。
  • 计算这些位置的有效子序列数量,并累加到 total 中。
  • 最终结果为 31。

总结

时间复杂度是 O(n)。首先,预计算 pow2 和 inv2_pow 数组各需要 O(n) 的时间。接着,遍历字符串并记录每个字符的位置也需要 O(n) 的时间。最后,在计算总美丽值时,虽然有嵌套循环,但由于每个字符的位置总数加起来仍然是 n ,因此这一部分的总时间复杂度也是 O(n) 。综合来看,整个算法的时间复杂度是线性的,即 O(n)。

空间复杂度也是 O(n)。预计算的 pow2 和 inv2_pow 数组各自占用 O(n) 的空间。存储每个字符位置的 HashMap 在最坏情况下(例如所有字符都不同)将包含 n 个元素,每个元素是一个 List,总共占用 O(n) 的空间。其他变量如 total、sumC、prefixInvSum 等都是常数级别的空间。因此,整个算法的空间复杂度是 O(n)。