问题分析
目标: 找出有多少种正反面组合,使得卡牌向上的数字之和可以被 3 整除。
关键点: 每张卡牌有两种选择(正面或反面),总共有 2n2^n 种组合,暴力求解不可行,需采用动态规划。
动态规划解决方案
-
状态定义:
- 定义
dp[i][j]表示前 ii 张卡牌中,使得数字之和的模 3 等于 jj 的组合数。
- 定义
-
状态转移:
-
对于第 ii 张卡牌,有两种选择:
- 正面朝上: 累加 a[i]a[i]。
- 反面朝上: 累加 b[i]b[i]。
-
转移方程: dp[i][(j+a[i])%3]+=dp[i−1][j]dp[i][(j + a[i]) % 3] += dp[i-1][j] dp[i][(j+b[i])%3]+=dp[i−1][j]dp[i][(j + b[i]) % 3] += dp[i-1][j]
-
-
初始化:
- dp[0][a[0]%3]+=1dp[0][a[0] % 3] += 1 表示第一张卡牌正面朝上。
- dp[0][b[0]%3]+=1dp[0][b[0] % 3] += 1 表示第一张卡牌反面朝上。
-
目标结果:
- 求 dp[n−1][0]dp[n-1][0],即前 nn 张卡牌模 3 为 0 的组合数。
代码实现与讲解
def solution(n: int, a: list, b: list) -> int:
MOD = 10**9 + 7 # 大数取模
# dp[i][j] 表示前 i 张卡牌中,余数为 j 的组合数
dp = [[0] * 3 for _ in range(n)]
# 初始化第 0 张卡牌
dp[0][a[0] % 3] += 1
dp[0][b[0] % 3] += 1
# 填充 dp 数组
for i in range(1, n): # 遍历卡牌
for j in range(3): # 遍历余数
if dp[i - 1][j] > 0: # 如果有组合数
# 正面朝上
new_mod = (j + a[i]) % 3
dp[i][new_mod] = (dp[i][new_mod] + dp[i - 1][j]) % MOD
# 反面朝上
new_mod = (j + b[i]) % 3
dp[i][new_mod] = (dp[i][new_mod] + dp[i - 1][j]) % MOD
# 返回余数为 0 的方案数
return dp[n - 1][0]
# 测试用例
if __name__ == '__main__':
print(solution(3, [1, 2, 3], [2, 3, 2])) # 输出 3
print(solution(4, [3, 1, 2, 4], [1, 2, 3, 1])) # 输出 6
print(solution(5, [1, 2, 3, 4, 5], [1, 2, 3, 4, 5])) # 输出 32
核心逻辑拆解
示例1分析
- 输入: n=3,a=[1,2,3],b=[2,3,2]n = 3, a = [1, 2, 3], b = [2, 3, 2]
- 目标: 数字和模 3 为 0 的组合数。
初始化:
- dp[0][1]=1dp[0][1] = 1: 第一张卡牌正面朝上。
- dp[0][2]=1dp[0][2] = 1: 第一张卡牌反面朝上。
迭代过程:
-
第 2 张卡牌 (a[1]=2,b[1]=3a[1] = 2, b[1] = 3) :
- dp[1][0]dp[1][0]: 通过 dp[0][1]+2mod 3=0dp[0][1] + 2 \mod 3 = 0,或 dp[0][2]+3mod 3=0dp[0][2] + 3 \mod 3 = 0。
- dp[1][1]dp[1][1]: 通过 dp[0][1]+3mod 3=1dp[0][1] + 3 \mod 3 = 1,或 dp[0][2]+2mod 3=1dp[0][2] + 2 \mod 3 = 1。
-
第 3 张卡牌 (a[2]=3,b[2]=2a[2] = 3, b[2] = 2) :
- dp[2][0]dp[2][0]: 模 3 为 0 的组合数累加。
最终结果:
- dp[2][0]=3dp[2][0] = 3。
时间与空间复杂度
-
时间复杂度:
O(n×3)=O(n)O(n \times 3) = O(n)
每张卡牌需要遍历模 3 的 3 个状态。
-
空间复杂度:
O(n×3)=O(n)O(n \times 3) = O(n)
使用 dpdp 数组存储状态。
优化: 滚动数组
可以用滚动数组优化空间复杂度,将 dp 的空间复杂度从 O(n×3)O(n \times 3) 降至 O(3)O(3)。
def solution_optimized(n: int, a: list, b: list) -> int:
MOD = 10**9 + 7
dp = [0] * 3
# 初始化
dp[a[0] % 3] += 1
dp[b[0] % 3] += 1
# 更新 dp
for i in range(1, n):
next_dp = [0] * 3
for j in range(3):
if dp[j] > 0:
next_dp[(j + a[i]) % 3] = (next_dp[(j + a[i]) % 3] + dp[j]) % MOD
next_dp[(j + b[i]) % 3] = (next_dp[(j + b[i]) % 3] + dp[j]) % MOD
dp = next_dp
return dp[0]
# 测试用例
if __name__ == '__main__':
print(solution_optimized(3, [1, 2, 3], [2, 3, 2])) # 输出 3
print(solution_optimized(4, [3, 1, 2, 4], [1, 2, 3, 1])) # 输出 6
print(solution_optimized(5, [1, 2, 3, 4, 5], [1, 2, 3, 4, 5])) # 输出 32
优化后的复杂度
- 时间复杂度: O(n)O(n)
- 空间复杂度: O(3)=O(1)O(3) = O(1)。