【算法】【线段树】【线段树优化】

1 阅读2分钟

atcoder.jp/contests/ty…

线段树存最大值

          根节点 (max(整个区间))
         /         \
左半区间max    右半区间max
      /   \        /   \
   max    max    max    max
   / \    / \    / \    / \
  1  3   5  2   4  6   8  7

#include <bits/stdc++.h>  
using namespace std;  
using ll=long long;  
//线段数类:维护区间最大值  
class RM{  
public :  
    int size_=1;  
    vector<ll>dat;  
    void init(int sz){  
        while(size_<=sz)size_*=2;  
        dat.resize(size_*2,-(1ll<<60));  
    }void update(int pos,ll x){//一次只改一个下标  
        pos+=size_;  
        dat[pos]=x;  
        while(pos>=2){  
            pos>>=1;  
            dat[pos]=max(dat[pos*2],dat[pos*2+1]);  
        }  
    }ll qu(int l,int r,int a,int b,int u){  
        if(l<=a&&b<=r)return dat[u];  
        if(r<=a||b<=l)return -(1ll<<60);  
        ll v1=qu(l,r,a,(a+b)>>1,u*2);  
        ll v2=qu(l,r,(a+b)>>1,b,u*2+1);  
        return max(v1,v2);  
    }ll qu1(int l,int r){  
        return qu(l,r,0,size_,1);  
    }  
};  
ll w,n;  
ll l[1<<18],r[1<<18],v[1<<18];  
ll dp[509][10009];//i,j用i道菜,j香料的最大值  
RM Z[509];//z[i]=第i层dp的线段树  
void so(){  
      
}int main(){  
    ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);  
    cin>>w>>n;  
    for(int i=1;i<=n;i++)cin>>l[i]>>r[i]>>v[i];  
    for(int i=0;i<=n;i++){  
        for(int j=0;j<=w;j++){  
            dp[i][j]=-(1ll<<60);  
        }Z[i].init(w+2);//给每层建  
    }dp[0][0]=0;  
    Z[0].update(0,0);  
    for(int i=1;i<=n;i++){//每道菜  
        for(int j=0;j<=w;j++){  
            dp[i][j]=dp[i-1][j];  
        }//要凑j上层j-使用量  
        //使用量:[l[i],r[i]]  
        //上层[j-r[i],j-l[i]]  
        for(int j=0;j<=w;j++){  
            int cl=max(0,j-(int)r[i]);  
            int cr=max(0,j-(int)l[i]+1);  
            if(cl==cr)continue;  
            ll val=Z[i-1].qu1(cl,cr);  
            if(val!=-(1ll<<60)){  
                dp[i][j]=max(dp[i][j],val+v[i]);  
            }  
        }for(int j=0;j<=w;j++)Z[i].update(j,dp[i][j]);//当前层存入线段树  
    }if(dp[n][w]==-(1ll<<60)){cout<<-1<<'\n';}  
    else cout<<dp[n][w]<<'\n';  
}

原理

相当于背包放的每一样东西都有个区间,找转移过来的区间的最大值

原来是j找j-z[i],现在找[j-r[i],j-l[i]]的最大值

原来的 DP 每一步都要在一段区间里找最大值,暴力找要花很多时间;线段树可以快速查区间最大,所以整体变快了。

就像这样:

  • 暴力:从头到尾扫一遍 → O (W)
  • 线段树:树里一跳就找到 → O (log W)

竞赛里 W 可以到 1e4,W² 就是 1e8,会 TLE;但 W log W 只有 1e4 * 14 ≈ 1.4e5,稳稳 AC。

时间复杂度

O(N × W × log W)

for(int i=1;i<=n;i++){           // N 次

  // 不选第 i 道菜
  for(int j=0;j<=w;j++)          // W 次
    dp[i][j] = dp[i-1][j];       // O(1)

  // 选第 i 道菜
  for(int j=0;j<=w;j++){         // W 次
    cl = ...
    cr = ...
    val = Z[i-1].qu1(cl, cr);    // 🔥 O(log W)
    if(...) dp[i][j] = ...       // O(1)
  }

  // 把当前 dp 写入线段树
  for(int j=0;j<=w;j++)          // W 次
    Z[i].update(j, dp[i][j]);    // 🔥 O(log W)
}