[luogu 5300][bzoj 5502] [GXOI/GZOI2019] 与或和

题面

思路还是挺容易想的, 只是由于我还是太\(naive\)了一点不会做只会打暴力吧......

题目要我们求所有子矩阵的\(and\)值之和与\(or\)值之和, 一看之下似乎不好入手, 我们慢慢来.

由于\(and\)\(or\)运算都是对于每个数的同一位二进制位进行运算的, 所以我们考虑将每个数拆成二进制数, 一位一位地统计答案, 这样的话, 原矩阵就变成了\(log{n}\)个01矩阵, 我们考虑先从\(and\)值入手.

关于\(and\)

由于是01矩阵, 对于最终答案有贡献的子矩阵必然所有元素都为1, 也就是说, 题目让我们求在一个\(n*m\)的矩阵中全为1的子矩阵有多少个. 这不是极大子矩阵的裸题吗, 考虑使用单调栈, 没有学过的左转百度吧(我也是昨天听\(lzf\)同志讲了我才知道的).

关于\(or\)

其实差不多, 对答案没有贡献的矩形就是那些全部为0的矩形, 用所有的情况减掉就是了, 还是单调栈(熟悉的配方)

然后大体上题目分析就到这里了, 每一位复杂度是\(O(n ^ 2)\), 总复杂度为\(O(n ^ 2log{n})\), 下面是代码的具体实现:

#include <iostream>
#include <cstring>
#include <cstdio>
#define N 1005
#define mod 1000000007
using namespace std;

int n, mp[N][N], s[N], w[N], h[N], top = 0; 
long long ans[2]; 

inline int read()
{
	int x = 0, w = 1;
	char c = getchar();
	while(c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
	while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
	return x * w;
}

long long get_num(int x) { return (1ll * x * (x + 1)) >> 1; }

long long get_ans(int x)
{
	long long res = 0;
	top = 0; 
	for(int i = 1; i <= n; i++) h[i] = 0;
	for(int i = 1; i <= n; i++)
	{
		for(int j = 1; j <= n; j++)
		{
			if((mp[i][j] & 1) == x) h[j]++;
			else h[j] = 0;
			if(h[j] > s[top]) { s[++top] = h[j]; w[top] = 1; }
			else
			{
				int k = 0;
				while(s[top] > h[j])
				{
					k += w[top];
					res = (res + (s[top] - max(s[top - 1], h[j])) * get_num(k) % mod) % mod;
					top--; 
				}
				s[++top] = h[j]; w[top] = k + 1; 
			}
		}
		int k = 0;
		while(top)
		{
			k += w[top];
			res = (res + (s[top] - s[top - 1]) * get_num(k) % mod) % mod; 
			top--; 
		}
	}
	return (res % mod + mod) % mod; 
    //单调栈实现极大子矩阵, 其实用悬线法也可以, 但是我既不会单调栈, 悬线法也没用过几次, 于是就成功Oho了
}

void work(int x)
{
	if(x == 32) return; 
	ans[0] = (ans[0] + get_ans(1) * (1 << x) % mod) % mod;
	ans[1] = (ans[1] + (get_num(n) * get_num(n) % mod - get_ans(0)) * (1 << x) % mod) % mod;
	//注意这个地方1左移x位是因为若不把这个数拆位的话其实这个数的这一位贡献是(1 << x)的.
    for(int i = 1; i <= n; i++)
		for(int j = 1; j <= n; j++)
			mp[i][j] >>= 1;
	work(x + 1); 
}

int main()
{
	n = read();
	for(int i = 1; i <= n; i++)
		for(int j = 1; j <= n; j++)
			mp[i][j] = read();
	work(0);
	printf("%lld %lld\n", (ans[0] + mod) % mod, (ans[1] + mod) % mod); 
	return 0;
}

posted @ 2019-04-29 21:15  ztlztl  阅读(99)  评论(0编辑  收藏  举报