【CF1743F】Intersection and Union
题目
题目链接:https://codeforces.com/contest/1743/problem/F
有 \(n\) 个集合,其中第 \(i\) 个集合 \(S_i\) 包含 \([l_i,r_i]\) 中的所有整数。
考虑一个长为 \(n-1\) 的序列 \([op_1,op_2,\cdots,op_{n-1}]\),其中每一个元素都可能是 \(\cup,\cap,\oplus\) 的任意一个,求这 \(3^{n-1}\) 个可能的序列
\[|(((S_1\ op_1\ S_2)\ op_2\ S_3)\ op_3\ S_4)\cdots\ op_{n-1}\ S_n|
\]
之和。答案对 \(998244353\) 取模。
\(n\leq 3\times 10^5;0\leq l_i\leq r_i\leq 3\times 10^5\)。
思路
退役人复健(
考虑每一个整数在最终那 \(3^{n-1}\) 个式子中的贡献。记 \(01\) 序列 \(a_i\) 表示目前这个整数是否被包含于第 \(i\) 个区间。
那么很容易想到 dp。设 \(f[i][0/1]\) 表示前 \(i\) 个区间经过运算后,这个整数不在集合内 \(/\) 在集合内的方案数。
考虑两个 \(0,1\) 之间进行位运算的结果,那么
\[f[i][0]=\left\{\begin{matrix}f[i-1][0]\times 3+f[i-1][1]\times 1\ (a_i=0) \\f[i-1][0]\times 1+f[i-1][1]\times 1\ (a_i=1)\end{matrix}\right.
\]
\[f[i][1]=\left\{\begin{matrix}f[i-1][0]\times 0+f[i-1][1]\times 2\ (a_i=0) \\f[i-1][0]\times 2+f[i-1][1]\times 2\ (a_i=1)\end{matrix}\right.
\]
把每个数字都跑一遍 dp,时间复杂度 \(O(nV)\)。
发现由于每一个集合都是一个区间,也就意味着,从计算数字 \(i\) 的贡献转移到数字 \(i+1\) 的贡献时,数组 \(a\) 均摊只有 \(O(1)\) 个位置需要改变,所以只需要搞一个动态 dp 就可以了。
时间复杂度 \(O(V\log n)\)。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=300010,MOD=998244353;
int n,ans,L,R;
vector<int> v1[N],v2[N];
struct Matrix
{
int a[2][2];
Matrix() { memset(a,0,sizeof(a)); }
friend Matrix operator *(Matrix a,Matrix b)
{
Matrix c;
for (int i=0;i<=1;i++)
for (int j=0;j<=1;j++)
for (int k=0;k<=1;k++)
c.a[i][j]=(c.a[i][j]+1LL*a.a[i][k]*b.a[k][j])%MOD;
return c;
}
}mat;
struct SegTree
{
Matrix f[N*4];
void build(int x,int l,int r)
{
if (l==r)
{
f[x].a[0][0]=3,f[x].a[1][0]=1,f[x].a[0][1]=0,f[x].a[1][1]=2;
return;
}
int mid=(l+r)>>1;
build(x*2,l,mid); build(x*2+1,mid+1,r);
f[x]=f[x*2]*f[x*2+1];
}
void update(int x,int l,int r,int k,int opt)
{
if (l==r)
{
if (opt==0) f[x].a[0][0]=3,f[x].a[1][0]=1,f[x].a[0][1]=0,f[x].a[1][1]=2;
if (opt==1) f[x].a[0][0]=1,f[x].a[1][0]=1,f[x].a[0][1]=2,f[x].a[1][1]=2;
return;
}
int mid=(l+r)>>1;
if (k<=mid) update(x*2,l,mid,k,opt);
if (k>mid) update(x*2+1,mid+1,r,k,opt);
f[x]=f[x*2]*f[x*2+1];
}
}seg;
int main()
{
scanf("%d%d%d",&n,&L,&R);
for (int i=2,x,y;i<=n;i++)
{
scanf("%d%d",&x,&y);
v1[x].push_back(i); v2[y+1].push_back(i);
}
seg.build(1,2,n);
for (int i=0;i<=300000;i++)
{
for (int j=0;j<(int)v1[i].size();j++)
seg.update(1,2,n,v1[i][j],1);
for (int j=0;j<(int)v2[i].size();j++)
seg.update(1,2,n,v2[i][j],0);
if (L<=i && R>=i) mat.a[0][0]=0,mat.a[0][1]=1;
else mat.a[0][0]=1,mat.a[0][1]=0;
mat.a[1][0]=mat.a[1][1]=0;
ans=(ans+(mat*seg.f[1]).a[0][1])%MOD;
}
printf("%d",ans);
return 0;
}