atcoder.jp/contests/ty…
线段树存最大值
根节点 (max(整个区间))
/ \
左半区间max 右半区间max
/ \ / \
max max max max
/ \ / \ / \ / \
1 3 5 2 4 6 8 7
#include <bits/stdc++.h>
using namespace std;
using ll=long long;
//线段数类:维护区间最大值
class RM{
public :
int size_=1;
vector<ll>dat;
void init(int sz){
while(size_<=sz)size_*=2;
dat.resize(size_*2,-(1ll<<60));
}void update(int pos,ll x){//一次只改一个下标
pos+=size_;
dat[pos]=x;
while(pos>=2){
pos>>=1;
dat[pos]=max(dat[pos*2],dat[pos*2+1]);
}
}ll qu(int l,int r,int a,int b,int u){
if(l<=a&&b<=r)return dat[u];
if(r<=a||b<=l)return -(1ll<<60);
ll v1=qu(l,r,a,(a+b)>>1,u*2);
ll v2=qu(l,r,(a+b)>>1,b,u*2+1);
return max(v1,v2);
}ll qu1(int l,int r){
return qu(l,r,0,size_,1);
}
};
ll w,n;
ll l[1<<18],r[1<<18],v[1<<18];
ll dp[509][10009];//i,j用i道菜,j香料的最大值
RM Z[509];//z[i]=第i层dp的线段树
void so(){
}int main(){
ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
cin>>w>>n;
for(int i=1;i<=n;i++)cin>>l[i]>>r[i]>>v[i];
for(int i=0;i<=n;i++){
for(int j=0;j<=w;j++){
dp[i][j]=-(1ll<<60);
}Z[i].init(w+2);//给每层建
}dp[0][0]=0;
Z[0].update(0,0);
for(int i=1;i<=n;i++){//每道菜
for(int j=0;j<=w;j++){
dp[i][j]=dp[i-1][j];
}//要凑j上层j-使用量
//使用量:[l[i],r[i]]
//上层[j-r[i],j-l[i]]
for(int j=0;j<=w;j++){
int cl=max(0,j-(int)r[i]);
int cr=max(0,j-(int)l[i]+1);
if(cl==cr)continue;
ll val=Z[i-1].qu1(cl,cr);
if(val!=-(1ll<<60)){
dp[i][j]=max(dp[i][j],val+v[i]);
}
}for(int j=0;j<=w;j++)Z[i].update(j,dp[i][j]);//当前层存入线段树
}if(dp[n][w]==-(1ll<<60)){cout<<-1<<'\n';}
else cout<<dp[n][w]<<'\n';
}
原理
相当于背包放的每一样东西都有个区间,找转移过来的区间的最大值
原来是j找j-z[i],现在找[j-r[i],j-l[i]]的最大值
原来的 DP 每一步都要在一段区间里找最大值,暴力找要花很多时间;线段树可以快速查区间最大,所以整体变快了。
就像这样:
- 暴力:从头到尾扫一遍 → O (W)
- 线段树:树里一跳就找到 → O (log W)
竞赛里 W 可以到 1e4,W² 就是 1e8,会 TLE;但 W log W 只有 1e4 * 14 ≈ 1.4e5,稳稳 AC。
时间复杂度
O(N × W × log W)
for(int i=1;i<=n;i++){ // N 次
// 不选第 i 道菜
for(int j=0;j<=w;j++) // W 次
dp[i][j] = dp[i-1][j]; // O(1)
// 选第 i 道菜
for(int j=0;j<=w;j++){ // W 次
cl = ...
cr = ...
val = Z[i-1].qu1(cl, cr); // 🔥 O(log W)
if(...) dp[i][j] = ... // O(1)
}
// 把当前 dp 写入线段树
for(int j=0;j<=w;j++) // W 次
Z[i].update(j, dp[i][j]); // 🔥 O(log W)
}