description

\(f(x)\)表示\(x\)的各个数位的乘积。
\(N*N\)的矩阵每个矩阵上都有一块金子,一次变化后,金子从\((i,j)\)变到\((f(i),f(j))\)
问一次变化后矩阵上金子个数前\(k\)大的和为多少。
\(N<=10^{12}\)

solution

  • 首先发现题目最终要的性质,\(N\)虽然很大,但\(f(x)\)一共只有至多\(14672\)种。
    因为是由多个一位数乘起来,\(f(x)=2^a*3^b*5^c*7^d\)。打个表发现数量为\(14672\)种。
    相当于说最后有金子的位置很少。
    因此考虑预处理出所有可能的值并离散化(标个号),然后求多少个\(x\)\(f(x)\)得到\(i\),设这个方案数为\(c(i)\)
  • 因为跟数位有关,而且离散化后的状态数很小,考虑数位dp求方案数。
    *这里的数位dp是递推(没有写记忆化)而且是从个位往高位推的。
    \(dp[i][j][0/1]\)表示填了后\(i\)位,乘积为\(j\)(当然是离散化后的标号),是否大于\(n\)的后\(i\)。(转移见代码)
    最后这个\(0/1\)在统计答案的时候有用,\(1\)的情况位数就必须小于\(n\)了。
  • 现在已经求得\(c_i\),最后位置\((i,j)\)的金子数为\(c_i*c_j\)
    问题转化为:求\(c_i*c_j\)的前\(k\)大。
    二维的乘积最大,考虑定下一维,从大到小移动第二维(第一维固定,肯定第二维取还没用过的最大的),用维护(堆里面存二元组,按乘积最大排序)
    即先将\(c_i\)从小到大排序。对于每个\(i\),把\(c_i*c_m\)加入(\(c_m\)是最后一个)
    每次取出堆顶,把第二维减一,执行\(k\)次可得到结果。

code:

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=15001;
const ll mod=1e9+7;
map<ll,int>mp;
ll n,pm[N];
int m,k;

void init() {
	ll w=1;
	for(int i=0;i<=39;i++) {
		ll x=w;
		for(int j=0;j<=25;j++) {
			ll y=w;
			for(int l=0;l<=17;l++) {
				ll z=w;
				for(int r=0;r<=14;r++) {
					pm[mp[w]=++m]=w;
					w*=7;
					if(w>n) break;
				}
				w=z*5;
				if(w>n) break;
			}
			w=y*3;
			if(w>n)break;
		}
		w=x*2;
		if(w>n)break;
	}
}
ll dp[17][N][2];
int len,a[N];
ll c[N];		//乘积为$i$对应的方案数
void DP() {
	ll x=n;
	while(x) {a[++len]=x%10;x/=10;}
	for(int j=1;j<=9;j++) {dp[1][mp[j]][j>a[1]]=1;}
	for(int i=2;i<=len;i++) {
		for(int j=1;j<=m;j++) {
			ll w=pm[j];
			for(int k=1;k<=9;k++) {
				if(w%k)continue;
				int o=mp[w/k];
				if(k<a[i]) {dp[i][j][0]+=dp[i-1][o][0]+dp[i-1][o][1];}
				else if(k>a[i]) {dp[i][j][1]+=dp[i-1][o][0]+dp[i-1][o][1];}
				else {
					dp[i][j][0]+=dp[i-1][o][0];
					dp[i][j][1]+=dp[i-1][o][1];
				}
			}
		}
	}
	for(int i=1;i<=len;i++)for(int j=1;j<=m;j++) {
		c[j]+=dp[i][j][0]+((i==len)?0:dp[i][j][1]);
	}
}

struct pq {
	int x,y;
	bool operator <(const pq &u) const{return c[x]*c[y]<c[u.x]*c[u.y];}
};
priority_queue<pq> Q;

void solve() {
	k=min((ll)k,(ll)m*m);
	sort(c+1,c+1+m);
	for(int i=1;i<=m;i++)Q.push((pq){i,m});
	ll ans=0;
	while(!Q.empty()&&k--) {
		int x=Q.top().x,y=Q.top().y; Q.pop();
		ans=(ans+c[x]*c[y])%mod;
		if(y>1)Q.push((pq){x,y-1});
	}
	printf("%lld",ans);
}

int main() {
	scanf("%lld%d",&n,&k);
	init();
	DP();
	solve(); 
	return 0;
}