246小U走排列问题 nlog(n) 做法!!

101 阅读3分钟

题解

问题描述

在数轴上有 n 个点 a[1], a[2], ..., a[n],小 U 初始位于原点。她希望按照一定的顺序访问这些点。需要计算在所有不同的访问顺序中,走过的路径的总和。每种顺序对应的路径长度等于她从原点出发依次访问这些点所走的距离之和。最终结果对 (10^9 + 7) 取模。

这题计算贡献的思想和 236 小 U 的数组权值计算问题有一定的相似之处。感兴趣的可以去做一下

解题思路

为了计算所有排列顺序下路径长度的总和,可以分解路径长度的贡献:

  1. 每个点作为起点的贡献

    • 每个点 a[i] 被访问为第一个点时,路径长度增加 |a[i] - 0| = |a[i]|
    • 由于有 (n-1)! 种排列方式使得 a[i] 是第一个点,因此总贡献为 a[i] * (n-1)!
  2. 每对相邻点之间的贡献

    • 对于每一对不同相邻的点 (a[i], a[j])|a[i] - a[j]| 会在不同的排列中多次出现。
    • 排列中一共有n - 1个可以这对相邻点的位置,并且放好这对相邻点之后,剩下的 n - 2 个点的排列方式有 (n-2)! 种。
    • 因此,这对相邻点之间的贡献为 |a[i] - a[j]| * (n-2)! * (n-1)。我们可以枚举所有的点对,计算它们的贡献。

代码实现

mod = int(1e9 + 7)

def solution(n: int, a: list) -> int:
    if n == 1:
        return a[0]
    res = 0
    f = 1
    # 计算 (n-2)!
    for i in range(1, n - 1): f = f * i % mod
    for i in range(n):
        res += a[i] * f * (n - 1)    # a[i] * (n-1)!
        for j in range(n):
            if i != j:
                # |a[i] - a[j]| * (n-2)! * (n-1)
                res += abs(a[i] - a[j]) * f * (n - 1) % mod
    return res % mod

if __name__ == '__main__':
    print(solution(3, [1, 3, 5]) == 50)
    print(solution(4, [1, 2, 4, 7]) == 324)
    print(solution(2, [2, 6]) == 16)

复杂度分析: 时间复杂度为 O(n2)O(n^2),空间复杂度为 O(1)O(1)

  1. 优化
    • 注意到点a[i] 作为点对中的第一个点的贡献为f * (n - 1) * (a[i] 到其他所有点对的距离和),因此可以先计算出a[i]到其他所有点的距离和,然后再计算a[i]的贡献。
    • 如何求a[i]到其他所有点的距离和?可以先对所有点排序,计算前缀和。 那么a[i]到其他所有点的距离和为a[i] * 前面的点个数 - 前面的区间和 + 后面的区间和 - a[i] * 后面的点个数。

代码实现如下:

from itertools import accumulate
mod = int(1e9 + 7)
def getSumAbsoluteDifferences(nums):
    n = len(nums)
    s = list(accumulate(nums))
    res = [0] * n          # res[i] = 前面的个数 * cur - 前面总和 + 后面的和 - 后面个数 * cur
    for i, x in enumerate(nums):
        res[i] = (i + 1) * x - s[i] + (s[n - 1] - s[i]) - (n - i - 1) * x
        res[i] %= mod
    return res

def solution(n: int, a: list) -> int:
    if n == 1:
        return a[0]
    res = 0
    a.sort()
    dis = getSumAbsoluteDifferences(a)
    f = 1
    # 计算 (n-2)!
    for i in range(1, n - 1): f = f * i % mod
    for i in range(n):
        res = (res + a[i] * f * (n - 1) + dis[i] * (n - 1) * f) % mod
    return res

if __name__ == '__main__':
    print(solution(3, [1, 3, 5]) == 50)
    print(solution(4, [1, 2, 4, 7]) == 324)
    print(solution(2, [2, 6]) == 16)

复杂度分析: 时间复杂度为 O(nlogn)O(n \log n), 瓶颈在排序上 空间复杂度为 O(n)O(n)