LOJ 6485. LJJ 学二项式定理
LOJ 6485. LJJ 学二项式定理
由于\(a\)的长度很短,考虑枚举\(a_i\),然后算他的贡献。
令\(k=|a|=4\)
\[Answer=\sum_{i=0}^{k-1}a_i\sum_{j=0}^{n}[k|j-i]{n\choose j}s^{j}
\]
很自然想到单位根反演:
\[Answer=\frac{1}{n}\sum_{i=0}^{k-1}a_i\sum_{j=0}^{n}(\sum_{z=0}^{k-1} w_{k}^{(j-i)\times z}){n\choose j}s^{j}
\]
交换求和符号:
\[Answer=\frac{1}{n}\sum_{i=0}^{k-1}a_i\sum_{z=0}^{k-1}w_k^{-zi}\sum_{j=0}^{n}(w_{n}^{z})^j{n\choose j}s^{j}
\]
后面的那个很想二项式定理:
由\(x^k=x^nx^{k-n}=x^n(\frac{1}{x})^{n-k}\),得到:
\[Answer=s^n\frac{1}{n}\sum_{i=0}^{k-1}a_i\sum_{z=0}^{k-1}w_k^{-zi}\sum_{j=0}^{n}(w_{n}^{z})^j{n\choose j}(\frac{1}{s})^{n-j}\\
=s^n\frac{1}{n}\sum_{i=0}^{k-1}a_i\sum_{z=0}^{k-1}w_k^{-zi}(w_{n}^z+\frac{1}{s})^n
\]
然后就做完了。
时间复杂度\(O(Tk^2\log n)\),不过\(\log n\)可能可以优化掉。
#include<bits/stdc++.h>
#define rb(a,b,c) for(int a=b;a<=c;++a)
#define rl(a,b,c) for(int a=b;a>=c;--a)
#define LL long long
#define IT iterator
#define PB push_back
#define II(a,b) make_pair(a,b)
#define FIR first
#define SEC second
#define FREO freopen("check.out","w",stdout)
#define rep(a,b) for(int a=0;a<b;++a)
#define SRAND mt19937 rng(chrono::steady_clock::now().time_since_epoch().count())
#define random(a) rng()%a
#define ALL(a) a.begin(),a.end()
#define POB pop_back
#define ff fflush(stdout)
#define fastio ios::sync_with_stdio(false)
#define check_min(a,b) a=min(a,b)
#define check_max(a,b) a=max(a,b)
using namespace std;
//inline int read(){
// int x=0;
// char ch=getchar();
// while(ch<'0'||ch>'9'){
// ch=getchar();
// }
// while(ch>='0'&&ch<='9'){
// x=(x<<1)+(x<<3)+(ch^48);
// ch=getchar();
// }
// return x;
//}
const int INF=0x3f3f3f3f;
typedef pair<int,int> mp;
/*}
*/
const int MOD=998244353;
const int G=3;
int quick(int A,int B){
int res=1;
while(B){
if(B&1) res=1ll*res*A%MOD;
B>>=1;
A=1ll*A*A%MOD;
}
return res;
}
int inv(int A){
return quick(A,MOD-2);
}
int w[4];
void add(int & A,int B){
A+=B;
if(A>=MOD) A-=MOD;
}
void solve(){
LL n;
int s,a[4];
scanf("%lld%d",&n,&s);
n%=MOD-1;
rep(i,4) scanf("%d",&a[i]);
int ans=0;
rep(j,4){
int tmp=0;
rep(k,4){
add(tmp,1ll*quick((w[k]+inv(s))%MOD,n)*inv(w[j*k%4])%MOD);
}
tmp=1ll*tmp*quick(s,n)%MOD;
add(ans,1ll*tmp*a[j]%MOD);
}
ans=1ll*ans*inv(4)%MOD;
printf("%d\n",ans);
}
int main(){
w[0]=1;
w[1]=quick(G,(MOD-1)/4);
w[2]=1ll*w[1]*w[1]%MOD;
w[3]=1ll*w[2]*w[1]%MOD;
int T;
scanf("%d",&T);
while(T--) solve();
return 0;
}