算法介绍
给两棵线段树,把它们合并成一棵。
递归地合并这两棵线段树。用两个指针从两棵线段树的根开始同步遍历。对于某个节点,如果或,也就是说这两棵线段树的其中一棵不存在这个点,就直接将非的或作为合并后线段树的新节点。如果对于某个点,两棵树上均存在,就递归地去合并它们的子树,然后更新节点信息即可。此外,对于叶结点,直接合并即可。
例题
P3605 [USACO17JAN]Promotion Counting
思路
虽然用树状数组特别好写,但还是用来练练手。
先把牛牛的能力值离散化,对每个结点都维护一棵权值线段树,从叶节点开始往上合并,把每个子节点的线段树都合并起来以后,直接查询权值线段树上比自己能力值大的个数,然后把自己维护进答案。
代码
#include<bits/stdc++.h>
#define rep(i,st,ed) for(int i=st;i<=ed;++i)
#define bl(u,i) for(int i=head[u];i;i=e[i].nxt)
#define en puts("")
#define LLM LONG_LONG_MAX
#define LLm LONG_LONG_MIN
#define pii pair<ll,ll>
typedef long long ll;
typedef double db;
using namespace std;
const ll INF=0x3f3f3f3f;
void read() {}
void OP() {}
void op() {}
template <typename T, typename... T2>
inline void read(T &_, T2 &... oth)
{
int __=0;
_=0;
char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-')
__=1;
ch=getchar();
}
while(isdigit(ch))
{
_=_*10+ch-48;
ch=getchar();
}
_=__?-_:_;
read(oth...);
}
template <typename T>
void Out(T _)
{
if(_<0)
{
putchar('-');
_=-_;
}
if(_>=10)
Out(_/10);
putchar(_%10+'0');
}
template <typename T, typename... T2>
inline void OP(T _, T2... oth)
{
Out(_);
putchar('\n');
OP(oth...);
}
template <typename T, typename... T2>
inline void op(T _, T2... oth)
{
Out(_);
putchar(' ');
op(oth...);
}
/*#################################*/
const ll N=1E5+10;
ll n,fa,tot,cnt;
ll head[N],ans[N],a[N],root[N];
struct Node{
ll l,r,val;
}t[20*N];
struct Edge{
ll nxt,to,w;
}e[N];
vector<ll> ve;
void add_edge(ll u,ll v,ll w,int flag)
{
e[++tot]=(Edge){head[u],v,w};
head[u]=tot;
if(flag)
add_edge(v,u,w,0);
}
void update(ll p)
{
t[p].val=t[t[p].l].val+t[t[p].r].val;
}
void merge(ll &p,ll q,ll l,ll r)
{
if(!p || !q)
{
p=p|q;
return;
}
if(l==r)
{
t[p].val+=t[q].val;
return;
}
ll mid=(l+r)>>1;
merge(t[p].l,t[q].l,l,mid);
merge(t[p].r,t[q].r,mid+1,r);
update(p);
}
ll query(ll p,ll l,ll r,ll al,ll ar)
{
if(al<=l && ar>=r)
return t[p].val;
ll mid=(l+r)>>1;
ll ret=0;
if(al<=mid)
ret+=query(t[p].l,l,mid,al,ar);
if(ar>mid)
ret+=query(t[p].r,mid+1,r,al,ar);
return ret;
}
void add(ll &p,ll l,ll r,ll pos,ll val)
{
if(!p)
p=++cnt;
if(l==r)
{
t[p].val+=val;
return;
}
ll mid=(l+r)>>1;
if(pos<=mid)
add(t[p].l,l,mid,pos,val);
else
add(t[p].r,mid+1,r,pos,val);
update(p);
}
void dfs(ll u)
{
root[u]=++cnt;
bl(u,i)
{
ll v=e[i].to;
dfs(v);
merge(root[u],root[v],1,n);
}
ans[u]=query(root[u],1,n,a[u]+1,n);
add(root[u],1,n,a[u],1);
}
int main()
{
read(n);
rep(i,1,n)
{
read(a[i]);
ve.emplace_back(a[i]);
}
rep(i,2,n)
{
read(fa);
add_edge(fa,i,0,0);
}
sort(ve.begin(),ve.end());
ve.erase(unique(ve.begin(),ve.end()),ve.end());
rep(i,1,n)
a[i]=lower_bound(ve.begin(),ve.end(),a[i])-ve.begin()+1;
dfs(1);
rep(i,1,n)
OP(ans[i]);
}
P3521 [POI2011]ROT-Tree Rotations
思路
首先可以发现交换左右子树仅影响左右子树之间的逆序对,而不影响子树内部和祖先的子树之间的逆序对数量,因此对于每个个节点,可以贪心地取交换其左右子树和不交换中逆序对最小的值。
如何快速地求出左右子树之间逆序对的个数呢?对每个结点建一棵权值线段树,在合并节点左右子树对应的线段树的同时维护该节点贡献的答案。假设指向左子树对应的线段树,指向右,因为树中的数字在原树当中的下标必然小于中的,因此交换前的逆序对个数就是右子树的值和左子树的值的乘积,交换后的就是左乘右。
代码
#include<bits/stdc++.h>
#define rep(i,st,ed) for(int i=st;i<=ed;++i)
#define bl(u,i) for(int i=head[u];i;i=e[i].nxt)
#define en puts("")
#define LLM LONG_LONG_MAX
#define LLm LONG_LONG_MIN
#define pii pair<ll,ll>
typedef long long ll;
typedef double db;
using namespace std;
const ll INF=0x3f3f3f3f;
void read() {}
void OP() {}
void op() {}
template <typename T, typename... T2>
inline void read(T &_, T2 &... oth)
{
int __=0;
_=0;
char ch=getchar();
while(!isdigit(ch))
{
if(ch=='-')
__=1;
ch=getchar();
}
while(isdigit(ch))
{
_=_*10+ch-48;
ch=getchar();
}
_=__?-_:_;
read(oth...);
}
template <typename T>
void Out(T _)
{
if(_<0)
{
putchar('-');
_=-_;
}
if(_>=10)
Out(_/10);
putchar(_%10+'0');
}
template <typename T, typename... T2>
inline void OP(T _, T2... oth)
{
Out(_);
putchar('\n');
OP(oth...);
}
template <typename T, typename... T2>
inline void op(T _, T2... oth)
{
Out(_);
putchar(' ');
op(oth...);
}
/*#################################*/
const ll N=2E5+10;
ll n,ans,ans1,ans2,cnt;
struct Node
{
ll l,r,val;
}t[20*N];
void add(ll &p,ll l,ll r,ll val)
{
if(!p)
p=++cnt;
++t[p].val;
if(l==r)
return;
ll mid=(l+r)>>1;
if(val<=mid)
add(t[p].l,l,mid,val);
else
add(t[p].r,mid+1,r,val);
}
void merge(ll &p,ll q)
{
if(!p || !q)
{
p=p|q;
return;
}
t[p].val+=t[q].val;
ans1+=t[t[p].r].val*t[t[q].l].val;
ans2+=t[t[p].l].val*t[t[q].r].val;
merge(t[p].l,t[q].l);
merge(t[p].r,t[q].r);
}
void dfs(ll &u)
{
ll lc,rc,x;
u=0;
read(x);
if(!x)
{
dfs(lc);
dfs(rc);
u=lc;
ans1=ans2=0;
merge(u,rc);
ans+=min(ans1,ans2);
}
else
add(u,1,n,x);
}
int main()
{
read(n);
ll tmp=0;
dfs(tmp);
OP(ans);
}