十进制矩阵乘法优化DP

十进制矩乘优化DP

P1397

[NOI2013] 矩阵游戏

题目描述

婷婷是个喜欢矩阵的小朋友,有一天她想用电脑生成一个巨大的 nm 列的矩阵(你不用担心她如何存储)。她生成的这个矩阵满足一个神奇的性质:若用 F[i][j] 来表示矩阵中第 i 行第 j 列的元素,则 F[i][j] 满足下面的递推式:

F[1][1]=1

F[i,j]=a×F[i][j1]+b(j1)

F[i,1]=c×F[i1][m]+d(i1)

递推式中 a,b,c,d 都是给定的常数。

现在婷婷想知道 F[n][m] 的值是多少,请你帮助她。由于最终结果可能很大,你只需要输出 F[n][m] 除以 1,000,000,007 的余数。

输入格式

包含一行有六个整数 n,m,a,b,c,d。意义如题所述。

输出格式

包含一个整数,表示 F[n][m]除以 1,000,000,007 的余数。

样例 #1

样例输入 #1

3 4 1 3 2 6

样例输出 #1

85

提示

【样例1说明】

样例中的矩阵为:

1 4 7 10

26 29 32 35

76 79 82 85

数据范围

测试点编号 数据范围
1 1n,m101a,b,c,d1000
2 1n,m1001a,b,c,d1000
3 1n,m1031a,b,c,d109
4 1n,m1031a,b,c,d109
5 1n,m1091a=c1091b=d109
6 1n,m109a=c=11b,d109
7 1n,m,a,b,c,d109
8 1n,m,a,b,c,d109
9 1n,m,a,b,c,d109
10 1n,m,a,b,c,d109
11 1n,m101000a=c=11b,d109
12 1n,m1010001a=c1091b=d109
13 1n,m1010001a,b,c,d109
14 1n,m1010001a,b,c,d109
15 1n,m10200001a,b,c,d109
16 1n,m10200001a,b,c,d109
17 1n,m101000000a=c=11b,d109
18 1n,m1010000001a=c1091b=d109
19 1n,m1010000001a,b,c,d109
20 1n,m1010000001a,b,c,d109

分析

对于此题,不难发现肯定是矩阵乘法
而我们的矩阵乘法加速递推一般而言是加速的一维数组的递推,对于这个二维数组,我们可以将其编号为一维数组,具体的原来的F[i,j]对应新的F[(i1)m+j]

那么这道题的递推只与上一个状态有关,所以我们大可以设一个长度为2的数组fi=[1,F[i]],那么考虑分情况转移

  1. imodm1

此时就是F[i]=aF[i1]+b,可以构造辅助矩阵A=[1b0a],使得f[i1]×A=f[i]

类似的,在imodm1时,此时可以构造矩阵B=[1d0b],使得f[i1]×B=f[i]
所以说,由F[1]>F[m+1]就有f[1]×Am1B

所以,F[nm]=f[1]×(Am1B)n1×Am1

至此,便可以使用矩阵快速幂求解

另外,由于n,m太大,需要存成字符串

//P1397
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define int long long
//#define scanf scanf_s
using namespace std;
int f[3], a[3][3], b[3][3], a1, b1, c1, d1, n, m;
int c[3][3], d[3][3], e[3][3];
char p[1000500], q[1000500];
#define mod 1000000007
void mul(int a[3][3], int b[3][3]) {
	int c[3][3];
	memset(c, 0, sizeof c);
	for (int i = 1; i <= 2; i++) {
		for (int j = 1; j <= 2; j++) {
			for (int k = 1; k <= 2; k++) {
				c[i][j] += a[i][k] % mod * b[k][j] % mod;
				c[i][j] %= mod;
			}
		}
	}
	for (int i = 1; i <= 2; i++) {
		for (int j = 1; j <= 2; j++) {
			a[i][j] = c[i][j] % mod;
		}
	}
}
void mul2(int a[3], int b[3][3]) {
	int f[3];
	memset(f, 0, sizeof f);
	for (int i = 1; i <= 2; i++) {
		for (int j = 1; j <= 2; j++) {
			f[j] += a[i] % mod * b[i][j] % mod;
			f[j] %= mod;
		}
	}
	for (int i = 1; i <= 2; i++)a[i] = f[i] % mod;
}
void power(int a[3][3], int b) {
	int c[3][3];
	for (int i = 1; i <= 2; i++)for (int j = 1; j <= 2; j++)c[i][j] = a[i][j] % mod;
	b--;
	while (b) {
		if (b & 1)mul(c, a);
		mul(a, a);
		b >>= 1;
	}
	for (int i = 1; i <= 2; i++)for (int j = 1; j <= 2; j++)a[i][j] = c[i][j];
}
void update(char s, int a[3][3]) {
	int x = s - '0';
	for (int i = 1; i <= x; i++) {
		mul(c, a);
	}
	//	printf("%d\n",s-'0');
	//	for (int i = 1; i <= 2; i++) {
	//		for (int j = 1; j <= 2; j++)printf("%d ", c[i][j]);
	//		puts("");
	//	}
	power(a, 10);
	//	for (int i = 1; i <= 2; i++) {
	//		for (int j = 1; j <= 2; j++)printf("%d ", a[i][j]);
	//		puts("");
	//	}
}
signed main() {
	f[1] = f[2] = 1;
	scanf("%s %s", p + 1, q + 1);
	scanf("%lld%lld%lld%lld", &a1, &b1, &c1, &d1);
	a[1][1] = 1, a[1][2] = b1, a[2][2] = a1;
	b[1][1] = 1, b[1][2] = d1, b[2][2] = c1;
	for (int i = 1; i <= 2; i++)c[i][i] = d[i][i] = 1;
	int len1 = strlen(p + 1), len2 = strlen(q + 1);
	reverse(p + 1, p + len1 + 1);
	reverse(q + 1, q + len2 + 1);
	for (int i = 1; i <= len1; i++) {
		if (p[i] != '0') {
			p[i] -= 1;
			break;
		}
		else {
			p[i] = '9';
		}
	}
	if (p[len1] == '0')len1--;
	for (int i = 1; i <= len2; i++) {
		if (q[i] != '0') {
			q[i] -= 1;
			break;
		}
		else {
			q[i] = '9';
		}
	}
	if (q[len2] == '0')len2--;
	//	printf("%d %d\n", len1, len2);
	//	printf("%c %c\n", q[1], p[1]);
	for (int i = 1; i <= len2; i++) {
		update(q[i], a);
	}
	for (int i = 1; i <= 2; i++)
		for (int j = 1; j <= 2; j++)d[i][j] = c[i][j];
	mul(c, b);
	//	for (int i = 1; i <= 2; i++) {
	//		for (int j = 1; j <= 2; j++)printf("%d ", c[i][j]);
	//		puts("");
	//	}
	for (int i = 1; i <= 2; i++)for (int j = 1; j <= 2; j++)a[i][j] = c[i][j], c[i][j] = 0;
	for (int i = 1; i <= 2; i++)c[i][i] = 1;
	for (int i = 1; i <= len1; i++) {
		update(p[i], a);
	}
	mul(c, d);
	mul2(f, c);
	//	for (int i = 1; i <= 2; i++) {
	//		for (int j = 1; j <= 2; j++)printf("%d ", c[i][j]);
	//		puts("");
	//	}
	//	for (int i = 1; i <= 2; i++) {
	//		for (int j = 1; j <= 2; j++)printf("%d ", d[i][j]);
	//		puts("");
	//	}
	printf("%lld\n", f[2]);
}
posted @   spdarkle  阅读(21)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!
点击右上角即可分享
微信分享提示