description

f(x)表示x的各个数位的乘积。
NN的矩阵每个矩阵上都有一块金子,一次变化后,金子从(i,j)变到(f(i),f(j))
问一次变化后矩阵上金子个数前k大的和为多少。
N<=1012

solution

  • 首先发现题目最终要的性质,N虽然很大,但f(x)一共只有至多14672种。
    因为是由多个一位数乘起来,f(x)=2a3b5c7d。打个表发现数量为14672种。
    相当于说最后有金子的位置很少。
    因此考虑预处理出所有可能的值并离散化(标个号),然后求多少个xf(x)得到i,设这个方案数为c(i)
  • 因为跟数位有关,而且离散化后的状态数很小,考虑数位dp求方案数。
    *这里的数位dp是递推(没有写记忆化)而且是从个位往高位推的。
    dp[i][j][0/1]表示填了后i位,乘积为j(当然是离散化后的标号),是否大于n的后i。(转移见代码)
    最后这个0/1在统计答案的时候有用,1的情况位数就必须小于n了。
  • 现在已经求得ci,最后位置(i,j)的金子数为cicj
    问题转化为:求cicj的前k大。
    二维的乘积最大,考虑定下一维,从大到小移动第二维(第一维固定,肯定第二维取还没用过的最大的),用维护(堆里面存二元组,按乘积最大排序)
    即先将ci从小到大排序。对于每个i,把cicm加入(cm是最后一个)
    每次取出堆顶,把第二维减一,执行k次可得到结果。

code:

点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=15001;
const ll mod=1e9+7;
map<ll,int>mp;
ll n,pm[N];
int m,k;

void init() {
	ll w=1;
	for(int i=0;i<=39;i++) {
		ll x=w;
		for(int j=0;j<=25;j++) {
			ll y=w;
			for(int l=0;l<=17;l++) {
				ll z=w;
				for(int r=0;r<=14;r++) {
					pm[mp[w]=++m]=w;
					w*=7;
					if(w>n) break;
				}
				w=z*5;
				if(w>n) break;
			}
			w=y*3;
			if(w>n)break;
		}
		w=x*2;
		if(w>n)break;
	}
}
ll dp[17][N][2];
int len,a[N];
ll c[N];		//乘积为$i$对应的方案数
void DP() {
	ll x=n;
	while(x) {a[++len]=x%10;x/=10;}
	for(int j=1;j<=9;j++) {dp[1][mp[j]][j>a[1]]=1;}
	for(int i=2;i<=len;i++) {
		for(int j=1;j<=m;j++) {
			ll w=pm[j];
			for(int k=1;k<=9;k++) {
				if(w%k)continue;
				int o=mp[w/k];
				if(k<a[i]) {dp[i][j][0]+=dp[i-1][o][0]+dp[i-1][o][1];}
				else if(k>a[i]) {dp[i][j][1]+=dp[i-1][o][0]+dp[i-1][o][1];}
				else {
					dp[i][j][0]+=dp[i-1][o][0];
					dp[i][j][1]+=dp[i-1][o][1];
				}
			}
		}
	}
	for(int i=1;i<=len;i++)for(int j=1;j<=m;j++) {
		c[j]+=dp[i][j][0]+((i==len)?0:dp[i][j][1]);
	}
}

struct pq {
	int x,y;
	bool operator <(const pq &u) const{return c[x]*c[y]<c[u.x]*c[u.y];}
};
priority_queue<pq> Q;

void solve() {
	k=min((ll)k,(ll)m*m);
	sort(c+1,c+1+m);
	for(int i=1;i<=m;i++)Q.push((pq){i,m});
	ll ans=0;
	while(!Q.empty()&&k--) {
		int x=Q.top().x,y=Q.top().y; Q.pop();
		ans=(ans+c[x]*c[y])%mod;
		if(y>1)Q.push((pq){x,y-1});
	}
	printf("%lld",ans);
}

int main() {
	scanf("%lld%d",&n,&k);
	init();
	DP();
	solve(); 
	return 0;
}