2021-07-07 集训题解
自动机
Description
Solution
可以想到一个 dp,设 \(f_{u,s,i}\) 表示起点在 \(u\),现在在 \(s\) ,考虑了前面 \(i\) 个字符时合法的方案数。可以列出 dp 转移式:
\[f_{u,d_{r,1},i}\to \sum_{j=1}^{i-1} f_{u,l,j}\times e_{l,0} \times f_{d_{l,0},r,i-j-1}\times e[r][1]
\]
\[f_{u,d_{r,2},i}\to f_{u,r,i-1}
\]
然后这个用分治 fft 优化即可。复杂度 \(\Theta(n\log^2n)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define SZ(A) ((int)A.size())
#define Int register int
#define mod 998244353
#define MAXN 400005
template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> void chkmax (T &a,T b){a = max (a,b);}
template <typename T> void chkmin (T &a,T b){a = min (a,b);}
int n,q,V,S[MAXN],T[MAXN],N[MAXN],d[2][3],e[2][3];
int mul (int a,int b){return 1ll * a * b % mod;}
int dec (int a,int b){return a >= b ? a - b : a + mod - b;}
int add (int a,int b){return a + b >= mod ? a + b - mod : a + b;}
int qkpow (int a,int k){
int res = 1;for (;k;k >>= 1,a = 1ll * a * a % mod) if (k & 1) res = 1ll * res * a % mod;
return res;
}
int inv (int x){return qkpow (x,mod - 2);}
void Add (int &a,int b){a = add (a,b);}
typedef vector <int> poly;
int w[MAXN],rev[MAXN];
void init_ntt (){
int lim = 1 << 18;
for (Int i = 0;i < lim;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << 17);
int Wn = qkpow (3,(mod - 1) / lim);w[lim >> 1] = 1;
for (Int i = lim / 2 + 1;i < lim;++ i) w[i] = mul (w[i - 1],Wn);
for (Int i = lim / 2 - 1;i;-- i) w[i] = w[i << 1];
}
void ntt (poly &a,int lim,int type){
#define G 3
#define Gi 332748118
static int d[MAXN];
for (Int i = 0,z = 18 - __builtin_ctz(lim);i < lim;++ i) d[rev[i] >> z] = a[i];
for (Int i = 1;i < lim;i <<= 1)
for (Int j = 0;j < lim;j += i << 1)
for (Int k = 0;k < i;++ k){
int x = mul (w[i + k],d[i + j + k]);
d[i + j + k] = dec (d[j + k],x),d[j + k] = add (d[j + k],x);
}
for (Int i = 0;i < lim;++ i) a[i] = d[i] % mod;
if (type == -1){
reverse (a.begin() + 1,a.begin() + lim);
for (Int i = 0,Inv = inv (lim);i < lim;++ i) a[i] = mul (a[i],Inv);
}
#undef G
#undef Gi
}
poly operator + (poly a,poly b){
a.resize (max (SZ (a),SZ (b)));
for (Int i = 0;i < SZ (b);++ i) a[i] = add (a[i],b[i]);
return a;
}
poly operator - (poly a,poly b){
a.resize (max (SZ (a),SZ (b)));
for (Int i = 0;i < SZ (b);++ i) a[i] = dec (a[i],b[i]);
return a;
}
poly operator * (poly a,int b){
for (Int i = 0;i < SZ (a);++ i) a[i] = mul (a[i],b);
return a;
}
poly operator + (poly a,int b){
Add (a[0],b);
return a;
}
poly operator * (poly a,poly b){
int d = SZ (a) + SZ (b) - 1,lim = 1;while (lim < d) lim <<= 1;
a.resize (lim),b.resize (lim);
ntt (a,lim,1),ntt (b,lim,1);
for (Int i = 0;i < lim;++ i) a[i] = mul (a[i],b[i]);
ntt (a,lim,-1),a.resize (d);
return a;
}
poly inv (poly a,int n){
poly b(1,inv (a[0])),c;
for (Int l = 4;(l >> 2) < n;l <<= 1){
c.resize (l >> 1);
for (Int i = 0;i < (l >> 1);++ i) c[i] = i < n ? a[i] : 0;
c.resize (l),b.resize (l);
ntt (c,l,1),ntt (b,l,1);
for (Int i = 0;i < l;++ i) b[i] = mul (b[i],dec (2,mul (b[i],c[i])));
ntt (b,l,-1),b.resize (l >> 1);
}
b.resize (n);
return b;
}
poly inv (poly a){return inv (a,SZ (a));}
poly F[2][2];
void cdq (int l,int r){
if (l > r) return ;
if (l == r){
if (l == 0){
for (Int u = 0;u < V;++ u) F[u][u][l] = 1;
}
else{
for (Int u = 0;u < V;++ u)
for (Int now = 0;now < V;++ now)
Add (F[u][d[now][2]][l],mul (e[now][2],F[u][now][l - 1]));
}
return ;
}
if (l + 1 == r){
int mid = l + r >> 1;
cdq (l,mid),cdq (mid + 1,r);
return ;
}
int mid = l + r >> 1,len = r - l + 1,len1 = mid - l + 1,len2 = r - l;
cdq (l,mid);
if (l){
for (Int u = 0;u < V;++ u){
for (Int L = 0;L < V;++ L)
for (Int R = 0;R < V;++ R){
poly F1,G1;
F1.resize (len1),G1.resize (len);
for (Int i = 0;i < len1;++ i) F1[i] = F[u][L][l + i];
for (Int i = 0;i < len;++ i) G1[i] = F[d[L][0]][R][i];
F1 = F1 * G1;
for (Int i = mid + 1;i <= r;++ i) if (i - l - 2 >= 0) Add (F[u][d[R][1]][i],mul (mul (e[L][0],e[R][1]),F1[i - l - 2]));
F1.resize (len1),G1.resize (len);
for (Int i = 0;i < len1;++ i) F1[i] = F[d[L][0]][R][l + i];
for (Int i = 0;i < len;++ i) G1[i] = F[u][L][i];
F1 = F1 * G1;
for (Int i = mid + 1;i <= r;++ i) if (i - l - 2 >= 0) Add (F[u][d[R][1]][i],mul (mul (e[L][0],e[R][1]),F1[i - l - 2]));
}
}
}
else{
for (Int u = 0;u < V;++ u){
for (Int L = 0;L < V;++ L)
for (Int R = 0;R < V;++ R){
poly F1,G1;
F1.resize (len1),G1.resize (len1);
for (Int i = 0;i < len1;++ i) F1[i] = F[u][L][i],G1[i] = F[d[L][0]][R][i];
F1 = F1 * G1;
for (Int i = mid + 1;i <= r;++ i) if (i - l - 2 >= 0) Add (F[u][d[R][1]][i],mul (mul (e[L][0],e[R][1]),F1[i - l - 2]));
}
}
}
cdq (mid + 1,r);
}
signed main(){
freopen( "dfa.in", "r", stdin );
freopen( "dfa.out", "w", stdout );
read (V);
for (Int u = 0;u < V;++ u) for (Int i = 0;i < 3;++ i) read (d[u][i],e[u][i]);
read (q);
for (Int i = 1;i <= q;++ i) read (S[i],T[i],N[i]),chkmax (n,N[i]);
for (Int u = 0;u < V;++ u) for (Int now = 0;now < V;++ now) F[u][now].resize (n + 1);
init_ntt(),cdq (0,n);
for (Int i = 1;i <= q;++ i) write (F[S[i]][T[i]][N[i]]),putchar ('\n');
return 0;
}
神
Description
Solution
考虑将两个序列放在一起排序,在原序列中属于第一个的染成白色,另一种染成黑色。可以发现,因为逆序对个数为 \(\Theta(n)\) 级别,那么段数就会是 \(\Theta(\sqrt n)\) 级别的。
那么就可以一段一段地暴力考虑了,复杂度 \(\Theta(n\sqrt n\log n)\)。
Code
#include <bits/stdc++.h>
using namespace std;
#define Int register int
#define mod 1000000007
#define MAXN 100005
template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> void chkmax (T &a,T b){a = max (a,b);}
template <typename T> void chkmin (T &a,T b){a = min (a,b);}
int n,q,a[MAXN],rt[MAXN],fac[MAXN],ifac[MAXN];
int mul (int a,int b){return 1ll * a * b % mod;}
int dec (int a,int b){return a >= b ? a - b : a + mod - b;}
int add (int a,int b){return a + b >= mod ? a + b - mod : a + b;}
int qkpow (int a,int k){
int res = 1;for (;k;k >>= 1,a = 1ll * a * a % mod) if (k & 1) res = 1ll * res * a % mod;
return res;
}
int inv (int x){return qkpow (x,mod - 2);}
void Add (int &a,int b){a = add (a,b);}
struct Segment{
#define LOGN 31
#define ls(x) son[x][0]
#define rs(x) son[x][1]
int cnt,sum[MAXN * LOGN],son[MAXN * LOGN][2];
void clear (){cnt = 0;}
void pushup (int x){sum[x] = sum[ls(x)] + sum[rs(x)];}
void modify (int &x,int y,int l,int r,int pos){
x = ++ cnt,sum[x] = sum[y] + 1,ls(x) = ls(y),rs(x) = rs(y);
//cout << l << " -> " << r << " " << x << ": " << sum[x] << endl;
if (l == r) return ;
int mid = l + r >> 1;
if (pos <= mid) modify (ls(x),ls(y),l,mid,pos);
else modify (rs(x),rs(y),mid + 1,r,pos);
}
int query (int x,int y,int l,int r,int v){
//cout << l << " -> " << r << ": " << x << " " << y << " " << sum[x] - sum[y] << endl;
if (l > v || sum[x] == sum[y]) return 0;
if (l == r) return l;
int mid = l + r >> 1;
if (v <= mid) return query (ls(x),ls(y),l,mid,v);
else{
int tmp = query (rs(x),rs(y),mid + 1,r,v);
if (tmp) return tmp;
else return query (ls(x),ls(y),l,mid,v);
}
}
int getit (int x,int y,int l,int r,int ql,int qr){
if (l >= ql && r <= qr) return sum[y] - sum[x];
int mid = l + r >> 1,res = 0;
if (ql <= mid) res += getit (ls(x),ls(y),l,mid,ql,qr);
if (qr > mid) res += getit (rs(x),rs(y),mid + 1,r,ql,qr);
return res;
}
}Tree;
void Work (){
read (n,q),Tree.clear();
fac[0] = 1;for (Int i = 1;i <= n;++ i) fac[i] = mul (fac[i - 1],i);
ifac[n] = qkpow (fac[n],mod - 2);for (Int i = n;i;-- i) ifac[i - 1] = mul (ifac[i],i);
for (Int i = 1;i <= n;++ i) read (a[i]),rt[i] = 0,Tree.modify (rt[i],rt[i - 1],1,n,a[i]);
//cout << Tree.query (rt[0],rt[1],1,n,n) << endl;
//return ;
while (q --> 0){
int l1,r1,l2,r2;
read (l1,r1,l2,r2);
int x = Tree.query (rt[l1 - 1],rt[r1],1,n,n),y = Tree.query (rt[l2 - 1],rt[r2],1,n,n),ans = 1,tot1 = 0,tot2 = 0;
while (x || y){
//cout << x << " , " << y << endl;
if (x < y){
tot2 += Tree.getit (rt[l2 - 1],rt[r2],1,n,x + 1,y),
y = Tree.query (rt[l2 - 1],rt[r2],1,n,x - 1);
}
else{
int sum = Tree.getit (rt[l1 - 1],rt[r1],1,n,y + 1,x);
if (tot2 >= tot1 + sum - 1) ans = mul (ans,mul (fac[tot2 - tot1],ifac[tot2 - tot1 - sum]));
else ans = 0;
tot1 += sum,x = Tree.query (rt[l1 - 1],rt[r1],1,n,y - 1);
}
}
write (ans),putchar ('\n');
}
}
signed main(){
freopen( "god.in", "r", stdin );
freopen( "god.out", "w", stdout );
int T;read (T);
while (T --> 0) Work ();
return 0;
}
我会彻查
Description
Solution
可以想到,假设设 \(f(k)\) 表示每个人选 \(k\) 个的时候的最大贡献和,那么这个东西是一个没有台阶的单峰函数,也就是说可以三分。
考虑如何计算 \(f(k)\),可以想到一定是左边一段,右边一段,中间一部分,所以直接二分在哪里分界即可。
Code
#include <bits/stdc++.h>
using namespace std;
#define Int register int
#define int long long
#define MAXN 200005
template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> void chkmax (T &a,T b){a = max (a,b);}
template <typename T> void chkmin (T &a,T b){a = min (a,b);}
int n,m,a[MAXN];
int getsum (int i,int a,int b){//计算对于第i个人,前面a个,后面b个时候的贡献
int A = a * i + n * (a - 1) * a / 2,B = b * (m - n + i) - n * (b - 1) * b / 2;
return A + B;
}
map <int,int> mp;
int f (int k){
if (mp.find (k) != mp.end()) return mp[k];
int &res = mp[k] = 0;int flg = 1;
for (Int i = 1;i <= n;++ i)
if (getsum(i,k,0) > a[i]){flg = -1;break;}
else if (getsum (i,0,k) > a[i]) flg = 0;
if (flg == 1) return res = n * m * k - (k * n) * (k * n - 1) / 2;
else if (flg == -1) return res = -k;
int l = 1,r = k;
while (l < r){
int mid = l + r + 1 >> 1;flg = 1;
for (Int i = n;i >= 1;-- i) if (getsum (i,mid - 1,k - mid + 1) > a[i]){flg = 0;break;}
if (flg) r = mid - 1;
else l = mid;
}
int j = l;
for (Int i = n;i >= 1;-- i) if (getsum (i,j - 1,k - j + 1) > a[i]){
for (Int t = 1;t <= i;++ t) res += getsum (t,j,k - j);
for (Int t = i + 1;t <= n;++ t) res += getsum (t,j - 1,k - j + 1);
int p = n * (j - 1) + i,q = m - n * (k - j) - (n - i);
for (Int i1 = i;i1 >= 1;-- i1){
chkmin (q,a[i1] - getsum (i1,j - 1,k - j));
res += q - p,-- q,-- p;
}
-- j;
if (j){
for (Int i1 = n;i1 > i;-- i1){
chkmin (q,a[i1] - getsum (i1,j - 1,k - j));
res += q - p,-- q,-- p;
}
}
return res;
}
return -1;
}
signed main(){
freopen("investigate.in","r",stdin);
freopen("investigate.out","w",stdout);
read (n,m);
for (Int i = 1;i <= n;++ i) read (a[i]);
int l = 0,r = m / n;
while (l < r){
int mid = l + r >> 1,det = f (mid + 1) - f (mid);
if (det == 0) l = r = mid;
else if (det < 0) r = mid;
else l = mid + 1;
}
write (f(l)),putchar ('\n');
return 0;
}