十进制矩阵乘法优化DP

十进制矩乘优化DP

P1397

[NOI2013] 矩阵游戏

题目描述

婷婷是个喜欢矩阵的小朋友,有一天她想用电脑生成一个巨大的 \(n\)\(m\) 列的矩阵(你不用担心她如何存储)。她生成的这个矩阵满足一个神奇的性质:若用 \(F[i][j]\) 来表示矩阵中第 \(i\) 行第 \(j\) 列的元素,则 \(F[i][j]\) 满足下面的递推式:

\[F[1][1]=1 \]

\[F[i,j]=a\times F[i][j-1]+b (j\neq 1) \]

\[F[i,1]=c\times F[i-1][m]+d (i\neq 1) \]

递推式中 \(a,b,c,d\) 都是给定的常数。

现在婷婷想知道 \(F[n][m]\) 的值是多少,请你帮助她。由于最终结果可能很大,你只需要输出 \(F[n][m]\) 除以 \(1,000,000,007\) 的余数。

输入格式

包含一行有六个整数 \(n,m,a,b,c,d\)。意义如题所述。

输出格式

包含一个整数,表示 \(F[n][m]\)除以 \(1,000,000,007\) 的余数。

样例 #1

样例输入 #1

3 4 1 3 2 6

样例输出 #1

85

提示

【样例1说明】

样例中的矩阵为:

1 4 7 10

26 29 32 35

76 79 82 85

数据范围

测试点编号 数据范围
1 \(1 \le n,m \le 10\)\(1 \le a,b,c,d \le 1000\)
2 \(1 \le n,m \le 100\)\(1 \le a,b,c,d \le 1000\)
3 \(1 \le n,m \le 10^3\)\(1 \le a,b,c,d \le 10^9\)
4 \(1 \le n,m \le 10^3\)\(1 \le a,b,c,d \le 10^9\)
5 \(1 \le n,m \le 10^9\)\(1 \le a = c \le 10^9\)\(1 \le b = d \le 10^9\)
6 \(1 \le n,m \le 10^9\)\(a = c = 1\)\(1 \le b,d \le 10^9\)
7 \(1 \le n,m,a,b,c,d \le 10^9\)
8 \(1 \le n,m,a,b,c,d \le 10^9\)
9 \(1 \le n,m,a,b,c,d \le 10^9\)
10 \(1 \le n,m,a,b,c,d \le 10^9\)
11 \(1 \le n,m \le 10^{1\,000}\)\(a = c = 1\)\(1 \le b,d \le 10^9\)
12 \(1 \le n,m \le 10^{1\,000}\)\(1 \le a = c \le 10^9\)\(1 \le b = d \le 10^9\)
13 \(1 \le n,m \le 10^{1\,000}\)\(1 \le a,b,c,d \le 10^9\)
14 \(1 \le n,m \le 10^{1\,000}\)\(1 \le a,b,c,d \le 10^9\)
15 \(1 \le n,m \le 10^{20\,000}\)\(1 \le a,b,c,d \le 10^9\)
16 \(1 \le n,m \le 10^{20\,000}\)\(1 \le a,b,c,d \le 10^9\)
17 \(1 \le n,m \le 10^{1\,000\,000}\)\(a = c = 1\)\(1 \le b,d \le 10^9\)
18 \(1 \le n,m \le 10^{1\,000\,000}\)\(1 \le a = c \le 10^9\)\(1 \le b = d \le 10^9\)
19 \(1 \le n,m \le 10^{1\,000\,000}\)\(1 \le a,b,c,d \le 10^9\)
20 \(1 \le n,m \le 10^{1\,000\,000}\)\(1 \le a,b,c,d \le 10^9\)

分析

对于此题,不难发现肯定是矩阵乘法
而我们的矩阵乘法加速递推一般而言是加速的一维数组的递推,对于这个二维数组,我们可以将其编号为一维数组,具体的原来的\(F[i,j]\)对应新的\(F[(i-1)m+j]\)

那么这道题的递推只与上一个状态有关,所以我们大可以设一个长度为2的数组\(f_{i}=[1,F[i]]\),那么考虑分情况转移

  1. \(i\bmod m\neq 1\)

此时就是\(F[i]=aF[i-1]+b\),可以构造辅助矩阵\(A=\begin{bmatrix}1 & b \\ 0 & a\\ \end{bmatrix}\),使得\(f[i-1]\times A=f[i]\)

类似的,在\(i\bmod m\equiv1\)时,此时可以构造矩阵\(B=\begin{bmatrix}1 & d \\ 0 & b\\ \end{bmatrix}\),使得\(f[i-1]\times B =f[i]\)
所以说,由\(F[1]->F[m+1]\)就有\(f[1]\times A^{m-1}B\)

所以,\(F[nm]=f[1]\times(A^{m-1}B)^{n-1}\times A^{m-1}\)

至此,便可以使用矩阵快速幂求解

另外,由于\(n,m\)太大,需要存成字符串

//P1397
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define int long long
//#define scanf scanf_s
using namespace std;
int f[3], a[3][3], b[3][3], a1, b1, c1, d1, n, m;
int c[3][3], d[3][3], e[3][3];
char p[1000500], q[1000500];
#define mod 1000000007
void mul(int a[3][3], int b[3][3]) {
	int c[3][3];
	memset(c, 0, sizeof c);
	for (int i = 1; i <= 2; i++) {
		for (int j = 1; j <= 2; j++) {
			for (int k = 1; k <= 2; k++) {
				c[i][j] += a[i][k] % mod * b[k][j] % mod;
				c[i][j] %= mod;
			}
		}
	}
	for (int i = 1; i <= 2; i++) {
		for (int j = 1; j <= 2; j++) {
			a[i][j] = c[i][j] % mod;
		}
	}
}
void mul2(int a[3], int b[3][3]) {
	int f[3];
	memset(f, 0, sizeof f);
	for (int i = 1; i <= 2; i++) {
		for (int j = 1; j <= 2; j++) {
			f[j] += a[i] % mod * b[i][j] % mod;
			f[j] %= mod;
		}
	}
	for (int i = 1; i <= 2; i++)a[i] = f[i] % mod;
}
void power(int a[3][3], int b) {
	int c[3][3];
	for (int i = 1; i <= 2; i++)for (int j = 1; j <= 2; j++)c[i][j] = a[i][j] % mod;
	b--;
	while (b) {
		if (b & 1)mul(c, a);
		mul(a, a);
		b >>= 1;
	}
	for (int i = 1; i <= 2; i++)for (int j = 1; j <= 2; j++)a[i][j] = c[i][j];
}
void update(char s, int a[3][3]) {
	int x = s - '0';
	for (int i = 1; i <= x; i++) {
		mul(c, a);
	}
	//	printf("%d\n",s-'0');
	//	for (int i = 1; i <= 2; i++) {
	//		for (int j = 1; j <= 2; j++)printf("%d ", c[i][j]);
	//		puts("");
	//	}
	power(a, 10);
	//	for (int i = 1; i <= 2; i++) {
	//		for (int j = 1; j <= 2; j++)printf("%d ", a[i][j]);
	//		puts("");
	//	}
}
signed main() {
	f[1] = f[2] = 1;
	scanf("%s %s", p + 1, q + 1);
	scanf("%lld%lld%lld%lld", &a1, &b1, &c1, &d1);
	a[1][1] = 1, a[1][2] = b1, a[2][2] = a1;
	b[1][1] = 1, b[1][2] = d1, b[2][2] = c1;
	for (int i = 1; i <= 2; i++)c[i][i] = d[i][i] = 1;
	int len1 = strlen(p + 1), len2 = strlen(q + 1);
	reverse(p + 1, p + len1 + 1);
	reverse(q + 1, q + len2 + 1);
	for (int i = 1; i <= len1; i++) {
		if (p[i] != '0') {
			p[i] -= 1;
			break;
		}
		else {
			p[i] = '9';
		}
	}
	if (p[len1] == '0')len1--;
	for (int i = 1; i <= len2; i++) {
		if (q[i] != '0') {
			q[i] -= 1;
			break;
		}
		else {
			q[i] = '9';
		}
	}
	if (q[len2] == '0')len2--;
	//	printf("%d %d\n", len1, len2);
	//	printf("%c %c\n", q[1], p[1]);
	for (int i = 1; i <= len2; i++) {
		update(q[i], a);
	}
	for (int i = 1; i <= 2; i++)
		for (int j = 1; j <= 2; j++)d[i][j] = c[i][j];
	mul(c, b);
	//	for (int i = 1; i <= 2; i++) {
	//		for (int j = 1; j <= 2; j++)printf("%d ", c[i][j]);
	//		puts("");
	//	}
	for (int i = 1; i <= 2; i++)for (int j = 1; j <= 2; j++)a[i][j] = c[i][j], c[i][j] = 0;
	for (int i = 1; i <= 2; i++)c[i][i] = 1;
	for (int i = 1; i <= len1; i++) {
		update(p[i], a);
	}
	mul(c, d);
	mul2(f, c);
	//	for (int i = 1; i <= 2; i++) {
	//		for (int j = 1; j <= 2; j++)printf("%d ", c[i][j]);
	//		puts("");
	//	}
	//	for (int i = 1; i <= 2; i++) {
	//		for (int j = 1; j <= 2; j++)printf("%d ", d[i][j]);
	//		puts("");
	//	}
	printf("%lld\n", f[2]);
}
posted @ 2022-11-30 22:45  spdarkle  阅读(18)  评论(0编辑  收藏  举报