2022HBCPC:7-9 优美的字符串

122 阅读3分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第28天,点击查看活动详情

2022HBCPC——7-9 优美的字符串

7-9 优美的字符串 (pintia.cn)

我们称一个字符串是优美的,当且仅当这个字符串中不存在长度严格大于 2 的回文串。

现在有 m 种不同的字符,那么在可以组成的长度恰好为 nm^n 个不同的字符串中,请求出一共有多少个字符串是优美的,输出答案对 1000000007 取模后的结果即可。

Input

输入仅一行两个整数,分别表示 nm(1≤n≤106,1≤m≤109)。

Output

输出仅一行一个整数,表示答案对 1000000007 取模后的结果。

Sample Input 1

3 3

Sample Output 1

18

问题解析

线性dp。

因为我们要求字符串中不存在长度大于2的回文串,那么字符串的状态就分为以下三种:

  1. abc(任意三个连续的字符都不相同);
  2. aab(两个连续字符相同后有一个不相同的);
  3. abb(一个字符后有两个相同的字符)。

那么我们可以设立二维状态转移数组f,f[i] [j]表示:长度为i,且末尾字符串状态为j的字符串有f[i] [j]个。

如何进行状态转移呢?

我们可以思考一下,对于长度为i的字符串,如何得到状态1:

  • 可以由长度为i-1且末尾状态为1的字符串得到。此时只要在尾部接上一个与前面两个字符不相同的字符就行,而因为前面两个字符也不相同,所以我们这个新增的字符就有m-2种选择;
  • 可以由长度为i-1且末尾状态为2的字符串得到。此时只要在尾部接上一个与前面两个字符不相同的字符就行,而因为前面两个字符也不相同,所以我们这个新增的字符就有m-2种选择;

所以,对于状态1,我们的状态转移方程为:**f[i] [1] = f[i-1] [1](m-2) + f[i-1] [2] (m-2);

对于长度为i的字符串,如何得到状态2:

  • 可以由长度为i-1且末尾状态为3的字符串得到。首先这个新增的字符不能和前两个字符一样,所以有m-1种选择。但是!他还不能和倒数第三个字符一样,如果一样了,就会变成abba的情况,这是一个长度为4的回文串。所以实际上,我们只有m-2种选择。

所以,对于状态2,我们的状态转移方程为:f[i] [2] = f[i-1] [3]*(m-2);

对于长度为i的字符串,如何得到状态3:

  • 可以由长度为i-1且末尾状态为1的字符串得到。这个新增的字符要和最后一个字符一样,所以只有1种选择;
  • 可以由长度为i-1且末尾状态为2的字符串得到。这个新增的字符要和最后一个字符一样,所以只有1种选择;

所以,对于状态2,我们的状态转移方程为:f[i] [3] = f[i-1] [2] + f[i-1] [1];

AC代码

#include<iostream>
using namespace std;
#include<vector>
#include<algorithm>
#include<math.h>
#include<set>
#include <random>
#include<numeric>
#include<string>
#include<string.h>
#include<iterator>
#include<fstream>
#include<map>
#include<unordered_map>
#include<stack>
#include<list>
#include<queue>
#include<iomanip>
#include<bitset>#pragma GCC optimize(2)
#pragma GCC optimize(3)#define endl '\n'
#define int ll
#define PI acos(-1)
#define INF 0x3f3f3f3f
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll, ll>PII;
const int N = 1e6 + 50, MOD = 1e9 + 7;
​
int f[N][4];
void solve()
{
    int n, m;
    cin >> n >> m;
    f[1][1] = m;
    f[2][1] = (m * m) % MOD;
    f[3][1] = (((m * (m - 1)) % MOD) * (m - 2)) % MOD;
    f[3][2] = (m * (m - 1)) % MOD;
    f[3][3] = f[3][2];
    for (int i = 4; i <= n; i++)
    {
        f[i][1] = ((m - 2) * f[i - 1][1] + (m - 2) * f[i - 1][2]) % MOD;
        f[i][2] = ((m - 2) * f[i - 1][3]) % MOD;
        f[i][3] = (f[i - 1][1] + f[i - 1][2]) % MOD;
    }
    cout << (f[n][1] + f[n][2] + f[n][3]) % MOD;
}
​
signed main()
{
​
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int t = 1;
    //cin >> t;
​
    while (t--)
    {
        solve();
    }
    return 0;
}