luogu「EZEC-4.5」子序列

题意

  • 定义一个长度为\(n\)的序列\(P\)的价值为\(\displaystyle \sum_{i = 1}^n p_i \prod_{i = 1}^n p_i\)
  • 给定一个序列\(A\),求出子序列中最大和最小位置跨度不超过\(k\)的子序列价值之和

考虑将每个子序列的价值算在子序列的最初的位置上,那么我们只需要求出每个位置\(i\)之后\(k\)个位置任选,强制选\(i\)的贡献就行了
那么考虑将一个数字\(x\)加入选的集合:

\[(\sum_{i = 1}^n p_i + x) \prod_{i = 1}^n p_i \cdot x \\ \sum_{i = 1}^n p_i \prod_{i = 1}^n p_i \cdot x + \prod_{i = 1}^n p_i \cdot x^2 \]

注意到区间\([L,R]\)之内的\(\displaystyle \sum_p \prod_{i = 1}^n p_i = \prod_{i = L}^R (a_i + 1)\)
所以我们可以考虑序列中每个数的贡献:
\(M = \prod_{i = 1}^n (p_i + 1)\)
那么贡献为

\[\displaystyle \sum_{i = 1}^n p_i^2 \frac{M}{p_i + 1} \]

只可惜这样涉及到逆元,无法处理,我们考虑递推:
\(f\)表示当前的贡献,\(g\)表示当前的\(\displaystyle \prod_i (a_i +1)\),那么:

\[\begin{cases} f^{'} \gets f \cdot (a_i +1) + g \cdot a_i^2\\ g^{'} \gets g \cdot (a_i + 1) \end{cases} \]

显然可以矩阵优化,这里有个小trick,就是每次只会使用\(k\)个矩阵之积,于是将数组分为\(k\)个一组,维护前缀积和后缀积,这样即可\(o(1)\)查询
代码

#include<bits/stdc++.h>
#define For(i, a, b) for(int i = (a), en = (b); i <= en; ++i)
#define Rof(i, a, b) for(int i = (a), en = (b); i >= en; --i)
#define Tra(u, i) for(int i = hd[u]; ~i; i = e[i].net)
#define LL long long
#define DD double
#define LD long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define inf 0x3f3f3f3f
#define eps 1e-12
#define maxn 1000000
using namespace std;

int n, k, mod, as = 0, s[maxn + 5], sum = 0;
struct Mat{int a[2][2];} a[2 * maxn + 5], f[2 * maxn + 5], g[2 * maxn + 5], uni = (Mat){{{1, 0}, {0, 1}}};
Mat operator * (Mat x, Mat y){
	Mat asi = (Mat){0, 0, 0, 0};
	asi.a[0][0] = (1ll * x.a[0][0] * y.a[0][0] + 1ll * x.a[0][1] * y.a[1][0]) % mod;
	asi.a[0][1] = (1ll * x.a[0][0] * y.a[0][1] + 1ll * x.a[0][1] * y.a[1][1]) % mod;
	asi.a[1][0] = (1ll * x.a[1][0] * y.a[0][0] + 1ll * x.a[1][1] * y.a[1][0]) % mod;
	asi.a[1][1] = (1ll * x.a[1][0] * y.a[0][1] + 1ll * x.a[1][1] * y.a[1][1]) % mod;
	return asi;
}

Mat sol(int l, int r){
	if(l > r) return uni;
	if(l % k == 0) return f[r];
	return g[l] * f[r];
}

int main(){
	//freopen("in", "r", stdin);
	scanf("%d %d %d", &n, &k, &mod);
	For(i, 1, n){
		scanf("%d", &s[i]);
		sum = (sum + 1ll * s[i] * s[i] % mod) % mod;
		a[i] = (Mat){{{s[i] + 1, 0}, {1ll * s[i] * s[i] % mod, s[i] + 1}}};
	}
	if(!k){printf("%d\n", sum); return 0;}
	For(i, n + 1, n + k) a[i] = uni;
	For(i, 1, n + k){
		if(i % k != 0) f[i] = f[i - 1] * a[i];
		else f[i] = a[i];
	}
	Rof(i, n + k, 1){
		if(i % k != k - 1) g[i] = g[i + 1] * a[i];
		else g[i] = a[i];
	}
	For(i, 1, n){
		Mat tem = sol(i + 1, i + k);
		int f = tem.a[1][0], g = tem.a[1][1];
		as = (as + 1ll * s[i] * f % mod + 1ll * s[i] * s[i] % mod * g % mod) % mod;
	}
	printf("%d\n", as);
	return 0;
}
posted @ 2020-10-04 23:40  lprdsb  阅读(154)  评论(0编辑  收藏  举报