洛谷:P6033、合并果子加强版

133 阅读4分钟

本文已参与[新人创作礼]活动,一起开启掘金创作之路。

logo.png

题目描述

P6033 [NOIP2004 提高组] 合并果子 加强版

题目背景

本题除【数据范围与约定】外与 P1090 完 全 一 致

题目描述

在一个果园里,多多已经将所有的果子打了下来,而且按果子的不同种类分成了不同的堆。多多决定把所有的果子合成一堆。

每一次合并,多多可以把两堆果子合并到一起,消耗的体力等于两堆果子的重量之和。可以看出,所有的果子经过 (n−1) 次合并之后, 就只剩下一堆了。多多在合并果子时总共消耗的体力等于每次合并所耗体力之和。

因为还要花大力气把这些果子搬回家,所以多多在合并果子时要尽可能地节省体力。假定每个果子重量都为 11,并且已知果子的种类数和每种果子的数目,你的任务是设计出合并的次序方案,使多多耗费的体力最少,并输出这个最小的体力耗费值。

例如有 33 堆果子,数目依次为 1, 2, 9。可以先将 1、2 堆合并,新堆数目为 3,耗费体力为 3。接着,将新堆与原先的第三堆合并,又得到新的堆,数目为 12,耗费体力为 12。所以多多总共耗费体力为 3+12=15。可以证明 15 为最小的体力耗费值。

输入格式

输入的第一行是一个整数 n,代表果子的堆数。 输入的第二行有 n 个用空格隔开的整数,第 i 个整数代表第 i 堆果子的个数 ai。

输出格式

输出一行一个整数,表示最小耗费的体力值。

样例输入

3 
1 2 9 

样例输出

15

数据范围

本题采用多测试点捆绑测试,共有四个子任务

  • Subtask 1(10 points):1≤n≤8。
  • Subtask 2(20 points):1≤n≤10^3。
  • Subtask 3(30 points):1≤n≤10^5。
  • Subtask 4(40 points):1≤n≤10^7。

对于全部的测试点,保证 1≤ai≤10^5。

问题解析

这题是P1090 [NOIP2004 提高组] 合并果子 的进阶版,做着题之前建议去写写好了解基本的做法,也可以看我之前的题解代码源:844、切割 - 掘金 (juejin.cn) 此时数据量达到了1e7,说明我们的复杂度只能是n或更小(如果有),那么此时显然不能用之前的做法了,毕竟优先队列插入就是logn的复杂度,总复杂度是nlogn,就算是想手动找最小的,但光是排序复杂度就是nlogn了,显然也不行。

要是连排序都排序不了那我们该怎么做呢?

但实际上我们还是可以排序的,题目这里只修改了n的数据大小,没有修改ai的,而ai最大才1e5,这就给了我们机会,我们可以准备一个长度为1e5+1的数组v,每次以元素的值为下标,修改数组v的值,比如元素是2,那就是下标为2的位置,数值+1。举个例子,长度为5的数组v,初始是0 0 0 0 0(下标1开始) ,要排序的元素有1 4 5 2 4 3,那么经过我们操作后就变成了:1 1 1 2 1,即值为1的元素有1个,2的元素有1个……5的元素有1个。这样就排好序了,这样时间复杂度就是On了,我们只用遍历长度为n的元素,然后O1的复杂度修改数组v的值。排序好了后,我们可以准备两个队列,一个a存下我们排好序的v数组(小的先进),一个b初始为空,用来后面存下我们合并后的果子。

我们每次从两个队列中总共取两个果子出来,那边的队头最小就取哪边,合并之后放在准备好的队列b里,这样可以使得两边的队列始终递增有序,一直合并到两个队列总元素为1时结束。

AC代码

#include<iostream>
using namespace std;
#include<vector>
#include<algorithm>
#include<math.h>
#include<set>
#include<numeric>
#include<string>
#include<string.h>
#include<iterator>
#include<map>
#include<unordered_map>
#include<stack>
#include<list>
#include<queue>
#include<iomanip>

#define endl '\n';
typedef long long ll;
typedef pair<ll, ll>PII;
const int N = 1e6;

inline int read() {
    int x = 0; char ch = getchar();
    while (ch < '0' || ch > '9') ch = getchar();
    while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
    return x;
}

inline void write(long long x) {
    if (x > 9) write(x / 10);
    putchar(x % 10 | '0');
}

int main()
{
    ll n, x, ans = 0;
    n = read();
    queue<ll>que,mx;
    vector<ll>v(1e5 + 50);
    for (int i = 0; i < n; i++)
    {
        x = read();
        ans = max(ans, x);
        v[x]++;
    }
    for (int i = 1; i <= ans; i++)
    {
        for (int j = 0; j < v[i]; j++)
        {
            que.push(i);
        }
    }
    ll res = 0;
    while (que.size() + mx.size() != 1)
    {
        ans = 0;
        if (!mx.empty())
        {
            if (que.empty())
            {
                ans += mx.front();
                mx.pop();
            }
            else if (mx.front() < que.front())
            {
                ans += mx.front();
                mx.pop();
            }
            else
            {
                ans += que.front();
                que.pop();
            }
        }
        else
        {
            ans += que.front();
            que.pop();
        }
        if (!mx.empty())
        {
            if (que.empty())
            {
                ans += mx.front();
                mx.pop();
            }
            else if (mx.front() < que.front())
            {
                ans += mx.front();
                mx.pop();
            }
            else
            {
                ans += que.front();
                que.pop();
            }
        }
        else
        {
            ans += que.front();
            que.pop();
        }
        res += ans;
        mx.push(ans);
    }
    write(res);
    
    return 0;
}