2022 CCPC 绵阳站 E题 (图上DP,根号分治)
题意
有一个由$n$城市组成的国家,城市之间由一条权值为$w$的边连接,共m条这样的边,并且保证整个国家是连通的。每个城市中有$a_i$ 个居民。
在接下来的q天,每天都会有一个城市遭受灾难$b_1,b_2,…,b_q$,你必须将该城市的所有人都转移到其它城市才能避免居民受到灾难,转移一个居民到相邻城市的代价为两个城市之间路径的权值w。
请问你最少需要多少代价才能让所有居民都安全度过q天的灾难。
思路
我们不必管每个城市中有多少个人,我们只需要求出每个城市中转移一个人的最小代价,在最终计算总代价时再乘上人数即可。很容易想到一个暴力的DP解法如下:
令$f(i,j)$ 表示在第$j$号点,第$i$天后所有的操作中最小的代价。那么有$f(i,j) = MIN{(v,w) \in edge{b_i}}{w+f(i+1,v)}$ 。
我们发现每天只会更新一个dp值,于是我们可以直接省去f的第一维,然后倒序枚举天数 $i $ 从$q$到$1$ 。 总的时间复杂度为$O(\sum_{i=1}^q deg(b_i)$ ,我们发现当$b_i$ 的度较大时,复杂度会退化到$O(qn)$ 这是不被允许的。
于是考虑根号分治,设分治边界为$SQ$ ,那么当$deg(b_i) \le SQ$ 时,枚举他的所有边来更新dp, 如果$deg(b_i) > SQ$ 那么我们为这个节点建立一个multiset ,存储所有邻边的$dp[v]+w$ 的值,则multiset的第一项即为当前最小的dp值。每当一个节点的dp值更新时,将与他相邻的$deg(v) > SQ$的点更新。
当$SQ = \sqrt{2 \times m \times logn}$ 时,复杂度为$O(q\sqrt{m \times logn})$
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
| #include<bits/stdc++.h> using namespace std; #define IO ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); #define int long long #define rep(i,l,r) for(int i = l;i<=r;i++) #define per(i,r,l) for(int i = r;i>=l;i--) const int INF = 0x3f3f3f3f3f3f3f3f; typedef pair<int,int> PII; int a[100005]; vector<PII> edge[100005]; vector<PII> edgeB[100005]; int deg[100005]; int que[100005]; int dp[100005]; multiset<int> mulst[100005]; const int mod = 998244353; void solve(){ int n,m,q; cin>>n>>m>>q; int SQ = sqrt(2LL * m * log2(n)); for(int i = 1;i<=n;i++){ cin >> a[i]; } for(int i = 1;i<=m;i++){ int u,v,w; cin>>u>>v>>w; edge[u].push_back({v,w}); edge[v].push_back({u,w}); deg[v]++;deg[u]++; } for(int u = 1;u<=n;u++){ if(deg[u] > SQ){ for(auto [v,w] : edge[u]){ if(deg[v] > SQ){ edgeB[u].push_back({v,w}); } mulst[u].insert(w); } } } for(int i = 1;i<=q;i++){ cin>>que[i]; } for(int i = q;i>=1;i--){ int u = que[i]; if(deg[u] <= SQ){ int cost = INF; for(auto [v,w] : edge[u]){ cost = min(cost,dp[v]+w); } for(auto [v,w] : edge[u]){ if(deg[v] > SQ){ mulst[v].erase(mulst[v].find(w+dp[u])); mulst[v].insert(w+cost); } } dp[u] = cost; }else{ int cost = *mulst[u].begin(); for(auto [v,w] : edgeB[u]){ if(deg[v] > SQ){ mulst[v].erase(mulst[v].find(w+dp[u])); mulst[v].insert(w+cost); } } dp[u] = cost; } } int ans = 0; for(int i = 1;i<=n;i++){ ans = (ans + dp[i] * a[i]) %mod; } cout<<ans; } signed main(){ int T = 1; while(T--){ solve(); } return 0; }
|