[AH2017/HNOI2017]礼物 解题报告
标签: NTT
题意
有两个手链, 这两个手链上分别有 \(n\) 个装饰品, ( \(1 \le n \le 5 \times 10^4\) ).
每个装饰品都有一个亮度, 两个手链上装饰品的亮度分别为 \(a_i,\ b_i\), 且 \(1 \le a_i,\ b_i \le m\), (\(1 \le m \le 100\)).
现可以将一个手链上的所有装饰品的亮度值增加一个非负正数 \(c\), 并可以将手链进行旋转, 使得
最小, 并求出这个最小值
思路
为了描述方便, 我们把 $ \sum_{i=1}^{n} (a_i-b_i)^2 $ 这个式子称作 "亮度差".
现考虑把一个手链的亮度值增加 \(c\) 后, 亮度差会如何变化,
前后亮度差的差值为
惊奇地发现, 无论装饰品的对应关系如何, 即不管手链如何旋转, 亮度差的变化值是一定的, 并且是关于亮度增加值 \(c\) 的一个二次函数, 可以直接求得最小值.
这样的话, 我们就只需要考虑如何旋转手链能使得初始亮度差最小.
初始亮度差为,
要是上述式子最小, 就需要求到 \(\sum_{i=1}^{n} a_ib_i\) 的最大值.
我们把序列 \(a\) 翻转, 使得 \(a_i = a_{n - i + 1}\), 那么上述式子就变为 \(\sum_{i = 1}^{n} a_{n - i + 1} b_i\), 是一个加法卷积的形式.
我们把 \(a\) 倍长, 并与 \(b\) 进行一次卷积, 那么卷积结果 \(c\) 的第 \(n + i\) 位就是 \(a\) 从第 \(i\) 位开始与 \(b\) 相乘的结果, 就是题目中的将 \(a\) 旋转了 \(i - 1\) 次. 所以我们 NTT 后取 \(n + 1 \sim 2n\) 的最大值即可.
代码
#include<bits/stdc++.h>
#define ll long long
#define db double
using namespace std;
const int _=3e5+7;
const int p=998244353;
const int rt=3;
bool be;
int n,m,a[_],b[_],f[_],g[_],t,invt,invrt,num[_];
bool en;
int q_pow(int a,int k){
int res=1;
while(k){
if(k&1) res=(ll)res*a%p;
a=(ll)a*a%p; k>>=1;
}
return res;
}
void NTT(int *f,int id){
for(int i=0;i<t;i++)
if(i<num[i]) swap(f[i],f[num[i]]);
for(int len=2;len<=t;len<<=1){
int gap=len>>1;
int w1=q_pow(id==1 ?rt :invrt,(p-1)/len);
for(int k=0;k<t;k+=len){
int w=1;
for(int i=k;i<k+gap;i++,w=(ll)w*w1%p){
int tmp=(ll)w*f[i+gap]%p;
f[i+gap]=(f[i]-tmp+p)%p;
f[i]=(f[i]+tmp)%p;
}
}
}
}
int main(){
//freopen("gift.in","r",stdin);
//freopen("gift.out","w",stdout);
cin>>n>>m;
for(int i=1;i<=n;i++){ scanf("%d",&a[i]); a[0]+=a[i]; }
for(int i=1;i<=n;i++){ scanf("%d",&b[i]); b[0]+=b[i]; }
int A=-n,B=2*(a[0]-b[0]);
db tmp=(db)-B/(2*A);
db t1=tmp-floor(tmp),t2=ceil(tmp)-tmp;
int c= t1<t2 ?floor(tmp) :ceil(tmp);
//printf("c: %d\n",c);
c=A*c*c+B*c;
for(int i=0;i<n;i++) f[i]=a[n-i];
for(int i=1;i<=n;i++) g[i]=b[i];
for(int i=n+1;i<=2*n;i++) g[i]=b[i-n];
t=1; while(t<=2*n) t<<=1;
invt=q_pow(t,p-2); invrt=q_pow(rt,p-2);
for(int i=0;i<t;i++)
num[i]=(num[i>>1]>>1)|((i&1) ?t>>1 :0);
NTT(f,1);
NTT(g,1);
for(int i=0;i<t;i++) f[i]=(ll)f[i]*g[i]%p;
NTT(f,-1);
ll ans=-0x3f3f3f3f,res=0;
for(int i=n+1;i<=2*n;i++) ans=max(ans,(ll)f[i]*invt%p);
for(int i=1;i<=n;i++) res+=a[i]*a[i]+b[i]*b[i];
ans=res-2*ans-c;
printf("%lld\n",ans);
// printf("\nused space: %.2lfMB\n",(&en-&be)/1048576.0);
return 0;
}