区间DP - 石子合并

111 阅读2分钟
  • 问题背景

Snipaste_2023-03-24_11-01-36.png

  • Dp分析

Snipaste_2023-03-24_11-03-06.png

Snipaste_2023-03-24_11-04-05.png

  • 状态表示

    • 二维状态表示f(i,j)
    • f(i,j)表示的是哪一个集合:所有满足如下条件的集合
      • 所有将第i堆石子和第j堆石子合并成一堆石子的合并方式
    • f(i,j)存的是什么属性:MaxMin数量;在这里f(i,j)存的应该是最小值,将第i堆石子和第j堆石子合并成一堆石子的最小代价
  • 状态计算:f(i,j)可以怎么算出来?

    • 最后一次一定是两堆石子合并,所有这里可以将集合分为n类,其中n = j - i - 1
      • 最后一次合并的是i ~ k这一堆和k + 1 ~ j这一堆,k从i取到j - 1
        • f(i, j) = f(i, k) + f(k + 1, j) + s( j ) - s( i - 1)
    • 因此 f(i, j) = min( f(i, k) + f(k + 1, j) + s( j ) - s( i - 1) )
  • 代码

    • //按照区间长度来枚举 区间长度为1就不用合并 代价为0
      for (int len = 2; len <= n; len++) {
          //i表示区间左端点
          for (int i = 1; i + len - 1 <= n; i++) {
              //由区间左端点和区间长度算出区间右端点
              int l = i, r = i + len - 1;
              //初始化
              f[l][r] = 0x3f3f3f3f;
              for (int k = l; k < r; k++) {
                  f[l][r] = Math.min(f[l][r], f[l][k] + f[k + 1][r] + s[r] - s[l - 1]);
              }
          }
      }
      

练习

01 石子合并

  • 题目

Snipaste_2023-03-24_11-42-23.png

  • 题解
import java.io.*;
import java.util.*;

public class Main {
    public static final int N = 310;
    public static int[] s = new int[N];
    public static int[][] f = new int[N][N];
    public static int n;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        PrintWriter pw = new PrintWriter(new OutputStreamWriter(System.out));
        n = Integer.parseInt(br.readLine());
        String[] str1 = br.readLine().split(" ");
        for (int i = 1; i <= n; i++) {
            s[i] = Integer.parseInt(str1[i - 1]);
        }
        for (int i = 1; i <= n; i++) {
            s[i] += s[i - 1];
        }

        //按照区间长度来枚举 区间长度为1就不用合并 代价为0
        for (int len = 2; len <= n; len++) {
            //i表示区间左端点
            for (int i = 1; i + len - 1 <= n; i++) {
                //由区间左端点和区间长度算出区间右端点
                int l = i, r = i + len - 1;
                //初始化
                f[l][r] = 0x3f3f3f3f;
                for (int k = l; k < r; k++) {
                    f[l][r] = Math.min(f[l][r], f[l][k] + f[k + 1][r] + s[r] - s[l - 1]);
                }
            }
        }
        pw.println(f[1][n]);
        br.close();
        pw.close();
    }
}