33.卡牌翻面求和问题(题解)

41 阅读2分钟

题目描述

给定两个数组 a 和 b,每个数组包含 n 个整数。我们需要找到一种方法,使得从这两个数组中选择元素的组合,其和对 3 取余为 0 的方案数。每个位置可以选择 a 或 b 中的元素。

解法一:递归 + 记忆化搜索

我们可以使用递归的方法来解决这个问题,通过记忆化搜索来优化递归的效率。

from functools import cache

mod = int(1e9 + 7)

def solution(n: int, a: list, b: list) -> int:
    @cache
    def dfs(i, j):
        if i == n:
            return j == 0
        return (dfs(i + 1, (j + a[i]) % 3) + dfs(i + 1, (j + b[i]) % 3)) % mod

    return dfs(0, 0)

if __name__ == '__main__':
    print(solution(n = 3, a = [1, 2, 3], b = [2, 3, 2]) == 3)
    print(solution(n = 4, a = [3, 1, 2, 4], b = [1, 2, 3, 1]) == 6)
    print(solution(n = 5, a = [1, 2, 3, 4, 5], b = [1, 2, 3, 4, 5]) == 32)
  • 使用递归函数 dfs(i, j) 表示从第 i 个位置开始,当前余数为 j 的方案数。
  • 递归终止条件是当 i == n 时,检查余数是否为 0
  • 通过记忆化搜索(@cache)来避免重复计算。

解法二:二维动态规划

我们可以使用二维动态规划来解决这个问题。定义 f[i][j] 表示前 i 个元素,余数为 j 的方案数。

  • 初始化 f[0][0] = 1,表示没有元素时,余数为 0 的方案数为 1
  • 遍历每个元素 i 和每个余数 j,更新 [f[i + 1][(j + a[i]) % 3]]和 [f[i + 1][(j + b[i]) % 3]]。
mod = int(1e9 + 7)

def solution(n: int, a: list, b: list) -> int:
    f = [[0] * 3 for _ in range(n + 1)]
    f[0][0] = 1
    for i in range(n):
        for j in range(3):
            f[i + 1][(j + a[i]) % 3] += f[i][j]
            f[i + 1][(j + b[i]) % 3] += f[i][j]
            f[i + 1][(j + a[i]) % 3] %= mod
            f[i + 1][(j + b[i]) % 3] %= mod
    return f[n][0] % mod

if __name__ == '__main__':
    print(solution(n = 3, a = [1, 2, 3], b = [2, 3, 2]) == 3)
    print(solution(n = 4, a = [3, 1, 2, 4], b = [1, 2, 3, 1]) == 6)
    print(solution(n = 5, a = [1, 2, 3, 4, 5], b = [1, 2, 3, 4, 5]) == 32)

解法一:一维动态规划

注意到当前状态f[i+1]只和f[i]种的状态有关,因此可以优化空间

mod = int(1e9 + 7)

def solution(n: int, a: list, b: list) -> int:
    f = [0] * 3
    f[0] = 1
    for i in range(n):
        cur = [0] * 3
        for j in range(3):
            cur[(j + a[i]) % 3] += f[j]
            cur[(j + b[i]) % 3] += f[j]
        f = cur
    return f[0] % mod

if __name__ == '__main__':
    print(solution(n = 3, a = [1, 2, 3], b = [2, 3, 2]) == 3)
    print(solution(n = 4, a = [3, 1, 2, 4], b = [1, 2, 3, 1]) == 6)
    print(solution(n = 5, a = [1, 2, 3, 4, 5], b = [1, 2, 3, 4, 5]) == 32)