2022 CCPC 绵阳站 E题 (图上DP,根号分治)

题意

有一个由$n$城市组成的国家,城市之间由一条权值为$w$的边连接,共m条这样的边,并且保证整个国家是连通的。每个城市中有$a_i$ 个居民。

在接下来的q天,每天都会有一个城市遭受灾难$b_1,b_2,…,b_q$,你必须将该城市的所有人都转移到其它城市才能避免居民受到灾难,转移一个居民到相邻城市的代价为两个城市之间路径的权值w。

请问你最少需要多少代价才能让所有居民都安全度过q天的灾难。

思路

我们不必管每个城市中有多少个人,我们只需要求出每个城市中转移一个人的最小代价,在最终计算总代价时再乘上人数即可。很容易想到一个暴力的DP解法如下:

令$f(i,j)$ 表示在第$j$号点,第$i$天后所有的操作中最小的代价。那么有$f(i,j) = MIN{(v,w) \in edge{b_i}}{w+f(i+1,v)}$ 。

我们发现每天只会更新一个dp值,于是我们可以直接省去f的第一维,然后倒序枚举天数 $i $ 从$q$到$1$ 。 总的时间复杂度为$O(\sum_{i=1}^q deg(b_i)$ ,我们发现当$b_i$ 的度较大时,复杂度会退化到$O(qn)$ 这是不被允许的。

于是考虑根号分治,设分治边界为$SQ$ ,那么当$deg(b_i) \le SQ$ 时,枚举他的所有边来更新dp, 如果$deg(b_i) > SQ$ 那么我们为这个节点建立一个multiset ,存储所有邻边的$dp[v]+w$ 的值,则multiset的第一项即为当前最小的dp值。每当一个节点的dp值更新时,将与他相邻的$deg(v) > SQ$的点更新。

当$SQ = \sqrt{2 \times m \times logn}$ 时,复杂度为$O(q\sqrt{m \times logn})$

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include<bits/stdc++.h>
using namespace std;
#define IO ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define int long long
#define rep(i,l,r) for(int i = l;i<=r;i++)
#define per(i,r,l) for(int i = r;i>=l;i--)
const int INF = 0x3f3f3f3f3f3f3f3f;
typedef pair<int,int> PII;
int a[100005];
vector<PII> edge[100005];
vector<PII> edgeB[100005];
int deg[100005];
int que[100005];
int dp[100005];
multiset<int> mulst[100005];
const int mod = 998244353;
void solve(){
int n,m,q;
cin>>n>>m>>q;
int SQ = sqrt(2LL * m * log2(n));
for(int i = 1;i<=n;i++){
cin >> a[i];
}
for(int i = 1;i<=m;i++){
int u,v,w;
cin>>u>>v>>w;
edge[u].push_back({v,w});
edge[v].push_back({u,w});
deg[v]++;deg[u]++;
}
for(int u = 1;u<=n;u++){
if(deg[u] > SQ){
for(auto [v,w] : edge[u]){
if(deg[v] > SQ){
edgeB[u].push_back({v,w});
}
mulst[u].insert(w);
}
}
}
for(int i = 1;i<=q;i++){
cin>>que[i];
}
for(int i = q;i>=1;i--){
int u = que[i];
if(deg[u] <= SQ){
int cost = INF;
for(auto [v,w] : edge[u]){
cost = min(cost,dp[v]+w);
}
for(auto [v,w] : edge[u]){
if(deg[v] > SQ){
mulst[v].erase(mulst[v].find(w+dp[u]));
mulst[v].insert(w+cost);
}
}
dp[u] = cost;
}else{
int cost = *mulst[u].begin();
for(auto [v,w] : edgeB[u]){
if(deg[v] > SQ){
mulst[v].erase(mulst[v].find(w+dp[u]));
mulst[v].insert(w+cost);
}
}
dp[u] = cost;
}
}
int ans = 0;
for(int i = 1;i<=n;i++){
ans = (ans + dp[i] * a[i]) %mod;
}
cout<<ans;
}
signed main(){
int T = 1;
// cin>>T;
while(T--){
solve();
}
return 0;
}