类欧几里得算法 重学笔记
Solution
以前学过,但是太烂,而且很有局限性,今重学一遍。
考虑假设我们要解决的问题为求:
\[\sum_{x=0}^{n} x^{k1}\lfloor\frac{ax+b}{c}\rfloor^{k2}
\]
可以发现可以分为几种情况进行讨论:
- \(a=0\) 或者 \(\lfloor\frac{an+b}{c}\rfloor=0\)
可以发现 \(\lfloor\frac{ax+b}{c}\rfloor\) 不变,直接 \(k1\) 次的前缀和就好了。
- \(a\ge c\)
设 \(q=\lfloor\frac{a}{c}\rfloor,r=a\mod c\) ,那么可以得到答案就是:
\[\sum_{x=0}^{n} x^{k1}(qx+\lfloor\frac{xr+b}{c}\rfloor)^{k2}
\]
\[=\sum_{i=0}^{k2} q^i\binom{k2}{i}\sum_{x=0}^{n} x^{k1+i}\lfloor\frac{xr+b}{c}\rfloor^{k2-i}
\]
直接递归即可。
- \(b\ge c\)
设 \(q=\lfloor\frac{b}{c}\rfloor,r=b\mod c\),同理可以得到答案就是:
\[\sum_{i=0}^{k2} \binom{k2}{i}q^i\sum_{x=0}^{n} x^{k1}\lfloor\frac{ax+r}{c}\rfloor^{k2-i}
\]
- \(\max(a,b)<c\)
设 \(M=\lfloor\frac{an+b}{c}\rfloor\),可以把 \(\lfloor\frac{ax+b}{c}\rfloor^{k2}\) 拆开,变成:
\[\sum_{j=0}^{\lfloor\frac{ax+b}{c}\rfloor-1}((j+1)^{k2}-j^{k2})
\]
那么答案就是:
\[\sum_{j=0} ((j+1)^{k2}-j^{k2})\sum_{x=0}^{n} x^{k1}[x>\lfloor\frac{cj+c-b-1}{a}\rfloor]
\]
\[\sum_{j=0} ((j+1)^{k2}-j^{k2})\sum_{i=0}^{n} i^{k1}-\sum_{j=0} ((j+1)^{k2}-j^{k2})\times \sum_{i=0}^{\lfloor\frac{cj+c-b-1}{a}\rfloor}i^{k1}
\]
然后前面这部分可以算 \(k2\) 次的前缀和,考虑如何算后面那一部分。你发现后面那一个是关于 \(\lfloor\frac{cj+c-b-1}{a}\rfloor\) 的 \(k1+1\) 次的多项式,假设第 \(i\) 次系数为 \(B_i\),那么就可以写成:
\[\sum_{i=0}^{k2-1}\binom{k2}{i}\sum_{j=0}^{k1+1} B_j\sum_{x=0}^{M-1} x^i\lfloor\frac{cx+c-b-1}{a}\rfloor^j
\]
也可以递归了。
Code
#include <bits/stdc++.h>
using namespace std;
#define Int register int
#define mod 1000000007
#define int long long
#define MAXN 15
template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c != ' ' && c != '\n') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}
template <typename T> void chkmax (T &a,T b){a = max (a,b);}
template <typename T> void chkmin (T &a,T b){a = min (a,b);}
int mul (int a,int b){return 1ll * a * b % mod;}
int dec (int a,int b){return a >= b ? a - b : a + mod - b;}
int add (int a,int b){return a + b >= mod ? a + b - mod : a + b;}
int qkpow (int a,int b){
int res = 1;for (;b;b >>= 1,a = mul (a,a)) if (b & 1) res = mul (res,a);
return res;
}
int inv (int x){return qkpow (x,mod - 2);}
void Add (int &a,int b){a = add (a,b);}
void Sub (int &a,int b){a = dec (a,b);}
struct node{
int t[MAXN][MAXN];
node(){memset (t,0,sizeof (t));}
int * operator [](const int key){return t[key];}
};
int C[MAXN][MAXN],mat[MAXN][MAXN];
struct Func{//处理每个F(i,x) \sum_{k=0}^{x} k^i 的i+1次函数
int a[MAXN];
int & operator [](const int key){return a[key];}
void Gauss (int K){
for (Int i = 0;i <= K;++ i){
int tmp = i;
for (Int j = i + 1;j <= K;++ j) if (mat[j][i]){tmp = j;break;}
if (tmp ^ i) swap (mat[tmp],mat[i]);
for (Int j = i + 1,iv = inv (mat[i][i]);j <= K;++ j){
int del = mul (mat[j][i],iv);
for (Int k = i;k <= K + 1;++ k) Sub (mat[j][k],mul (del,mat[i][k]));
}
}
for (Int i = K;~i;-- i){
for (Int j = i + 1;j <= K;++ j) Sub (mat[i][K + 1],mul (a[j],mat[i][j]));
a[i] = mul (mat[i][K + 1],inv (mat[i][i]));
}
}
void gen(int k){
for (Int i = 0,res = 0;i <= k + 1;++ i) Add (res,qkpow (i,k)),mat[i][k + 2] = res;
for (Int i = 0;i <= k + 1;++ i) for (Int j = 0,res = 1;j <= k + 1;++ j,res = mul (res,i)) mat[i][j] = res;
Gauss (k + 1);
}
int getit (int k,int x){
int res = 0;
for (Int i = k + 1;i >= 0;-- i) res = add (a[i],mul (res,x));
return res;
}
}f[MAXN];
int relans (int n,int a,int b,int c,int k1,int k2){
int res = 0;
for (Int x = 0;x <= n;++ x)
Add (res,mul (qkpow (x,k1),qkpow ((a * x + b) / c % mod,k2)));
return res;
}
node getit (int n,int a,int b,int c){
node ans;
if (a == 0 || a * n + b < c){
int t = (a * n + b) / c % mod;
for (Int k1 = 0;k1 <= 10;++ k1)
for (Int k2 = 0,res = 1;k1 + k2 <= 10;++ k2,res = mul (res,t))
ans[k1][k2] = mul (res,f[k1].getit (k1,n));
}
else if (a >= c){
int q = a / c,r = a % c;
node lst = getit (n,r,b,c);
for (Int k1 = 0;k1 <= 10;++ k1)
for (Int k2 = 0;k1 + k2 <= 10;++ k2)
for (Int i = 0,res = 1;i <= k2;++ i,res = mul (res,q))
Add (ans[k1][k2],mul (mul (res,C[k2][i]),lst[k1 + i][k2 - i]));
}
else if (b >= c){
int q = b / c,r = b % c;
node lst = getit (n,a,r,c);
for (Int k1 = 0;k1 <= 10;++ k1)
for (Int k2 = 0;k1 + k2 <= 10;++ k2)
for (Int i = 0,res = 1;i <= k2;++ i,res = mul (res,q))
Add (ans[k1][k2],mul (mul (res,C[k2][i]),lst[k1][k2 - i]));
}
else{
int M = (a * n + b) / c;
node lst = getit (M - 1,c,c - b - 1,a);
for (Int k1 = 0;k1 <= 10;++ k1)
for (Int k2 = 0;k1 + k2 <= 10;++ k2){
if (k2 == 0) ans[k1][k2] = f[k1].getit (k1,n);
else{
ans[k1][k2] = mul (qkpow (M,k2),f[k1].getit (k1,n));
for (Int i = 0;i <= k2 - 1;++ i)
for (Int j = 0;j <= k1 + 1;++ j)
Sub (ans[k1][k2],mul (mul (C[k2][i],f[k1][j]),lst[i][j]));
}
}
}
return ans;
}
signed main(){
for (Int i = 0;i <= 10;++ i) f[i].gen (i);
for (Int i = 0;i <= 10;++ i){
C[i][0] = 1;
for (Int j = 1;j <= i;++ j) C[i][j] = add (C[i - 1][j],C[i - 1][j - 1]);
}
int T;read (T);
while (T --> 0){
int n,a,b,c,k1,k2;read (n,a,b,c,k1,k2);
node ans = getit (n,a,b,c);write (ans[k1][k2]),putchar ('\n');
}
return 0;
}