7123. 【2021.6.15NOI模拟】尼特
给出一个序列\(a_i\),长度\(n\)。
现在对于每个长度为\(n-1\)的序列\(b_i\),值域\(m\),将\(a_i\)删掉一个位置之后最大化\(\sum [a_i=b_i]\)。
对于每个\(b_i\)求和。
\(n\le 10^6\)
没前途的DP做法:
首先考虑\(b_i\)固定时怎么搞。设\(f_i\)表示\(a_{1..i}\)和\(b_{1..i-1}\)搞的时候的答案,\(c_i=\sum_{j\le i} [a_i=b_i]\)。于是有\(f_i=\max(f_{i-1}+[a_i=b_{i-1}],c_{i-1})\),\(c_i=c_{i-1}+[a_i=b_i]\).
既然要计数就DP套DP:设\(g_{i,f,c,lst}\)表示前\(i\)位,\(f\)是什么,\(c\)是什么,最后一个字符是什么的方案数。容易发现\(lst\)可以去掉,于是就得到了个\(O(n^3)\)的做法。
注意到我们要求\(\sum f*g_{n,f,c}\)。于是考虑转移时\(f\)每次新增就加入答案。于是设\(g_{n,c},s_{n,c}\),\(c\)表示原来的\(f-c\),\(g\)和\(s\)转移大体相同,只是多了个从\(g\)到\(s\)的转移。然后得到\(O(n^2)\)做法。
然后题解不知道在云什么……不知道它是怎么从这个方法上扩展的。
然后gmh114514拯救世界:
先是一个模型转化:对于\(a_i\neq a_{i+1}\)的位置,如果\(b_i=a_i\)则标个左箭头,如果\(b_i=a_{i+1}\)标个右箭头,否则不标。对于\(a_i=a_{i+1}\)的位置肯定有贡献所以可以在最后算。
现在问题是:找到个分界点,最大化左边左箭头+右边右箭头,对其计数。
好啦发现这个东西完全符合上面的DP。然而gmh114514直接选择计数!
首先要求的东西相当于:把左箭头看做+1,右箭头看做-1,前缀和最大值加右箭头的个数就是贡献。
右箭头个数的贡献可以先算,于是只有前缀最大值的贡献。为了方便直接将长度记作\(n\),\(m\leftarrow m-2\)。
按照套路,枚举\(j\ge 1\),计算\(前缀最大值\ge j\)的方案,加起来。把它画在坐标系上,按照套路,如果终点不超过\(j\)就对称过去。于是贡献为\(\sum_{j\ge 1}\sum_i calc(n,\max(i,2j-i))\),其中\(calc(x,y)\)表示从原点到\((x,y)\),每次可以向右上、正右、右下走的方案数。
那个东西等于\(\sum_{i\ge 1} calc(n,i)(2i-1)\)。写成生成函数推一下:
发现其实只需要算\(O(1)\)次\(calc\),每次计算时间\(O(n)\)。
using namespace std;
#include <bits/stdc++.h>
const int Mxdt=100000;
inline char gc() {
static char buf[Mxdt],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,Mxdt,stdin),p1==p2)?EOF:*p1++;
}
inline int read() {
int s=0,f=0;char ch=gc();
while(ch<'0'||ch>'9')f|=(ch=='-'),ch=gc();
while(ch>='0'&&ch<='9')s=(s<<3)+(s<<1)+(ch^48),ch=gc();
return f?-s:s;
}
const int N=1000005,mo=998244353;
typedef long long ll;
ll qpow(ll x,ll y=mo-2){
ll r=1;
for (;y;y>>=1,x=x*x%mo)
if (y&1)
r=r*x%mo;
return r;
}
ll fac[N],ifac[N];
void initC(int n){
fac[0]=1;
for (int i=1;i<=n;++i)
fac[i]=fac[i-1]*i%mo;
ifac[n]=qpow(fac[n]);
for (int i=n-1;i>=0;--i)
ifac[i]=ifac[i+1]*(i+1)%mo;
}
ll C(int m,int n){
return fac[m]*ifac[n]%mo*ifac[m-n]%mo;
}
int n,m;
int a[N];
void add(int &x,ll y){x=(x+y)%mo;}
ll pw[N*2];
ll calc(int x,int y){
ll ans=0;
for (int i=0;i<=x;++i)
if (i>=x+y-i)
(ans+=C(i,x+y-i)*pw[i*2-x-y]%mo*C(x,i))%=mo;
return ans;
}
int main(){
freopen("nit.in","r",stdin);
freopen("nit.out","w",stdout);
n=read(),m=read();
for (int i=1;i<=n;++i)
a[i]=read();
if (m==1){
printf("%d\n",n-1);
return 0;
}
int cnt=0;
for (int i=1;i<n;++i)
cnt+=(a[i]!=a[i+1]);
initC(n);
ll ans=0;
for (int i=0;i<=cnt;++i)
(ans+=qpow(m-1,cnt-i)*C(cnt,i)%mo*i)%=mo;
pw[0]=1;
for (int i=1;i<=cnt*2;++i)
pw[i]=pw[i-1]*(m-2)%mo;
(ans+=cnt*2*(calc(cnt-1,0)+calc(cnt-1,1)))%=mo;
(ans+=-(qpow(m,cnt)-calc(cnt,0))%mo*qpow(2))%=mo;
ans=ans*m%mo;
ans=ans*qpow(m,n-1-cnt)%mo;
(ans+=(ll)qpow(m,n-1)%mo*(n-1-cnt)%mo)%=mo;
ans=ans*qpow(m,mo-1-n)%mo;
ans=(ans+mo)%mo;
printf("%lld\n",ans);
return 0;
}