线性DP - 最长公共子序列

154 阅读2分钟

最长公共子序列

  • 问题背景
    • 给定两个长度分别为N和M的字符串A和B,求既是A的子序列又是B的子序列的字符串长度最长是多少。(子序列不一定连续)
  • Dp分析

Snipaste_2023-03-23_19-40-27.png

  • 状态表示

    • 二维的状态表示f(i,j)
    • f(i, j)表示的是哪一个集合:所有满足如下条件的集合
      • 所有在第一个序列中的前i个字母中出现,且在第二个序列中的前j个字母中出现的公共子序列
    • f(i, j)存的是什么属性:MaxMin数量;在这里f(i,j)存的应该是最大值,即所有满足这种条件的最长的公共子序列的长度
  • 状态计算:f(i,j)可以怎么算出来?

    • 这里可以将集合分为四个子集
      • 子序列中不包含A中的第 i 个字符和B中的第 j 个字符
        • f(i, j) = f(i - 1, j - 1)
      • 子序列中不包含A中的第 i 个字符但包含B中的第 j 个字符
        • f(i, j) = f(i - 1, j)
        • 注意,这里的 f(i, j) = f(i - 1, j) 的含义是所有在第一个序列中的前i-1个字母中出现,且在第二个序列中的前j个字母中出现的最长公共子序列,这样的 f(i ,j) 并不一定满足子序列中不包含A中的第 i 个字符但包含B中的第 j 个字符,但是这种情况一定在这样的 f(i ,j) 中, f(i ,j) 这样子表示会出现重复,但由于题目求的是最大值,因此重复也无所谓
      • 子序列中包含A中的第 i 个字符但不包含B中的第 j 个字符
        • f(i, j) = f(i, j - 1)
        • 和第二个子集类似
      • 子序列中包含A中的第 i 个字符和B中的第 j 个字符
        • f(i, j) = f(i - 1, j - 1) + 1
    • 注意,第一个子集 ∈ ( 第二个子集 ∩ 三个子集 ),因此 f(i, j) = max( f(i - 1, j), f(i, j - 1), f(i - 1, j - 1) + 1 )
  • 代码

    • for (int i = 1; i <= n; i++) {
          for (int j = 1; j <= m; j++) {
              f[i][j] = Math.max(f[i - 1][j], f[i][j - 1]);
              if (a[i] == b[j]) {
                  f[i][j] = Math.max(f[i][j], f[i - 1][j - 1] + 1);
              }
          }
      }
      

练习

01 最长公共子序列

  • 题目

Snipaste_2023-03-23_21-06-09.png

  • 题解
import java.io.*;
import java.util.*;

public class Main {
    public static final int N = 1010;
    public static char[] a = new char[N];
    public static char[] b = new char[N];
    public static int[][] f = new int[N][N];
    public static int n, m;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        PrintWriter pw = new PrintWriter(new OutputStreamWriter(System.out));
        String[] str1 = br.readLine().split(" ");
        n = Integer.parseInt(str1[0]);
        m = Integer.parseInt(str1[1]);
        String str2 = br.readLine();
        for (int i = 0; i < n; i++) {
            a[i + 1] = str2.charAt(i);
        }
        String str3 = br.readLine();
        for (int i = 0; i < m; i++) {
            b[i + 1] = str3.charAt(i);
        }

        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= m; j++) {
                f[i][j] = Math.max(f[i - 1][j], f[i][j - 1]);
                if (a[i] == b[j]) {
                    f[i][j] = Math.max(f[i][j], f[i - 1][j - 1] + 1);
                }
            }
        }
        pw.println(f[n][m]);
        br.close();
        pw.close();
    }
}