CF 1400G.Mercenaries 题解【SOSDP 组合数学】
题意:
有\(n\)个佣兵,问雇佣至少一名雇佣兵且满足下述条件的方案数
-
如果雇佣第\(i\)个佣兵必须要求最终雇佣的总人数\(x\)满足\(l_i\le x\le r_i\)
-
有\(m\)对佣兵不能同时选
\(1\le n\le 3\times 10^5,0 \le m \le \min(20, \dfrac{n(n-1)}{2})\)
题解:
首先对于第一个限制,我们考虑枚举最终雇佣的总人数来做
对于第二个限制,可以把这些不能同时选的点连边,不考虑单独一个点的块,那么对于每个连通块,能选的方案必然是其中的一个独立集,单独考虑每一个连通块,由于\(m\)很小,所以最大的连通块里的点数不会超过\(m+1\)个
我们定义\(h[i][x][k]\)表示考虑第\(i\)个连通块,最终雇佣人数为\(x\)的情况下,雇佣连通块\(i\)中的\(k\)个人的合法方案数
定义\(f[i][msk][k]\)表示考第\(i\)个联通块中,在\(msk\)这个集合中选雇佣兵,选\(k\)个的合法方案
令\(S_{i,x}\)为第\(i\)个连通块中,最终选取人数为\(x\)且只考虑第一个限制的情况下可选的雇佣兵的集合
那么可以发现\(h[i][x][k]=f[i][S_{i,x}][k]\)
考虑如何计算\(f[i][msk][k]\),我们可以先\(2^m\)枚举这个集合中的所有子集\(sub\),如果某个子集\(sub\)合法(即这个子集是一个独立集),那么我们令\(f[i][sub][\mid sub\mid] = 1\),然后我们通过高维前缀和(\(SOS\ DP\))就可以在\(O(m^2\cdot 2^m)\)的时间内处理出所有的\(f[i][msk][k]\)了
最终我们枚举雇佣总人数,然后用分组背包处理出这些大于\(1\)的连通块的选择方案,然后剩下的那些大小为\(1\)的连通块可以直接用组合数来算
view code
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast,no-stack-protector")
#include<bits/stdc++.h>
using namespace std;
#define INF 0x3f3f3f3f
#define endl "\n"
#define LL long long int
#define vi vector<int>
#define vl vector<LL>
#define all(V) V.begin(),V.end()
#define sci(x) scanf("%d",&x)
#define scl(x) scanf("%I64d",&x)
#define scs(x) scanf("%s",s)
#define pii pair<int,int>
#define pll pair<LL,LL>
#ifndef ONLINE_JUDGE
#define cout cerr
#endif
#define cmax(a,b) ((a) = (a) > (b) ? (a) : (b))
#define cmin(a,b) ((a) = (a) < (b) ? (a) : (b))
#define debug(x) cerr << #x << " = " << x << endl
function<void(void)> ____ = [](){ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);};
template <typename T> vector<T>& operator << (vector<T> &__container, T x){ __container.push_back(x); return __container; }
template <typename T> ostream& operator << (ostream &out, vector<T> &__container){ for(T _ : __container) out << _ << ' '; return out; }
const int MAXN = 3e5+7;
const int MOD = 998244353;
int n, m, root[MAXN], sz[MAXN], l[MAXN], r[MAXN], cnt[MAXN];
int fac[MAXN], inv[MAXN], rfac[MAXN];
vector<vi> comp;
vector<vector<vi> > f;
int compID;
map<int,int> msk;
set<pair<int,int> > S;
int findx(int x){ return x == root[x] ? x : root[x] = findx(root[x]); }
int C(int n, int m){ return n < m ? 0 : 1ll * fac[n] * rfac[m] % MOD * rfac[n-m] % MOD; }
// preprocess with sosdp
void preprocess(int id){
f[id] = vector<vi>(comp[id].size()+1,vi(1<<comp[id].size(),0));
auto &dp = f[id];
for(int i = 0; i < (1<<comp[id].size()); i++){
bool ok = true;
for(int bit1 = 0; bit1 < comp[id].size(); bit1++){
if(!(i>>bit1&1)) continue;
for(int bit2 = bit1 + 1; bit2 < comp[id].size(); bit2++){
if(!(i>>bit2&1)) continue;
int x = comp[id][bit1], y = comp[id][bit2];
if(S.count(make_pair(min(x,y),max(x,y)))) ok = false;
}
}
if(ok) dp[__builtin_popcount(i)][i] = 1;
}
for(int k = 0; k <= (int)comp[id].size(); k++)
for(int i = 0; i < (int)comp[id].size(); i++) for(int msk = 0; msk < (1<<comp[id].size()); msk++)
if(msk>>i&1) dp[k][msk] += dp[k][msk^(1<<i)];
}
void solve(){
sci(n); sci(m);
for(int i = 1; i <= n; i++) sci(l[i]), sci(r[i]);
for(int i = 1; i <= n; i++) root[i] = i, sz[i] = 1;
for(int i = 1; i <= m; i++){
int u, v; sci(u); sci(v);
S.insert(make_pair(min(u,v),max(u,v)));
if(findx(u)==findx(v)) continue;
int fu = findx(u), fv = findx(v);
root[fu] = fv; sz[fv] += sz[fu];
}
for(int i = 1; i <= n; i++){
if(sz[findx(i)]==1) continue;
int fx = findx(i);
if(!msk.count(fx)) msk[fx] = compID++, comp << vi();
comp[msk[fx]] << i;
}
f.resize(compID);
for(int i = 0; i < compID; i++) preprocess(i);
for(int i = 1; i <= n; i++) if(sz[findx(i)]==1) cnt[l[i]]++, cnt[r[i]+1]--;
for(int i = 1; i <= n; i++) cnt[i] += cnt[i-1];
int ret = 0, tot = 0;
for(int i = 1; i <= n; i++) if(findx(i)==i and sz[i]!=1) tot += sz[i];
for(int i = 1; i <= n; i++){ // 枚举总数
int top = min(tot,i);
vector<vi> g(2,vi(top+1,0));
int tag = 0;
g[0][0] = 1;
for(int j = 0; j < compID; j++){
tag ^= 1;
fill(all(g[tag]),0);
int msk = 0;
for(int k = 0; k < (int)comp[j].size(); k++) if(l[comp[j][k]]<=i and i<=r[comp[j][k]]) msk |= (1 << k); // 找出当前集合中符合条件的
for(int k = 0; k <= min(top,__builtin_popcount(msk)); k++) // 枚举当前集合中选的数量
for(int p = k; p <= top; p++) g[tag][p] = (g[tag][p] + 1ll * g[tag^1][p-k] * f[j][k][msk]) % MOD;
}
for(int j = 0; j <= top; j++) ret = (ret + 1ll * g[tag][j] * C(cnt[i],i-j)) % MOD;
}
cout << ret << endl;
}
int main(){
#ifndef ONLINE_JUDGE
freopen("Local.in","r",stdin);
freopen("ans.out","w",stdout);
#endif
fac[0] = rfac[0] = inv[1] = 1;
for(int i = 1; i < MAXN; i++) fac[i] = 1ll * fac[i-1] * i % MOD;
for(int i = 2; i < MAXN; i++) inv[i] = 1ll * (MOD - MOD/i) * inv[MOD%i] % MOD;
for(int i = 1; i < MAXN; i++) rfac[i] = 1ll * rfac[i-1] * inv[i] % MOD;
solve();
return 0;
}