HDU 6865 杭电多校8 Kidnapper's Matching Problem (线性基 + KMP)
题意
给你三个数组\(a\),\(b\),\(S\)。\(a\)长度为\(n\),\(b\)长度为\(m\),\(S\)长度为\(k\), 且\(a\)数组长度大于\(b\)数组。
之后取\(a\)数组所有长度为\(m\)的连续区间,取出来的子数组与\(b\)对应位置的数两两异或,如果所有异或出来的值都能由\(S\)某个子集的数异或出来,那么答案加上\(2^{i - 1}\),\(i\)为子数组开头在\(a\)中的下标。
具体可以由下式表示:
\(\sum_{i = 1}^{n - m + 1} [(a_i, a_{i + 1}, \cdots, a_{i + m - 1}) \text{ matches } b] \cdot 2^{i - 1} \bmod (10^9 + 7)\)
题解
先看题解原文,对于两数异或能否由\(S\)的线性基表示,题解里说是去掉\(x\),\(y\)中线性基有的位,也就是每次用\(S\)线性基的数消去\(x\),\(y\)当前剩下的最高位。如果剩下的\(x'\)和\(y'\)相同,那就表示\(x^y\)能由\(S\)的线性基表示。
所以开始前我们用\(S\)的线性基直接消去\(a\),\(b\)数组中的每个数的二进制位。那么操作完之后我们只需要判断\(a\)中长度为\(m\)的子数组是否与b完全相同就行。这部分可以直接用KMP来解决。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll N = 2e5 + 7;
const int MAXL = 30;
const ll mod = 1e9 + 7;
int T, n, m, k;
int a[N], b[N];
int kk[205];
int nxt[N], vis[N];
struct LinearBasis{
int aa[MAXL+1];
LinearBasis(){
std::fill(aa, aa + MAXL + 1, 0);
}
void insert(int t){
for (int j = MAXL; j >= 0; j--){
if (!(t & (1ll << j))) continue;
if (aa[j]) t ^= aa[j];
else
{
for (int k = 0; k < j; k++) if (t & (1ll << k)) t ^= aa[k];
for (int k = j + 1; k <= MAXL; k++) if (aa[k] & (1ll << j)) aa[k] ^= t;
aa[j] = t;
return;
}
}
}
void build(int *x, int len){
std::fill(aa, aa + MAXL + 1, 0);
for (int i = 1; i <= len; i++){
insert(x[i]);
}
}
}ji;//线性基模板
int main()
{
scanf("%d", &T);
while(T--)
{
scanf("%d %d %d", &n, &m, &k);
for (int i = 1; i <= n;i++) scanf("%d", &a[i]);
for (int i = 1; i <= m;i++) scanf("%d", &b[i]);
for (int i = 1; i <= k;i++) scanf("%d", &kk[i]);
ji.build(kk, k);
for (int i = 1; i <= n;i++)
{
for (int j = MAXL; j >= 0;j--)
{
if((a[i] >> j) & 1) a[i] ^= ji.aa[j];
}
}
for (int i = 1; i <= m;i++)
{
for (int j = MAXL; j >= 0;j--)
{
if((b[i] >> j) & 1) b[i] ^= ji.aa[j];
}
}
for (int i = n - m; i >= 0;i--) vis[i] = 0;
nxt[1] = 0;
for (int i = 2, j = 0; i <= m;i++)
{
while(j && b[j + 1] != b[i]) j = nxt[j];
if(b[j + 1] == b[i]) j++;
nxt[i] = j;
}
for (int i = 1, j = 0; i <= n;i++)
{
while(j && b[j + 1] != a[i]) j = nxt[j];
if(b[j + 1] == a[i]) j++;
if(j == m)
{
vis[i - m] = 1;
j = nxt[j];
}
}
ll ans = 0;
for (int i = n - m; i >= 0;i--)
{
ans = ans * 2 + vis[i];
ans %= mod;
}
printf("%lld\n", ans);
}
return 0;
}