树状数组解决点对计数问题

114 阅读2分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第25天,点击查看活动详情

树状数组可以解决很多实际的问题,并且在一些场景下用于做优化时间复杂度处理,有一类题,一般给定一限制条件,然后让你去求给定条件下,符合要求的点对有多少个。比如下面这道题:


题目

给定 n 个两两不同的正整数 a1,a2,,an给定 n 个两两不同的正整数 a_1,a_2,…,a_n。

请你计算共有多少个三元组 (i,j,k) 能够同时满足: 请你计算共有多少个三元组 (i,j,k) 能够同时满足:

  • i<j<ki<j<k
  • ai>aj>aka_i>a_j>a_k

解法

要求找出符合条件的三元组有多少个,若是在 数据 nn不是很大的时候,可以暴力做,但是数据量大时则不行。

朴素做法

三种循环暴力for,时间复杂度O(N3)O(N^3)

树状数组优化

在枚举的时候,其实枚举 j的位置即可,不需要枚举i和k的部分

假设现在枚举到了 aja_j,左边比aja_j大的有xx个,右边比aja_j小的有yy

那么,当jj固定时 总数就是 x×yx \times y 个,

于是,如何快速查找比aja_j大的和比aja_j小的就成为了关注的点

我们可以使用一个for循环去扫描,这样的时间复杂度是O(n)O(n)的,但是j要枚举n次,总体时间复杂度为O(n2)O(n^2)

不失为一种办法,但不是最优

如何快速查找呢?这个时候利用树状数组就可以很好的解决这一件事情,把值映射到下标,求前缀和,就知道了 比 aja_j小的和大的值了

对于动态更新的话,也只需要lognlogn的时间来更新,于是,总体时间复杂度为O(nlogn)O(nlogn)

代码

#include <bits/stdc++.h>
typedef int i32;
typedef long long i64;
#define int i64
template <typename... T> void read(T&... arg) {((std::cin >> arg), ...);}
template <typename... T> void out(const T&... arg) {((std::cout << arg), ...);}
template <char a = ' ', char b = '\n', typename... T> void bug(const T&... arg) {
    int i = sizeof...(arg);
    char s[2];s[0] = a, s[1] = b;
    ((std::cout << arg << s[!--i]), ...);
}

using namespace std;
const int MOD = 1e9 + 7;
i32 main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    int n;
    cin >> n;
    vector<int> ve(n), s(n);
    for (int i = 0; i < n; i ++) 
      cin >> ve[i], s[i] = ve[i];
    sort(s.begin(), s.end());
    s.erase(unique(s.begin(), s.end()), s.end());
    for (int i = 0; i < n; i ++)
      ve[i] = lower_bound(s.begin(), s.end(), ve[i]) - s.begin() + 1;
    int m = s.size();
    vector<int> l(m + 1, 0), r(m + 1, 0);
    #define low(x) ((x)&(-x))
    auto get = [&](vector<int>& g, int x) {
      int res = 0;
      for (int i = x; i; i -= low(i))
        res += g[i];
      return res;
    };
    auto add = [&](vector<int>& g, int x, int v) {
      for (int i = x; i <= m; i += low(i)) 
        g[i] += v;
    };
    int ans = 0;
    for (int i = 0; i < n; i ++) 
      add(r, ve[i], 1);
    for (int i = 0; i < n; i ++) {
      add(r, ve[i], -1);
      int a = get(l, m) - get(l, ve[i]);
      int b = get(r, ve[i] - 1);
      ans += a * b;
      add(l, ve[i], 1);
    }
    cout << ans << endl;
    return 0;
}