联考20200722 T1 集合划分
分析:
首先是一个\(O(n^2)\)的DP,设\(f_{i,j,0/1}\)表示做了前\(i\)个,用了\(j\)个\(A\),最后一个是\(A/B\)的方案数
然后我们不看最后一位,发现\(f_{i,j}\)两个状态可以用\(2*2\)的转移矩阵DP
发现转移矩阵与\(j\)没有关系,把\(j\)去掉,维护\(f_i=\sum_{j=0}a_jx^j\)的生成函数,\(x^j\)项系数就是\(f_{i,j}\)
如果加一位\(A\)相当于乘一个\(x\),否则乘一个\(1\)
分治维护矩阵上的多项式
复杂度\(O(nlog^2n)\),我的常数巨大2333
#include<cstdio>
#include<cmath>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<queue>
#include<set>
#include<map>
#include<vector>
#include<string>
#define maxn 200005
#define INF 0x3f3f3f3f
#define MOD 998244353
#define Poly vector<int>
using namespace std;
inline long long getint()
{
long long num=0,flag=1;char c;
while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
while(c>='0'&&c<='9')num=num*10+c-48,c=getchar();
return num*flag;
}
int n;
int A[maxn],B[maxn];
struct node{
Poly a[2][2];
}P[maxn];
int rev[maxn];
inline int upd(int x){return x<MOD?x:x-MOD;}
inline int ksm(int num,int k)
{
int ret=1;
for(;k;k>>=1,num=1ll*num*num%MOD)if(k&1)ret=1ll*ret*num%MOD;
return ret;
}
inline Poly add(Poly x,Poly y)
{
int mx=max(x.size(),y.size());
x.resize(mx),y.resize(mx);
for(int i=0;i<mx;i++)x[i]=upd(x[i]+y[i]);
return x;
}
inline void NTT(Poly &a,int N,int opt)
{
for(int i=0;i<N;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int i=1;i<N;i<<=1)
{
int wn=ksm(3,(MOD-1)/(i<<1));
if(!~opt)wn=ksm(wn,MOD-2);
for(int j=0;j<N;j+=i<<1)for(int k=0,w=1;k<i;k++,w=1ll*w*wn%MOD)
{
int x=a[j+k],y=1ll*a[i+j+k]*w%MOD;
a[j+k]=upd(x+y),a[i+j+k]=upd(x-y+MOD);
}
}
if(!~opt)for(int i=0,Inv=ksm(N,MOD-2);i<N;i++)a[i]=1ll*a[i]*Inv%MOD;
}
inline node mul(node y,node x)
{
int N=x.a[0][0].size(),M=y.a[0][0].size(),len=1;
while(len<N+M)len<<=1;
for(int i=0;i<len;i++)rev[i]=(rev[i>>1]>>1)|(i&1?len>>1:0);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)
{
x.a[i][j].resize(len),y.a[i][j].resize(len);
NTT(x.a[i][j],len,1),NTT(y.a[i][j],len,1);
}
node z;
for(int i=0;i<2;i++)for(int j=0;j<2;j++)
{
z.a[i][j].resize(len);
for(int k=0;k<2;k++)for(int l=0;l<len;l++)z.a[i][j][l]=(z.a[i][j][l]+1ll*x.a[i][k][l]*y.a[k][j][l])%MOD;
}
for(int i=0;i<2;i++)for(int j=0;j<2;j++)NTT(z.a[i][j],len,-1),z.a[i][j].resize(N+M-1);
return z;
}
inline node solve(int l,int r)
{
if(l==r)return P[l];
int mid=(l+r)>>1;
return mul(solve(l,mid),solve(mid+1,r));
}
int main()
{
n=getint(),getint();
for(int i=1;i<=2*n;i++)A[i]=getint();
for(int i=1;i<=2*n;i++)B[i]=getint();
for(int i=1;i<=2*n;i++)
{
for(int j=0;j<2;j++)for(int k=0;k<2;k++)P[i].a[j][k].resize(2);
if(A[i-1]<=A[i])P[i].a[0][0][1]=1;
if(B[i-1]<=A[i])P[i].a[0][1][1]=1;
if(A[i-1]<=B[i])P[i].a[1][0][0]=1;
if(B[i-1]<=B[i])P[i].a[1][1][0]=1;
}
node Ans=solve(1,2*n);
printf("%d\n",upd(Ans.a[0][0][n]+Ans.a[1][0][n]));
}