题解 I. Three Body "蔚来杯"2022牛客暑期多校训练营4
【大意】
给定 \(K\) 维数组 \(S,T\) ,其中每个元素都是不超过 \(K\) 的正整数。求有多少个位置,使得 \(T\) 的 \(T_{0, 0}\) 元素对齐该位置后,整个 \(T\) 数组的值都不超过 \(S\) 数组对应位置的值
【分析】
我们令 \(g_{v, x_1, x_2, \cdots, x_K}=[T_{x_1, x_2, \cdots, x_K}=v]\) ,表示 \(T\) 数组的对应位置是否为 \(v\)
令 \(f_{v, x_1, x_2, \cdots, x_K}=[S_{x_1, x_2, \cdots, x_K}<v]\) ,表示 \(S\) 数组的对应位置是否严格小于 \(v\)
那么,对于 \(\displaystyle h_{v, x_1, x_2, \cdots, x_K}=\sum_{i_1-j_1=x_1}\sum_{i_2-j_2=x_2}\cdots\sum_{i_K-j_K=x_K}f_{v, i_1, i_2, \cdots, i_K}\cdot g_{v, j_1, j_2, \cdots, j_K}\) ,表示位置 \((x_1, x_2, \cdots, x_K)\) 处放置 \(T\) 数组的 \(T_{0, 0}\) 元素后,因为 \(T\) 数组中大小为 \(v\) 的元素,产生的冲突次数
我们对 \(h_{v, x_1, x_2, \cdots, x_K}\) 从 \(v=1\) 到 \(v=K\) 进行求和,则 \(\displaystyle h_{x_1, x_2, \cdots, x_K}=\sum_{v=1}^Kh_{v, x_1, x_2, \cdots, x_K}\) ,表示位置 \((x_1, x_2, \cdots, x_K)\) 处放置 \(T\) 数组的 \(T_{0, 0}\) 元素后产生的冲突次数
显然只有冲突为 \(0\) 的位置是可以摆放的,故答案为 \(h\) 中,\(0\) 出现的次数
现在的问题化为如何求解 \(h_{v, x_1, x_2, \cdots, x_K}\)
考虑到求解式子类似减法卷积的形式,我们直接定义 \(g[v][x_1][x_2]\cdots[x_K]\to g'[x_1n_2n_3\cdots n_K+x_2n_3\cdots n_K+\cdots x_{K-1}n_K+x_K]\)
同理定义 \(f'\) 与 \(h'\) ,并对之前未定义的位置置 \(0\)
于是,\(\displaystyle h'_x=\sum_{i-j=x}f'_ig'_j\) ,其中 \(x=x_1n_2n_3\cdots n_K+x_2n_3\cdots n_K+\cdots x_{K-1}n_K+x_K, i=i_1n_2n_3\cdots n_K+i_2n_3\cdots n_K+\cdots i_{K-1}n_K+i_K, j=j_1n_2n_3\cdots n_K+j_2n_3\cdots n_K+\cdots j_{K-1}n_K+j_K\)
而对应过来,有 \(i-j=(i_1-j_1)n_2n_3\cdots n_K+(i_2-j_2)n_3\cdots n_K+\cdots+(i_{K-1}-j_{K-1})n_K+(i_K-j_K)=x\) ,故 \(i_t-j_t=x_t\) 在原对应位置均合法
然而,唯一的问题出在边界上,需要特判一下,在该位置上的值,加上 \(T\) 数组的大小后,是否会超过这一维度的边界
由于超过边界的值也是被置 \(0\) 的,故可能计算出的答案也是 \(0\) ,但该位置因为过于靠近边界,答案是不能被统计的
【代码】
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define sz(a) (int)a.size()
#define de(a) cout << #a << " = " << a << endl
#define dd(a) cout << #a << " = " << a << " "
#define all(a) a.begin(), a.end()
#define pw(x) (1ll<<(x))
#define lc(x) ((x)<<1)
#define rc(x) ((x)<<1|1)
#define rsz(a, x) (a.resize(x))
typedef unsigned long long ull;
typedef long long ll;
typedef pair<int, int> pii;
typedef vector<int> vi;
typedef double db;
const int P=998244353;
inline int kpow(int a, int x, int p=P) { int ans=1; for(;x;x>>=1, a=(ll)a*a%p) if(x&1) ans=(ll)ans*a%p; return ans; }
inline int exgcd(int a, int b, int &x, int &y) {
static int g;
return b?(exgcd(b, a%b, y, x), y-=a/b*x, g):(x=1, y=0, g=a);
}
inline int inv(int a, int p=P) {
static int x, y;
return exgcd(a, p, x, y)==1?(x<0?x+p:x):(-1);
}
const int LimBit=19;
const int M=1<<LimBit<<1;
namespace Poly{
const int G=3;
struct vir {
int v;
vir(int v_=0):v(v_>=P?v_-P:v_) {}
inline vir operator + (const vir &x) const { return vir(v+x.v); }
inline vir operator - (const vir &x) const { return vir(v+P-x.v); }
inline vir operator * (const vir &x) const { return vir((ll)v*x.v%P); }
inline vir operator - () const { return vir(P-v); }
inline vir operator ! () const { return vir(inv(v)); }
inline operator int() const { return v; }
};
struct poly : public vector<vir> {
inline friend ostream& operator << (ostream& out, const poly &p) {
if(!p.empty()) out<<(int)p[0];
for(int i=1; i<sz(p); ++i) out<<" "<<(int)p[i];
return out;
}
};
int N, N_, Stk[M], curStk, rev[M];
vir invN, Inv[M], w[2][M];
inline void init() {
N_=-1;
curStk=0;
Inv[1]=1;
for(int i=2; i<M; ++i)
Inv[i]=-vir(P/i)*Inv[P%i];
}
void work(){
if(N_==N) return ;
N_=N;
int d = __builtin_ctz(N);
vir x(kpow(G, (P-1)/N)), y=!x;
w[0][0] = w[1][0] = 1;
for (int i = 1; i < N; ++i) {
rev[i] = (rev[i>>1] >> 1) | ((i&1) << (d-1));
w[0][i]=x*w[0][i-1], w[1][i]=y*w[1][i-1];
}
invN=!vir(N);
}
inline void FFT(vir a[M],int f){
static auto make = [=](vir w, vir &a, vir &b) { w=w*a; a=b-w; b=b+w; };
for(int i=0;i<N;++i) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=1;i<N;i<<=1)
for(int j=0,t=N/(i<<1);j<N;j+=i<<1)
for(int k=0,l=0;k<i;k++,l+=t)
make(w[f][l], a[j+k+i], a[j+k]);
if(f) for(int i=0;i<N;++i) a[i]=a[i]*invN;
}
vir p1[M], p0[M];
inline void get_mul(poly &a, poly &b, int na, int nb) {//3*FFT
for(N=1;N<na+nb-1;N<<=1);
for(int i=0; i<na; ++i) p1[i]=(int)a[i]; for(int i=na;i<N;++i) p1[i]=0;
for(int i=0; i<nb; ++i) p0[i]=(int)b[i]; for(int i=nb;i<N;++i) p0[i]=0;
work(); FFT(p1,0); FFT(p0,0);
for(int i=0;i<N;++i) p1[i]=p1[i]*p0[i];
FFT(p1,1);
rsz(a, na+nb-1); for(int i=0; i<sz(a); ++i) a[i]=p1[i];
}
inline void get_mulT(poly &a, poly &b, poly&c, int na, int nb, int n) {//c=a*r(b)
c=b;
reverse(all(c));
get_mul(c, a, nb, na);
for(int i=0, j=nb-1; i<n; ++i, ++j)
c[i]=c[j];
rsz(c, n);
}
}
using Poly::poly;
const int MAXN=3e5+10;
int k, n[7], val[MAXN], ship[MAXN], m[6];
poly f, g, h;
inline int alw(int pos) {
if(h[pos].v)
return 0;
for(int i=1; i<=k; ++i) {
if(pos+(m[i]-1)*n[i+1]>=n[i])
return 0;
pos%=n[i+1];
}
return 1;
}
inline void work() {
rsz(h, n[1]);
for(int t=1; t<=k; ++t) {
rsz(f, n[1]); rsz(g, n[1]);
for(int i=0; i<n[1]; ++i) {
f[i]=(val[i]<t);
g[i]=(ship[i]==t);
}
Poly::get_mulT(f, g, g, n[1], n[1], n[1]);
for(int i=0; i<n[1]; ++i)
h[i]=h[i]+g[i];
}
int res=0;
for(int i=0; i<n[1]; ++i)
if(alw(i)) ++res;
cout<<res;
}
void draw(int m[], int k, int pos, int x) {
if(pos>k) {
ship[x]=k;
return ;
}
for(int i=0; i<m[pos]; ++i)
draw(m, k, pos+1, x+i*n[pos+1]);
}
inline void init() {
Poly::init();
cin>>k;
for(int i=1; i<=k; ++i) cin>>n[i];
n[k+1]=1;
for(int i=k; ~i; --i) n[i]*=n[i+1];
for(int i=0; i<n[1]; ++i) val[i]=k;
int c; cin>>c;
for(int i=1; i<=c; ++i) {
int x=0, y;
for(int j=1; j<=k; ++j)
cin>>y, x+=y*n[j+1];
cin>>y;
val[x]=y;
}
for(int i=1; i<=k; ++i) cin>>m[i];
draw(m, k, 1, 0);
cin>>c;
for(int i=1; i<=c; ++i) {
int x=0, y;
for(int j=1; j<=k; ++j)
cin>>y, x+=y*n[j+1];
cin>>y;
ship[x]=y;
}
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
init();
work();
cout.flush();
return 0;
}