[Codeforces] Codeforces Round #848 (Div. 2) D. Flexible String Revisit | 期望dp,数学

228 阅读1分钟

开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 2 天,点击查看活动详情

[Codeforces] Codeforces Round #848 (Div. 2) D. Flexible String Revisit | 期望dp,数学

题目链接

codeforces.com/contest/177…

题目

temp.png

题目大意

给两个01字符串a和b,每次等概率随机对a[i] (1 <= i <= n)做翻转操作,即a[i] = 1 - a[i],求a第一次变成b时的期望操作次数

数据范围:

测试样例组数t (1 <= t <= 1e5)

每组输入: n (1 <= n <= 1e6)

限制:所有样例的n之和 <= 1e6

思路

不难发现,对于a[i]和b[i],相同的情况下操作一次变不同,不同的情况下操作一次变相同

我们只需要关心a和b中有多少个不同的位置

记有x个不同位置时a第一次变成b所需要的翻转操作为f[x]

xn的概率选到不同位置,不同位置变为x1;nxn的概率选到相同位置,不同位置变为x+1个。1<=x<n,  f[x]=xnf[x1]+nxnf[x+1]+1f[x+1]=nnx(f[x]1xnf[x1])有\tfrac{x}{n}的概率选到不同位置,不同位置变为x-1个; \\ 有\tfrac{n-x}{n}的概率选到相同位置,不同位置变为x+1个。 \\ 对1 <= x < n, \ \ f[x] = \tfrac{x}{n} * f[x-1] + \tfrac{n-x}{n} * f[x+1] + 1 \\ 即 f[x+1] = \tfrac{n}{n-x}(f[x] - 1 - \tfrac{x}{n} * f[x-1]) \\

相信上面的式子不太难推,但对于dp的初值,我们自然知道f[0] = 0,可f[1]呢?f[1]对应的式子用来推导f[2]了!最后我们会多出一个f[n]的转移式子,即

f[n]=1+f[n1]f[n] = 1 + f[n-1]

该怎么处理呢?我的想法是,待定系数

令f[i] = c[i] * f[1] + d[i],f[1]待定,我们可以由f[x]的转移式得到c[x]和d[x]的转移式如下:

c[x+1]=nnx(c[x]xnc[x1])d[x+1]=nnx(d[x]1xnd[x1])c[x+1] = \tfrac{n}{n-x} * (c[x] - \tfrac{x}{n}c[x-1]) \\ d[x+1] = \tfrac{n}{n-x} * (d[x] - 1 - \tfrac{x}{n}d[x-1]) \\

这样我们知道c[0] = d[0] = 0,c[1] = 1,d[1] = 0,从而可以dp出c和d数组

再根据f[n] = 1 + f[n-1], 有

c[n]f[1]+d[n]=c[n1]f[1]+d[n1]+1f[1]=d[n1]d[n]1c[n]c[n1]c[n] * f[1] + d[n] = c[n-1] * f[1] + d[n-1] + 1 \\ f[1] = \frac{d[n-1] - d[n] - 1}{c[n] - c[n-1]} \\

求出f[1]以后,我们只需要统计a和b最初的不同位置数diff,就能得到答案

ans=c[diff]f[1]+d[diff]ans = c[diff] * f[1] + d[diff]

完结撒花!

代码

#include<bits/stdc++.h>
#define IOS ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define int long long
using namespace std;
​
const int maxn = 2e5 + 100;
const int mod = 998244353;
const int inf = 1LL << 60;#define db 0
#define test if (db) 
​
int qpow(int a, int n) {
    int r = 1;
    a %= mod;
    for (; n; n >>= 1, a = a * a % mod) if (n & 1) r = r * a % mod;
    return r;
}
​
int inv(int x) {
    return qpow(x, mod-2);
}
​
void solve() {
    int n; 
    cin >> n;
    string a, b;
    cin >> a >> b;
    int diff = 0;
    for (int i = 0; i < n; ++i) diff += (a[i] != b[i]);
    vector<int> c(n+1, 0), d(n+1, 0);
    c[1] = 1; d[1] = 0; // f[1] = 1 * f[1] + 0
    // f[i] = 1 + i / n * f[i-1] + (n-i) / n * f[i+1]
    // f[i+1] = (f[i] - 1 - i / n * f[i-1]) * n / (n-i)
    //        = (c[i] * f[1] + d[i] - 1 - i / n * (c[i-1] * f[1] + d[i-1])) * n / (n-i)
    for (int i = 1; i+1 <= n; ++i) {
        c[i+1] = (c[i] - i * inv(n) % mod * c[i-1] % mod + mod) * n % mod * inv(n-i) % mod;
        d[i+1] = (d[i] - 1 - i * inv(n) % mod * d[i-1] % mod + mod + mod) * n % mod * inv(n-i) % mod;
    }
​
    // f[n] = 1 + f[n-1]
    // c[n] * f[1] + d[n] = 1 + c[n-1] * f[1] + d[n-1]
    // (c[n] - c[n-1]) * f[1] = d[n-1] + 1 - d[n]
​
    int f1 = (d[n-1] + 1 - d[n] + mod) * inv(c[n] - c[n-1] + mod) % mod;
    cout << (c[diff] * f1 + d[diff]) % mod << '\n';
}
​
​
void refresh() {
    
}
    
​
signed main() {
    IOS
    int t;
    t = 1;
    cin >> t;
    while (t--) {
        solve();
        refresh();
    }
    return 0;
}