[luogu 5300][bzoj 5502] [GXOI/GZOI2019] 与或和
题面
思路还是挺容易想的, 只是由于我还是太\(naive\)了一点不会做只会打暴力吧......
题目要我们求所有子矩阵的\(and\)值之和与\(or\)值之和, 一看之下似乎不好入手, 我们慢慢来.
由于\(and\)和\(or\)运算都是对于每个数的同一位二进制位进行运算的, 所以我们考虑将每个数拆成二进制数, 一位一位地统计答案, 这样的话, 原矩阵就变成了\(log{n}\)个01矩阵, 我们考虑先从\(and\)值入手.
关于\(and\)值
由于是01矩阵, 对于最终答案有贡献的子矩阵必然所有元素都为1, 也就是说, 题目让我们求在一个\(n*m\)的矩阵中全为1的子矩阵有多少个. 这不是极大子矩阵的裸题吗, 考虑使用单调栈, 没有学过的左转百度吧(我也是昨天听\(lzf\)同志讲了我才知道的).
关于\(or\)值
其实差不多, 对答案没有贡献的矩形就是那些全部为0的矩形, 用所有的情况减掉就是了, 还是单调栈(熟悉的配方)
然后大体上题目分析就到这里了, 每一位复杂度是\(O(n ^ 2)\), 总复杂度为\(O(n ^ 2log{n})\), 下面是代码的具体实现:
#include <iostream>
#include <cstring>
#include <cstdio>
#define N 1005
#define mod 1000000007
using namespace std;
int n, mp[N][N], s[N], w[N], h[N], top = 0;
long long ans[2];
inline int read()
{
int x = 0, w = 1;
char c = getchar();
while(c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
return x * w;
}
long long get_num(int x) { return (1ll * x * (x + 1)) >> 1; }
long long get_ans(int x)
{
long long res = 0;
top = 0;
for(int i = 1; i <= n; i++) h[i] = 0;
for(int i = 1; i <= n; i++)
{
for(int j = 1; j <= n; j++)
{
if((mp[i][j] & 1) == x) h[j]++;
else h[j] = 0;
if(h[j] > s[top]) { s[++top] = h[j]; w[top] = 1; }
else
{
int k = 0;
while(s[top] > h[j])
{
k += w[top];
res = (res + (s[top] - max(s[top - 1], h[j])) * get_num(k) % mod) % mod;
top--;
}
s[++top] = h[j]; w[top] = k + 1;
}
}
int k = 0;
while(top)
{
k += w[top];
res = (res + (s[top] - s[top - 1]) * get_num(k) % mod) % mod;
top--;
}
}
return (res % mod + mod) % mod;
//单调栈实现极大子矩阵, 其实用悬线法也可以, 但是我既不会单调栈, 悬线法也没用过几次, 于是就成功Oho了
}
void work(int x)
{
if(x == 32) return;
ans[0] = (ans[0] + get_ans(1) * (1 << x) % mod) % mod;
ans[1] = (ans[1] + (get_num(n) * get_num(n) % mod - get_ans(0)) * (1 << x) % mod) % mod;
//注意这个地方1左移x位是因为若不把这个数拆位的话其实这个数的这一位贡献是(1 << x)的.
for(int i = 1; i <= n; i++)
for(int j = 1; j <= n; j++)
mp[i][j] >>= 1;
work(x + 1);
}
int main()
{
n = read();
for(int i = 1; i <= n; i++)
for(int j = 1; j <= n; j++)
mp[i][j] = read();
work(0);
printf("%lld %lld\n", (ans[0] + mod) % mod, (ans[1] + mod) % mod);
return 0;
}