洛谷P3953 [NOIP 2017 提高组] 逛公园

33 阅读11分钟

原题:P3953 [NOIP 2017 提高组] 逛公园

题面:

P3953 [NOIP 2017 提高组] 逛公园

题目背景

NOIP2017 D1T3

题目描述

策策同学特别喜欢逛公园。公园可以看成一张 NN 个点 MM 条边构成的有向图,且没有 自环和重边。其中 11 号点是公园的入口,NN 号点是公园的出口,每条边有一个非负权值, 代表策策经过这条边所要花的时间。

策策每天都会去逛公园,他总是从 11 号点进去,从 NN 号点出来。

策策喜欢新鲜的事物,他不希望有两天逛公园的路线完全一样,同时策策还是一个特别热爱学习的好孩子,他不希望每天在逛公园这件事上花费太多的时间。如果 11 号点 到 NN 号点的最短路长为 dd,那么策策只会喜欢长度不超过 d+Kd + K 的路线。

策策同学想知道总共有多少条满足条件的路线,你能帮帮他吗?

为避免输出过大,答案对 PP 取模。

如果有无穷多条合法的路线,请输出 1-1

输入格式

第一行包含一个整数 TT, 代表数据组数。

接下来 TT 组数据,对于每组数据: 第一行包含四个整数 N,M,K,PN,M,K,P,每两个整数之间用一个空格隔开。

接下来 MM 行,每行三个整数 ai,bi,cia_i,b_i,c_i,代表编号为 ai,bia_i,b_i 的点之间有一条权值为 cic_i 的有向边,每两个整数之间用一个空格隔开。

输出格式

输出文件包含 TT 行,每行一个整数代表答案。

输入输出样例 #1

输入 #1

2
5 7 2 10
1 2 1
2 4 0
4 5 2
2 3 2
3 4 1
3 5 2
1 5 3
2 2 0 10
1 2 0
2 1 0

输出 #1

3
-1

说明/提示

【样例解释1】

对于第一组数据,最短路为 3315,1245,12351\to 5, 1\to 2\to 4\to 5, 1\to 2\to 3\to 533 条合法路径。

【测试数据与约定】

对于不同的测试点,我们约定各种参数的规模不会超过如下

::cute-table{tuack}

测试点编号  TT   NN   MM   KK   是否有 00
115555101000
225510310^32×1032\times 10^300
335510310^32×1032\times 10^35050
445510310^32×1032\times 10^35050
555510310^32×1032\times 10^35050
665510310^32×1032\times 10^35050
775510510^52×1052\times 10^500
883310510^52×1052\times 10^55050
993310510^52×1052\times 10^55050
10103310510^52×1052\times 10^55050

对于 100%100\% 的数据,1P1091 \le P \le 10^91ai,biN1 \le a_i,b_i \le N0ci10000 \le c_i \le 1000

数据保证:至少存在一条合法的路线。


  • 2019.8.30 增加了一组 hack 数据 by @skicean
  • 2022.7.21 增加了一组 hack 数据 by @djwj233

SolutionSolution

好题啊好题。

我们先不考虑存在无穷多条路径的情况,只考虑如何对路径条数进行计数。

先看看如果直接暴力 dfsdfs 应该怎么写。我们记录当前节点 uu 和当前的路径长度 lenlen ,然后进行搜索。当我们搜索到 u=nu=nlend+Klen \le d+K 时,就将答案计数加一。同时加上一个判断,当目前已经走过的路径长度大于限制 d+Kd+K 时,就直接返回,否则会一直走下去。

如果这样写的话,显然会有很多已经不可能的状态在持续递归,所以我们需要进行剪枝。回顾这个模型,和P1535 [USACO08MAR] Cow Travelling S很像。同样都需要记录方案数,也同样需要优化剪枝。

我们考虑当前已经达到了节点 uu ,路径长度为 curcur ,通过预处理 uunn 的最短路径长度,我们可以提前判断当前这条路径是否可行。怎么判断呢?设剩余可以走的路径长度为 rest=d+Kcurrest=d+K-cur ,且 disudis_u 代表从 uunn 的最短路径长度,若有 rest<disurest<dis_u ,则说明接下来不论怎么走都不可能在规定的长度范围内走到终点,直接剪枝剪去。

如何预处理最短路径长度?我们直接跑一遍最短路就可以了。你可能会说这不是多源最短路吗,会 TLETLE 吧。我们只需要逆向思维处理,在原有图的基础上再建一个反向图,在这个反向图上以 nn 为起点跑一遍单源最短路即可。与之相同的思想也在P1629 邮递员送信中有所体现。

现在我们已经知道怎么剪枝了,但如果做过上面那道1535的就知道,这还不够,我们还需要记忆化搜索,因为有很多相同的状态被反复递归计算。

怎么记忆化呢?我们设一个二维的 memmem 数组,其中第一维代表当前还剩余多少余量可走,第二维代表当前到了哪个节点。当然这里的余量不只是直接用限制的最大长度 d+Kd+K 减去当前的路径长度 curcur ,而是还要再减去当前节点到 nn 的最短距离 disudis_u 。为什么要这么计算?我们可以注意到,题目给出的 KK 很小,而如果不减去 disudis_u 的话,所得的剩余距离长度会很大,而且还需要乘上一个节点数 nn ,这样开出来的数组显然会直接爆掉空间,所以我们需要考虑压缩第一维的大小。

已知最短路径长度为 dd ,则对于其他的任何路径,均有长度 lendlen \ge d ,而我们要求超出的最大量不超过 KK ,则对于每个节点 uu ,我们按照上面的方式计算余量,这代表当前我最多还能够再分配给后面的路径 restrest 大小的容许你多走的路径的长度。所以后面的所有可能产生的答案都受到当前这个 restrest 的限制,此时我们再将后面的答案递推上来,并记录到 memrest,umem_{rest,u} 中。我们常说记忆化搜索其实就是 dpdp ,这里也是一样的,运用到了 dpdp 中递推状态的思想,即相当于较大的 restrest 的方案数由所有产生在其后的路径并具有更小的 restrest 的方案数累加而成。

通过上面的处理,我们就可以计算出全部的方案数,下面我们来考虑如何处理具有无穷种路径的情况。

思考一下,当图呈什么情况的时候,我们会有无穷种路径?由于节点和边的数量均为有限的,所以我们必然有环的存在,使得我们可以在这个环上一直跑。同时又由于我们具有 d+Kd+K 的长度限制,所以如果这个环的边权之和为正,则每次绕环走一圈路径长度必然递增,最终会存在一个时刻使得路径长度超过 d+Kd+K 的限制,不可能有无穷种情况。所以综上,如果我们要有无穷种情况的存在,图中必须要存在一个环,它的边权之和为 00 ,且在环中存在一个节点 uu ,使得 dis1,u+disu,nd+Kdis_{1,u}+dis_{u,n} \le d+K ,否则这个环无法抵达,就无从贡献无穷种情况。

关于环的判定,我们使用 tarjantarjan 来判断强连通分量,对每个强连通分量进行边权的累加,并且这个强连通分量的大小应当大于一,因为单独的点也会被判断为强连通分量。

CodingCoding

#include <iostream>
#include <cstring>
#include <iomanip>
#include <cmath>
#include <vector>
#include <algorithm>
#include <queue>
#include <stack>
using namespace std;

#define ll long long
#define ull unsigned long long
#define debug(x) cout << #x << "=" << x << "\n";

int t;
int n, m, k, p;
int limit;
const int maxn = 1e5 + 10, maxm = 2e5 + 10;
const int maxk = 55;
const int INF = 0x3f3f3f3f;
bool is_reachable[maxn];
vector<pair<int, int>> graph[maxn];
vector<pair<int, int>> op_graph[maxn];
ll mem[maxk][maxn];
int dis[maxn];
int scc_cnt, time_stamp;
int dfn[maxn], low[maxn], scc_id[maxn], scc_weight[maxn], scc_size[maxn];
bool in_stack[maxn];
stack<int> st;
struct state
{
    int node, len;

    bool operator>(const state &other) const
    {
        return len > other.len;
    }
};
struct Edge
{
    int u, v, w;
};

void add_edge(int u, int v, int w)
{
    graph[u].push_back({v, w});
    op_graph[v].push_back({u, w});
}

void pre_deal()
{
    priority_queue<state, vector<state>, greater<state>> q;
    q.push({n, 0});
    memset(dis, 0x3f, sizeof(dis));
    dis[n] = 0;

    while (!q.empty())
    {
        state cur = q.top();
        q.pop();

        int u = cur.node;
        int cur_len = cur.len;

        if (cur_len > dis[u])
            continue;

        for (auto [v, w] : op_graph[u])
        {
            int new_len = cur_len + w;
            if (new_len < dis[v])
            {
                dis[v] = new_len;
                q.push({v, new_len});
            }
        }
    }
}

void check_reach()//用于判断这个节点是否从1可达
{
    queue<int> q;
    q.push(1);
    memset(is_reachable, 0, sizeof(is_reachable));

    while (!q.empty())
    {
        int u = q.front();
        q.pop();

        is_reachable[u] = true;

        for (auto [v, w] : graph[u])
        {
            if (!is_reachable[v])
                q.push(v);
        }
    }
}

void tarjan(int u)//判断强连通分量
{
    dfn[u] = low[u] = ++time_stamp;
    st.push(u);
    in_stack[u] = true;

    for (auto [v, w] : graph[u])
    {
        if (!dfn[v])
        {
            tarjan(v);
            low[u] = min(low[u], low[v]);
        }
        else if (in_stack[v])
            low[u] = min(low[u], dfn[v]);
    }

    if (low[u] == dfn[u])
    {
        scc_cnt++;
        int y;
        do
        {
            y = st.top();
            st.pop();
            scc_id[y] = scc_cnt;
            in_stack[y] = false;
            scc_size[scc_cnt]++;
        } while (y != u);
    }
}

ll dfs(int u, int cur)//记忆化搜索得到答案
{
    int cur_rest = limit - cur - dis[u];

    if (cur_rest < 0)
        return 0;

    if (mem[cur_rest][u] != -1)
        return mem[cur_rest][u];

    mem[cur_rest][u] = 0;

    if (u == n)
        mem[cur_rest][u] = 1;

    for (auto [v, w] : graph[u])
    {
        if (dis[v] == INF)
            continue;

        int rest = limit - dis[v] - cur - w;

        if (rest < 0)
            continue;

        if (mem[rest][v] != -1)
            mem[cur_rest][u] = (mem[cur_rest][u] + mem[rest][v]) % p;
        else
            mem[cur_rest][u] = (mem[cur_rest][u] + dfs(v, cur + w)) % p;
    }

    return mem[cur_rest][u] % p;
}

void solve()
{
    cin >> n >> m >> k >> p;

    scc_cnt = time_stamp = 0;
    for (int i = 1; i <= n; i++)
    {
        graph[i].clear();
        op_graph[i].clear();
    }

    while (!st.empty())
        st.pop();
    for (int i = 1; i <= n; i++)
        dfn[i] = low[i] = in_stack[i] = scc_id[i] = scc_weight[i] = scc_size[i] = 0;
    for (int i = 0; i <= k; i++)
    {
        for (int j = 1; j <= n; j++)
            mem[i][j] = -1;
    }

    Edge edge[maxm];

    for (int i = 1; i <= m; i++)
    {
        int u, v, w;
        cin >> u >> v >> w;
        add_edge(u, v, w);
        edge[i].u = u, edge[i].v = v, edge[i].w = w;
    }

    for (int i = 1; i <= n; i++)
    {
        if (!dfn[i])
            tarjan(i);
    }

    for (int i = 1; i <= m; i++)
    {
        if (scc_id[edge[i].u] == scc_id[edge[i].v])
            scc_weight[scc_id[edge[i].u]] += edge[i].w;
    }

    pre_deal();
    check_reach();

    if (dis[1] == INF)
        return void(cout << "0\n");

    limit = dis[1] + k;

    bool is_inf = false;

    for (int i = 1; i <= n; i++)
    {
        if (dis[i] < INF && is_reachable[i])
        {
            int id = scc_id[i];
            if (scc_size[id] >= 2 && scc_weight[id] == 0)
            {
                is_inf = true;
                break;
            }
        }
    }

    if (is_inf)
        return void(cout << "-1\n");

    cout << dfs(1, 0) % p << "\n";
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> t;
    while (t--)
        solve();

    return 0;
}

然而这份代码只有 70pts70pts ,为什么?

我们来看一个特殊情况:大环包小环。什么意思?即一个边权之和大于 00 的大环,包含了一个边权之和为 00 的小环,按理来说这应当也是无穷多条路径的情况,但 tarjantarjan 将这两个环合并到一起了,从而边权之和大于 00 ,错判了情况。

怎么解决?我们直接在 tarjantarjan 时根据边权为 00 的边来构建子图,在这个子图上跑 tarjantarjan ,如果最终存在一个大小大于一的强连通分量,且满足上面所说的距离限制,则我们有无穷多解。

上面的代码忘记写距离限制了(

100pts:100pts:

#include <iostream>
#include <cstring>
#include <iomanip>
#include <cmath>
#include <vector>
#include <algorithm>
#include <queue>
#include <stack>
using namespace std;

#define ll long long
#define ull unsigned long long
#define debug(x) cout << #x << "=" << x << "\n";

int t;
int n, m, k, p;
int limit;
const int maxn = 1e5 + 10, maxm = 2e5 + 10;
const int maxk = 55;
const int INF = 0x3f3f3f3f;
vector<pair<int, int>> graph[maxn];
vector<pair<int, int>> op_graph[maxn];
ll mem[maxk][maxn];
int dis[maxn], begin_dis[maxn];
int scc_cnt, time_stamp;
int dfn[maxn], low[maxn], scc_id[maxn], scc_weight[maxn], scc_size[maxn];
bool in_stack[maxn];
stack<int> st;
struct state
{
    int node, len;

    bool operator>(const state &other) const
    {
        return len > other.len;
    }
};
struct Edge
{
    int u, v, w;
};

void add_edge(int u, int v, int w)
{
    graph[u].push_back({v, w});
    op_graph[v].push_back({u, w});
}

void pre_deal()
{
    priority_queue<state, vector<state>, greater<state>> q;
    q.push({n, 0});
    memset(dis, 0x3f, sizeof(dis));
    dis[n] = 0;

    while (!q.empty())
    {
        state cur = q.top();
        q.pop();

        int u = cur.node;
        int cur_len = cur.len;

        if (cur_len > dis[u])
            continue;

        for (auto [v, w] : op_graph[u])
        {
            int new_len = cur_len + w;
            if (new_len < dis[v])
            {
                dis[v] = new_len;
                q.push({v, new_len});
            }
        }
    }
}

void check_reach()
{
    priority_queue<state, vector<state>, greater<state>> q;
    q.push({1, 0});
    memset(begin_dis, 0x3f, sizeof(begin_dis));
    begin_dis[1] = 0;

    while (!q.empty())
    {
        state cur = q.top();
        q.pop();

        int u = cur.node;
        int cur_len = cur.len;

        if (cur_len > begin_dis[u])
            continue;

        for (auto [v, w] : graph[u])
        {
            int new_len = cur_len + w;
            if (new_len < begin_dis[v])
            {
                begin_dis[v] = new_len;
                q.push({v, new_len});
            }
        }
    }
}

void tarjan(int u)
{
    dfn[u] = low[u] = ++time_stamp;
    st.push(u);
    in_stack[u] = true;

    for (auto [v, w] : graph[u])
    {
        if (w)
            continue;

        if (!dfn[v])
        {
            tarjan(v);
            low[u] = min(low[u], low[v]);
        }
        else if (in_stack[v])
            low[u] = min(low[u], dfn[v]);
    }

    if (low[u] == dfn[u])
    {
        scc_cnt++;
        int y;
        do
        {
            y = st.top();
            st.pop();
            scc_id[y] = scc_cnt;
            in_stack[y] = false;
            scc_size[scc_cnt]++;
        } while (y != u);
    }
}

ll dfs(int u, int cur)
{
    int cur_rest = limit - cur - dis[u];

    if (cur_rest < 0)
        return 0;

    if (mem[cur_rest][u] != -1)
        return mem[cur_rest][u];

    mem[cur_rest][u] = 0;

    if (u == n)
        mem[cur_rest][u] = 1;

    for (auto [v, w] : graph[u])
    {
        if (dis[v] == INF)
            continue;

        int rest = limit - dis[v] - cur - w;

        if (rest < 0)
            continue;

        if (mem[rest][v] != -1)
            mem[cur_rest][u] = (mem[cur_rest][u] + mem[rest][v]) % p;
        else
            mem[cur_rest][u] = (mem[cur_rest][u] + dfs(v, cur + w)) % p;
    }

    return mem[cur_rest][u] % p;
}

void solve()
{
    cin >> n >> m >> k >> p;

    scc_cnt = time_stamp = 0;
    for (int i = 1; i <= n; i++)
    {
        graph[i].clear();
        op_graph[i].clear();
    }

    while (!st.empty())
        st.pop();
    for (int i = 1; i <= n; i++)
        dfn[i] = low[i] = in_stack[i] = scc_id[i] = scc_weight[i] = scc_size[i] = 0;
    for (int i = 0; i <= k; i++)
    {
        for (int j = 1; j <= n; j++)
            mem[i][j] = -1;
    }

    Edge edge[maxm];

    for (int i = 1; i <= m; i++)
    {
        int u, v, w;
        cin >> u >> v >> w;
        add_edge(u, v, w);
        edge[i].u = u, edge[i].v = v, edge[i].w = w;
    }

    for (int i = 1; i <= n; i++)
    {
        if (!dfn[i])
            tarjan(i);
    }

    pre_deal();
    check_reach();

    if (dis[1] == INF)
        return void(cout << "0\n");

    limit = dis[1] + k;

    bool is_inf = false;

    for (int i = 1; i <= n; i++)
    {
        int total_dis = begin_dis[i] + dis[i];
        if (total_dis > limit)
            continue;

        if (scc_id[i] != 0 && scc_size[scc_id[i]] > 1)
        {
            is_inf = true;
            break;
        }
    }

    if (is_inf)
        return void(cout << "-1\n");

    cout << dfs(1, 0) % p << "\n";
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> t;
    while (t--)
        solve();

    return 0;
}