多项式三角函数
\(\text{Problem}:\)多项式三角函数
\(\text{Solution}:\)
引理 \(1\)(欧拉公式):
\[e^{ix}=\cos x+i\sin x
\]
将 \(x\) 用 \(-x\) 代入,解方程后得到三角函数的另一个表达式:
\[\begin{aligned}
\sin x&=\cfrac{e^{ix}-e^{-ix}}{2i}\\
\cos x&=\cfrac{e^{ix}+e^{-ix}}{2}
\end{aligned}
\]
在模意义下,找到 \(p\) 的原根 \(a\),有 \(i\equiv a^{\frac{p-1}{4}}\pmod p\)。利用多项式 \(exp\) 求解,总时间复杂度为 \(O(n\log n)\)。
如果要求 \(\tan x\),根据定义 \(\tan x=\dfrac{\sin x}{\cos x}\) 计算即可。
\(\text{Code}:\)
#include <bits/stdc++.h>
#pragma GCC optimize(3)
//#define int long long
#define ri register
#define mk make_pair
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define is insert
#define es erase
#define vi vector<int>
#define vpi vector<pair<int,int>>
using namespace std; const int N=265010, Mod=998244353;
inline int read()
{
int s=0, w=1; ri char ch=getchar();
while(ch<'0'||ch>'9') { if(ch=='-') w=-1; ch=getchar(); }
while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48), ch=getchar();
return s*w;
}
int n,type,I;
vector<int> a,b,Ans,F;
int rev[N],r[24][2],iiv[N+5];
inline int ksc(int x,int p) { int res=1; for(;p;p>>=1, x=1ll*x*x%Mod) if(p&1) res=1ll*res*x%Mod; return res; }
inline void Get_Rev(int T) { for(ri int i=0;i<T;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(T>>1):0); }
inline void DFT(int T,vector<int> &s,int type)
{
for(ri int i=0;i<T;i++) if(rev[i]<i) swap(s[i],s[rev[i]]);
for(ri int i=2,cnt=1;i<=T;i<<=1,cnt++)
{
int wn=r[cnt][type];
for(ri int j=0,mid=(i>>1);j<T;j+=i)
{
for(ri int k=0,w=1;k<mid;k++,w=1ll*w*wn%Mod)
{
int x=s[j+k], y=1ll*w*s[j+mid+k]%Mod;
s[j+k]=(x+y)%Mod;
s[j+mid+k]=x-y;
if(s[j+mid+k]<0) s[j+mid+k]+=Mod;
}
}
}
if(!type) for(ri int i=0,inv=ksc(T,Mod-2);i<T;i++) s[i]=1ll*s[i]*inv%Mod;
}
inline void NTT(int n,int m,vector<int> &A,vector<int> &B)
{
int len=n+m;
int T=1;
while(T<=len) T<<=1;
Get_Rev(T);
A.resize(T), B.resize(T);
for(ri int i=n+1;i<T;i++) A[i]=0;
for(ri int i=m+1;i<T;i++) B[i]=0;
DFT(T,A,1), DFT(T,B,1);
for(ri int i=0;i<T;i++) A[i]=1ll*A[i]*B[i]%Mod;
DFT(T,A,0);
}
void GetInv(int n,vector<int> &F,vector<int> &G)
{
if(n==1) { F[0]=ksc(G[0],Mod-2); return; }
GetInv((n+1)/2,F,G);
vector<int> A,B;
int T=1;
while(T<=n+n) T<<=1;
Get_Rev(T);
A.resize(T), B.resize(T);
for(ri int i=0;i<n;i++) A[i]=F[i], B[i]=G[i];
DFT(T,A,1), DFT(T,B,1);
for(ri int i=0;i<T;i++) A[i]=(2ll*A[i]%Mod-1ll*B[i]*A[i]%Mod*A[i]%Mod+Mod)%Mod;
DFT(T,A,0);
for(ri int i=0;i<n;i++) F[i]=A[i];
}
inline void GetDao(int n,vector<int> &A,vector<int> &B)
{
for(ri int i=0;i<n-1;i++) A[i]=1ll*(i+1)*B[i+1]%Mod;
A[n-1]=0;
}
inline void GetJi(int n,vector<int> &A,vector<int> &B)
{
for(ri int i=1;i<n;i++) A[i]=1ll*B[i-1]*iiv[i]%Mod;
A[0]=0;
}
inline void GetLn(int n,vector<int> &F,vector<int> &G)
{
vector<int> A,B;
A.resize(n), B.resize(n);
GetDao(n,A,G);
GetInv(n,B,G);
NTT(n,n,A,B);
GetJi(n,F,A);
}
void GetExp(int n,vector<int> &F,vector<int> &G)
{
if(n==1) { F[0]=1; return; }
GetExp((n+1)/2,F,G);
vector<int> C;
C.resize(n);
GetLn(n,C,F);
vector<int> A,B;
int T=1;
while(T<=n+n) T<<=1;
Get_Rev(T);
A.resize(T), B.resize(T);
for(ri int i=0;i<n;i++) A[i]=F[i], B[i]=(G[i]-C[i]+Mod)%Mod; B[0]++;
DFT(T,A,1), DFT(T,B,1);
for(ri int i=0;i<T;i++) A[i]=1ll*A[i]*B[i]%Mod;
DFT(T,A,0);
for(ri int i=0;i<n;i++) F[i]=A[i];
}
inline void GetSin()
{
for(ri int i=0;i<n;i++) b[i]=1ll*I*a[i]%Mod;
GetExp(n,Ans,b);
for(ri int i=0;i<n;i++) b[i]=1ll*(Mod-I)*a[i]%Mod;
GetExp(n,F,b);
for(ri int i=0;i<n;i++) Ans[i]=(Ans[i]-F[i]+Mod)%Mod;
for(ri int i=0,invI=ksc(I*2%Mod,Mod-2);i<n;i++) Ans[i]=1ll*Ans[i]*invI%Mod;
for(ri int i=0;i<n;i++) printf("%d ",Ans[i]);
puts("");
}
inline void GetCos()
{
for(ri int i=0;i<n;i++) b[i]=1ll*I*a[i]%Mod;
GetExp(n,Ans,b);
for(ri int i=0;i<n;i++) b[i]=1ll*(Mod-I)*a[i]%Mod;
GetExp(n,F,b);
for(ri int i=0;i<n;i++) Ans[i]=(Ans[i]+F[i])%Mod;
for(ri int i=0,invI=(Mod+1)/2;i<n;i++) Ans[i]=1ll*Ans[i]*invI%Mod;
for(ri int i=0;i<n;i++) printf("%d ",Ans[i]);
puts("");
}
signed main()
{
iiv[1]=1;
for(ri int i=2;i<=N;i++) iiv[i]=1ll*(Mod-Mod/i)*iiv[Mod%i]%Mod;
r[23][1]=ksc(3,119), r[23][0]=ksc(ksc(3,Mod-2),119);
for(ri int i=22;~i;i--) r[i][0]=1ll*r[i+1][0]*r[i+1][0]%Mod, r[i][1]=1ll*r[i+1][1]*r[i+1][1]%Mod;
n=read(), type=read(); I=ksc(3,(Mod-1)/4);
a.resize(n), b.resize(n), F.resize(n), Ans.resize(n);
for(ri int i=0;i<n;i++) a[i]=read();
if(!type) GetSin();
else GetCos();
return 0;
}
夜畔流离回,暗叹永无殿。
独隐万花翠,空寂亦难迁。
千秋孰能为,明灭常久见。
但得心未碎,踏遍九重天。