Educational Codeforces Round 155 (Rated for Div. 2)

D. Sum of XOR Functions

题意

给出一个数列: $a1,a_2,…,a_n$ ,求$\sum {l=1}^{n} \sum {r=l}^{n}f(l,r) \cdot (r - l + 1)$,其中 $f(l,r) = a_l \oplus a{l+1} \oplus … \oplus a_r$

答案对$998244353$取模. ($n \le 3\cdot 10 ^ 5$)

思路

我们如果暴力枚举,我们的复杂度就太高了(纯暴力$O(n^3)$ ,即使使用前缀来优化区间异或的操作,也需要 $O(n^2)$)。

因此我们需要考虑异或操作的特殊性。
我们可以将a按每个二进制位差分,这样就得到了一个二维数组(设为v),数组的长度是数列中的数的个数,宽度是一个数的二进制位数。

例如样例 $1,3,2$ 对应的v数组为:

V数组 1 3 2
$(a_i >> 1) \& 1$ 0 1 1
$(a_i >> 0) \& 1$ 1 1 0

然后我们就可以按每次一行的顺序来计算处理,最终将答案按权合并即可。

这样我们每次处理的数据就只有0 和 1,可以用如下方法来在$O(n)$ 的时间内解决:

假设我们要处理的01数组为数组b,

设有如下变量:cnt0,cnt1,sum0,sum1,res, 从左向右依次检查每一个数(设下标为1~n):

如果检查到第$i$个数 ,这时候的 cnt0 ,cnt1 分别表示以 $b[i]$ 为结尾的,区间异或为0 / 为1 的区间 的数量; sum0,sum1 表示以$b[i]$ 为结尾的区间异或 的 区间长度的和 , res表示 此时的 前i个数组成的所有异或1区间的 长度和(不止是以i为右边界的区间)
翻译为数学公式即:
$cnt0 = \sum _{k = 1}^i 1,(f(k,i) = 0)$

​ $cnt1 = \sum _{k = 1}^i 1,(f(k,i) = 1)$

​ $sum0 = \sum _{k = 1}^i (i-k+1),(f(k,i) = 0)$

​ $sum1 = \sum _{k = 1}^i (i-k+1),(f(k,i) = 0)$

​ $res = \sum {l=1}^{i} \sum {r=l}^{i}f(l,r) \cdot (r - l + 1)$

然后遍历i 从1 到n:进行状态转移:
如果 $b[i] = 1$ :

​ $swap(cnt0,cnt1);swap(sum0,sum1);$ (原先异或为0的区间再异或1变为1,异或为1的区间再异或为0,效果就是交换cnt0和cnt1)

​ $cnt0+=0;cnt1+=1;$ ((区间异或为1的段的数量加上1)

​ $sum0+=cnt0;sum1+=cnt1;$长度都增加了1,相当于总的长度增加了cnt0和cnt1

​ $res += sum1$ (最终答案加上 $sum1_i$)

如果b[i] = 0:

(原先的区间异或的值不会变)

​ $cnt0+=1;cnt1+=0;$ (区间异或为0的段的数量加上1)

​ $sum0+=cnt0;sum1+=cnt1;$ 度都增加了1,相当于总的长度增加了cnt0和cnt1

​ $res += sum1$ (最终答案加上 $sum1_i$)

最终将每行得到的res进行按权(第$i$行的权为$ (1<<(i-1)) $)相加即可。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include<iostream>
using namespace std;
#define int long long
const int N = 3e5 + 5;
bool v[33][N];
const int MOD = 998244353;
signed main(){
int n;
cin>>n;
int num;
for(int j = 1;j<=n;j++){
scanf("%d",&num);
for(int i = 0;i<32;i++){
if((num>> i) & 1) v[i][j] = 1;
}
}
int ans = 0;
for(int i = 0;i<32;i++){
int res ,cnt0,cnt1,sum0,sum1;
res = cnt0 = cnt1 = sum0 = sum1 = 0;
for(int j = 1;j<=n;j++){
if(v[i][j] == 1){
swap(cnt0,cnt1);
swap(sum0,sum1);
cnt1++;
sum0 = (sum0 + cnt0)%MOD;
sum1= (sum1 + cnt1)%MOD;
res = (res + sum1)%MOD;
}
else{
cnt0 += 1;
sum0 = (sum0 + cnt0)%MOD;
sum1= (sum1 + cnt1)%MOD;
res = (res + sum1)%MOD;
}
}
ans = (ans + (res<<i)%MOD )%MOD;
}
cout<<ans;
return 0;
}