2026-04-15:使数组元素相等的最小操作次数。用go语言,给定一个整数数组 nums 和一个整数 k。在一次操作中,你可以选择数组中某个元素,把它增加或减少 k(每次恰好改动 k 一次)。
另外给定若干个查询 queries,其中每个查询是一个区间 [li, ri],表示取 nums 中从 li 到 ri 的连续子数组。
对每个查询,你需要计算:把该子数组中的所有元素都变成同一个数所需的最少操作次数。如果无法通过上述操作让它们全部相等,则该查询结果为 -1。
最终返回一个数组 ans,其第 i 个元素表示第 i 个查询的最小操作次数(或 -1)。
1 <= n == nums.length <= 40000。
1 <= nums[i] <= 1000000000。
1 <= k <= 1000000000。
1 <= queries.length <= 40000。
queries[i] = [li, ri]。
0 <= li <= ri <= n - 1。
输入: nums = [1,4,7], k = 3, queries = [[0,1],[0,2]]。
输出: [1,2]。
解释:
一种最优操作方式:
| i | [li, ri] | nums[li..ri] | 可行性 | 操作 | 最终 nums[li..ri] | ans[i] |
|---|---|---|---|---|---|---|
| 0 | [0, 1] | [1, 4] | 是(nums[0] + k = 1 + 3 = 4 = nums[1]) | [4, 4] | [4, 4] | 1 |
| 1 | [0, 2] | [1, 4, 7] | 是(nums[0] + k = 4 = nums[1],nums[2] - k = 7 - 3 = 4 = nums[1]) | [4, 4, 4] | [4, 4, 4] | 2 |
因此,ans = [1, 2]。
题目来自力扣3762。
大体过程如下:
- 关闭GC优化:通过
debug.SetGCPercent(-1)关闭垃圾回收,因为代码使用了大量指针结构(可持久化线段树),关闭GC能提升运行效率。 - 预处理同余合法性数组:遍历数组
nums,生成left数组,left[i]记录以i为右端点时,左侧第一个与nums[i]对k取余结果不同的位置;用于快速判断查询区间内所有元素是否同余(能否全部相等)。 - 数值离散化处理:克隆原数组并排序、去重得到
sorted数组,将原数组的大数值映射为连续的小下标,适配线段树的存储和查询范围。 - 构建可持久化线段树:初始化空线段树,遍历原数组每个元素,将元素映射为离散化下标后,逐次插入线段树,生成
n+1个版本的线段树(t[0]到t[n]),每个版本对应前i个元素的状态。 - 处理每个查询:
- 第一步:通过
left数组判断查询区间[l,r]内元素是否全部同余,若不同余直接记结果为-1。 - 第二步:将区间转换为左闭右开格式,适配线段树的区间查询规则。
- 第三步:利用可持久化线段树查询区间内的中位数(使操作次数最少的目标值),同时统计中位数左侧元素数量和元素和。
- 第四步:计算区间所有元素到中位数的总距离,总距离除以
k得到最小操作次数,存入结果数组。
- 第一步:通过
- 返回最终结果:所有查询处理完成后,输出结果数组。
时间复杂度与额外空间复杂度
- 总时间复杂度:,其中
n是数组长度,q是查询数量,M是离散化后数值的个数;离散化、构建线段树的时间为,每个查询的线段树查询时间为,总查询时间为。 - 总额外空间复杂度:,主要用于存储可持久化线段树的所有节点,空间规模与数组长度和线段树深度成正比。
Go完整代码如下:
package main
import (
"fmt"
"runtime/debug"
"slices"
"sort"
)
// 有大量指针的题目,关闭 GC 更快
func init() { debug.SetGCPercent(-1) }
type node struct {
lo, ro *node
l, r int
cnt, sum int
}
func (o *node) maintain() {
o.cnt = o.lo.cnt + o.ro.cnt
o.sum = o.lo.sum + o.ro.sum
}
func build(l, r int) *node {
o := &node{l: l, r: r}
if l == r {
return o
}
mid := (l + r) / 2
o.lo = build(l, mid)
o.ro = build(mid+1, r)
return o
}
// 在线段树的位置 i 添加 val
// 注意这里传的不是指针,会把 node 复制一份,而这正好是我们需要的
func (o node) add(i, val int) *node {
if o.l == o.r {
o.cnt++
o.sum += val
return &o
}
mid := (o.l + o.r) / 2
if i <= mid {
o.lo = o.lo.add(i, val)
} else {
o.ro = o.ro.add(i, val)
}
o.maintain()
return &o
}
// 查询 old 和 o 对应子数组的第 k 小,有多少个数小于第 k 小,这些数的元素和是多少
func (o *node) query(old *node, k int) (int, int, int) {
if o.l == o.r {
return o.l, 0, 0
}
cntL := o.lo.cnt - old.lo.cnt
if k <= cntL {
return o.lo.query(old.lo, k)
}
i, c, s := o.ro.query(old.ro, k-cntL)
sumL := o.lo.sum - old.lo.sum
return i, cntL + c, sumL + s
}
func minOperations(nums []int, k int, queries [][]int) []int64 {
n := len(nums)
left := make([]int, n)
for i := 1; i < n; i++ {
if nums[i]%k != nums[i-1]%k {
left[i] = i
} else {
left[i] = left[i-1]
}
}
// 准备离散化
sorted := slices.Clone(nums)
slices.Sort(sorted)
sorted = slices.Compact(sorted)
t := make([]*node, n+1)
t[0] = build(0, len(sorted)-1)
for i, x := range nums {
j := sort.SearchInts(sorted, x) // 离散化
t[i+1] = t[i].add(j, x) // 构建可持久化线段树
}
ans := make([]int64, len(queries))
for qi, q := range queries {
l, r := q[0], q[1]
if left[r] > l { // 无解
ans[qi] = -1
continue
}
r++ // 改成左闭右开,方便计算
// 计算区间中位数
sz := r - l
i, cntLeft, sumLeft := t[r].query(t[l], sz/2+1)
median := sorted[i] // 离散化后的值 -> 原始值
// 计算区间所有元素到中位数的距离和
total := t[r].sum - t[l].sum // 区间元素和
sum := median*cntLeft - sumLeft // 蓝色面积
sum += total - sumLeft - median*(sz-cntLeft) // 绿色面积
ans[qi] = int64(sum / k) // 操作次数 = 距离和 / k
}
return ans
}
func main() {
nums := []int{1, 4, 7}
k := 3
queries := [][]int{{0, 1}, {0, 2}}
result := minOperations(nums, k, queries)
fmt.Println(result)
}
Python完整代码如下:
# -*-coding:utf-8-*-
import sys
import bisect
from functools import lru_cache
# 禁用GC(Python中不需要显式设置,但可通过gc.disable()实现类似效果)
# import gc
# gc.disable()
class Node:
__slots__ = ('lo', 'ro', 'l', 'r', 'cnt', 'sum')
def __init__(self, l, r):
self.lo = None
self.ro = None
self.l = l
self.r = r
self.cnt = 0
self.sum = 0
def build(l, r):
"""构建线段树"""
o = Node(l, r)
if l == r:
return o
mid = (l + r) // 2
o.lo = build(l, mid)
o.ro = build(mid + 1, r)
return o
def add(node, i, val, sorted_arr):
"""在线段树的位置i添加val(可持久化)"""
# 创建新节点
new_node = Node(node.l, node.r)
new_node.cnt = node.cnt
new_node.sum = node.sum
if node.l == node.r:
new_node.cnt += 1
new_node.sum += val
return new_node
mid = (node.l + node.r) // 2
if i <= mid:
# 复制右子节点,更新左子节点
new_node.ro = node.ro
new_node.lo = add(node.lo, i, val, sorted_arr)
else:
# 复制左子节点,更新右子节点
new_node.lo = node.lo
new_node.ro = add(node.ro, i, val, sorted_arr)
# 维护节点信息
new_node.cnt = new_node.lo.cnt + new_node.ro.cnt
new_node.sum = new_node.lo.sum + new_node.ro.sum
return new_node
def query(o, old, k):
"""查询第k小,返回(索引, 左侧元素个数, 左侧元素和)"""
if o.l == o.r:
return o.l, 0, 0
cntL = o.lo.cnt - old.lo.cnt
if k <= cntL:
return query(o.lo, old.lo, k)
idx, cnt, s = query(o.ro, old.ro, k - cntL)
sumL = o.lo.sum - old.lo.sum
return idx, cntL + cnt, sumL + s
def min_operations(nums, k, queries):
"""
nums: 输入数组
k: 每次操作可以增减的值
queries: 查询区间列表,每个查询为[l, r]
"""
n = len(nums)
# 计算left数组:记录每个位置左侧最近的不满足模k同余的位置
left = [0] * n
for i in range(1, n):
if nums[i] % k != nums[i-1] % k:
left[i] = i
else:
left[i] = left[i-1]
# 准备离散化
sorted_arr = sorted(set(nums))
val_to_idx = {v: i for i, v in enumerate(sorted_arr)}
# 构建可持久化线段树
roots = [None] * (n + 1)
roots[0] = build(0, len(sorted_arr) - 1)
for i, x in enumerate(nums):
idx = val_to_idx[x]
roots[i + 1] = add(roots[i], idx, x, sorted_arr)
ans = []
for l, r in queries:
# 检查区间内所有元素模k是否同余
if left[r] > l:
ans.append(-1)
continue
sz = r - l + 1
# 查询中位数(第sz//2+1小的元素)
idx, cnt_left, sum_left = query(roots[r + 1], roots[l], sz // 2 + 1)
median = sorted_arr[idx]
# 计算区间所有元素到中位数的距离和
total_sum = roots[r + 1].sum - roots[l].sum
# 蓝色面积:median * cnt_left - sum_left
blue = median * cnt_left - sum_left
# 绿色面积:total_sum - sum_left - median * (sz - cnt_left)
green = total_sum - sum_left - median * (sz - cnt_left)
total_distance = blue + green
ans.append(total_distance // k)
return ans
def main():
nums = [1, 4, 7]
k = 3
queries = [[0, 1], [0, 2]]
result = min_operations(nums, k, queries)
print(result)
if __name__ == "__main__":
main()
C++完整代码如下:
#include <iostream>
#include <vector>
#include <algorithm>
#include <memory>
#include <climits>
using namespace std;
// 节点结构体
struct Node {
Node* lo;
Node* ro;
int l, r;
int cnt, sum;
Node(int left, int right) : lo(nullptr), ro(nullptr), l(left), r(right), cnt(0), sum(0) {}
void maintain() {
cnt = lo->cnt + ro->cnt;
sum = lo->sum + ro->sum;
}
};
// 构建线段树
Node* build(int l, int r) {
Node* o = new Node(l, r);
if (l == r) {
return o;
}
int mid = (l + r) / 2;
o->lo = build(l, mid);
o->ro = build(mid + 1, r);
return o;
}
// 在线段树的位置i添加val(可持久化)
Node* add(Node* node, int i, int val) {
Node* o = new Node(node->l, node->r);
if (o->l == o->r) {
o->cnt = node->cnt + 1;
o->sum = node->sum + val;
return o;
}
int mid = (o->l + o->r) / 2;
if (i <= mid) {
o->ro = node->ro;
o->lo = add(node->lo, i, val);
} else {
o->lo = node->lo;
o->ro = add(node->ro, i, val);
}
o->maintain();
return o;
}
// 查询第k小
tuple<int, int, int> query(Node* o, Node* old, int k) {
if (o->l == o->r) {
return {o->l, 0, 0};
}
int cntL = o->lo->cnt - old->lo->cnt;
if (k <= cntL) {
return query(o->lo, old->lo, k);
}
auto [idx, cnt, s] = query(o->ro, old->ro, k - cntL);
int sumL = o->lo->sum - old->lo->sum;
return {idx, cntL + cnt, sumL + s};
}
vector<long long> minOperations(vector<int>& nums, int k, vector<vector<int>>& queries) {
int n = nums.size();
// 计算left数组
vector<int> left(n);
for (int i = 1; i < n; i++) {
if (nums[i] % k != nums[i-1] % k) {
left[i] = i;
} else {
left[i] = left[i-1];
}
}
// 准备离散化
vector<int> sorted = nums;
sort(sorted.begin(), sorted.end());
sorted.erase(unique(sorted.begin(), sorted.end()), sorted.end());
// 构建可持久化线段树
vector<Node*> t(n + 1);
t[0] = build(0, sorted.size() - 1);
for (int i = 0; i < n; i++) {
int j = lower_bound(sorted.begin(), sorted.end(), nums[i]) - sorted.begin();
t[i+1] = add(t[i], j, nums[i]);
}
vector<long long> ans(queries.size());
for (int qi = 0; qi < queries.size(); qi++) {
int l = queries[qi][0];
int r = queries[qi][1];
if (left[r] > l) {
ans[qi] = -1;
continue;
}
r++; // 改成左闭右开
// 计算区间中位数
int sz = r - l;
auto [idx, cntLeft, sumLeft] = query(t[r], t[l], sz/2 + 1);
int median = sorted[idx];
// 计算区间所有元素到中位数的距离和
int total = t[r]->sum - t[l]->sum;
long long sum = 1LL * median * cntLeft - sumLeft;
sum += total - sumLeft - 1LL * median * (sz - cntLeft);
ans[qi] = sum / k;
}
// 清理内存(可选)
for (auto node : t) {
// 这里可以添加递归删除节点的代码,但为了简洁省略
// 在实际项目中应该正确释放内存
}
return ans;
}
int main() {
vector<int> nums = {1, 4, 7};
int k = 3;
vector<vector<int>> queries = {{0, 1}, {0, 2}};
vector<long long> result = minOperations(nums, k, queries);
cout << "[";
for (int i = 0; i < result.size(); i++) {
cout << result[i];
if (i < result.size() - 1) cout << " ";
}
cout << "]" << endl;
return 0;
}