洛谷 P3287 [SCOI2014]方伯伯的玉米田【二维树状数组优化DP】

131 阅读2分钟

Offer 驾到,掘友接招!我正在参与2022春招打卡活动,点击查看活动详情

题目链接:洛谷 P3287

题目大意:给定一长度为 nn 的序列,最多可进行最多 kk 次区间整体 +1+1 操作。求 kk 次操作后的最长不下降子序列长度。

题目分析:

我们首先考虑每次拔高的区间。由于要求最长的单调不下降序列,那么每次拔高的右区间应该都是 nn,因为这样才不会使右边的单调关系不受影响。

让我们用以下数据来举个例子:

4 34\ 3

4 2 1 34\ 2\ 1\ 3

如果我们拔高区间 [2,2][2,2] 两次,高度关系就变为了:

4 4 1 34\ 4\ 1\ 3

我们发现虽然第二株玉米已经不小于第一株玉米,但是此时第二株玉米已经高余了第四株玉米,最后的答案是没有变的,而很显然我们可以在拔高区间 [2,2][2,2] 的同时拔高 [3,4][3,4],也就是拔高 [2,4][2,4],因为此时无论拔高 [2,4][2,4] 多少次,第二、三、四株的高度关系都是不变的,所以我们拔高的区间只能是如下:

4 2 1 34\ 2\ 1\ 3

   2 1 3\ \ \ 2\ 1\ 3

      1 3\ \ \ \ \ \ 1\ 3

         3\ \ \ \ \ \ \ \ \ 3

有了这一个结论之后,我们不难定出一个很普通的状态:

f[i][j]f[i][j] 表示前 ii 棵一共拔了 jj 次,答案显然为 f[n][k]f[n][k]

那么如下的暴力状态转移就很容易得到了:

f[i][j]=f[p][q]max+1,p<i,qj,Height[p]+qHeight[i]+jf[i][j] = f[p][q]_{max} + 1,p < i,q\leq j,Height[p]+q \leq Height[i]+j

我们接下来考虑一下优化:

f[i][j]f[i][j] 最终要取到 f[p][q]f[p][q] 中的最大值,所以我们考虑用某种东西记录转移部分的最大值,但是同时我们还需要处理拔高的次数,所以我们在处理一个数组:

g[i][j]g[i][j] 表示一共拔高了 ii 次,其中最高的一株高度为 jj

所以我们变换一下状态转移方程:

f[i][j]=g[p][q]max+1,1pj,1qHeight[i]+jf[i][j]=g[p][q]_{max}+1,1 \leq p\leq j,1\leq q\leq Height[i]+j

但是这样看来,状态转移方程似乎并没有优化,但是真正的优化其实在g[][]g[][]这个数组,因为我们要求的是这个转移区间的最大值,所以我们考虑将g[][]g[][]这个数组当做二维树状数组来进行优化处理最大值。

所有我们在寻找最大值答案的时候不妨直接处理 p,qp,q 的极值(更新树状数组同此),即:

f[i][j]=GetAns(j,Height[i]+j)+1f[i][j]=GetAns(j,Height[i]+j)+1

由于树状数组的 LowBitLowBit 操作,GetAns(j,Height[i]+j)GetAns(j,Height[i]+j) 已经处理完了当前的所有区间

讲到这里,代码其实就很好写了,由于树状数组的下标是不为 00 的,但是这道题当中处理的时候是需要的,所以我们不妨将 jj 强行加一位。

#include <cstdio>
#include <algorithm>

#define ll long long
using namespace std;

const ll M = 505;
const ll N = 10005;

ll bit[M][N];
ll n, m, k, ans, h[N], f[N][M];

int lowbit(int x) {
    return x & (-x);
}

void update(ll x, ll y, ll z) {
    for (ll i = x; i <= k; i += lowbit(i)) {
        for (ll j = y; j <= m; j += lowbit(j)) {
            bit[i][j] = max(bit[i][j], z);
        }
    }
}

ll getans(ll x, ll y) {
    ll ret = 0;
    for (ll i = x; i; i -= lowbit(i)) {
        for (ll j = y; j; j -= lowbit(j)) {
            ret = max(ret, bit[i][j]);
        }
    }
    return ret;
}

int main() {
    scanf("%lld%lld", &n, &k);

    for (int i = 1; i <= n; ++i) {
        scanf("%lld", &h[i]);
        m = max(m, h[i]);
    }

    m += k;
    k++;

    for (int i = 1; i <= n; ++i) {
        for (int j = k; j >= 1; --j) {
            f[i][j] = getans(j, h[i] + j) + 1;
            update(j, h[i] + j, f[i][j]);
            ans = max(ans, f[i][j]);
        }
    }

    printf("%lld\n", ans);

    return 0;
}