hihoCoder #1388 : Periodic Signal

NTT (long long 版)

#include <algorithm>
#include <cstring>
#include <string.h>
#include <iostream>
#include <list>
#include <map>
#include <set>
#include <stack>
#include <string>
#include <utility>
#include <vector>
#include <cstdio>
#include <cmath>

#define INF 0x3ffffff

using namespace std;

typedef long long LL;
const int N = 150007;
const LL P = 180143985094819841LL; //190734863287 * 2 ^ 18 + 1
const int G = 3;
LL wn[25];

LL mul(LL x, LL y)
{
    return (x * y - (LL)(x / (long double)P * y + 1e-3) * P + P) % P;
}
LL qpow(LL x, LL k, LL p) {
    LL ret = 1;
    while(k) {
        if(k & 1) ret = mul(ret, x);
        k >>= 1;
        x = mul(x, x);
    }
    return ret;
}
void getwn() {
    for(int i = 1; i <= 18; ++i) {
        int t = 1 << i;
        wn[i] = qpow(G, (P - 1) / t, P);
    }
}
void change(LL *y, int len) {
    for(int i = 1, j = len / 2; i < len - 1; ++i) {
        if(i < j) swap(y[i], y[j]);
        int k = len / 2;
        while(j >= k) {
            j -= k;
            k /= 2;
        }
        j += k;
    }
}
void NTT(LL *y, int len, int on) {
    change(y, len);
    int id = 0;
    for(int h = 2; h <= len; h <<= 1) {
        ++id;
        for(int j = 0; j < len; j += h) {
            LL w = 1;
            for(int k = j; k < j + h / 2; ++k) {
                LL u = y[k];
                LL t = mul(y[k+h/2], w);
                y[k] = u + t;
                if(y[k] >= P) y[k] -= P;
                y[k+h/2] = u - t + P;
                if(y[k+h/2] >= P) y[k+h/2] -= P;
                w = mul(w, wn[id]);
            }
        }
    }
    if(on == -1) {
        for(int i = 1; i < len / 2; ++i) swap(y[i], y[len-i]);
        LL inv = qpow(len, P - 2, P);
        for(int i = 0; i < len; ++i)
            y[i] = mul(y[i], inv);
    }
}
LL a[60007], b[60007];
LL x[N], y[N], num[N];
void mul(LL a[], LL b[], LL c[], int len)
{
    NTT(a, len, 1);
    NTT(b, len, 1);
    for (int i = 0; i < len; i++)
    {
        c[i] = mul(a[i], b[i]);
    }
    NTT(c, len, -1);
}
void init(){
     memset(num,0,sizeof(num));
     memset(x,0,sizeof(x));
     memset(y,0,sizeof(y));
}
int main()
{
    int T;
    scanf("%d",&T);
    getwn(); //!!!
    LL suma,sumb;
    while(T--)
        {
            int n;
             suma=0;sumb=0; //suma为A[]平方和
            init();
            scanf("%d",&n);
             for(int i = 0;i < n;i++) {scanf("%lld",&a[i]);suma+=a[i]*a[i];}
             for(int i = 0;i < n;i++) {scanf("%lld",&b[i]);sumb+=b[i]*b[i];}
            int len = 1;
            while( len < 2*n ) len <<= 1;
             for(int i = 0;i < n;i++){
                x[i] = a[i];
            }
             for(int i = 0;i < n;i++){
                y[i] = b[n-i-1];
            }
                    mul(x, y, num, len); //NTT
          LL ret=num[n-1];
          for(int i=0;i<n-2;i++) {
                ret=max(ret,num[i]+num[i+n]);
          }
          LL ans=suma+sumb-2*ret;
         cout<< ans<<endl;
        }
    return 0;
}

 

posted @ 2016-09-24 19:14  smartweed  阅读(258)  评论(0编辑  收藏  举报