F.Four-tuples
原题链接
https://ac.nowcoder.com/acm/contest/123/F
思路
四个集合的容斥原理:|A∪B∪C∪D|=|A|+|B|+|C|+|D|-|A∩B|-|A∩C|-|A∩D|- |B∩C| - |B∩D| - |C∩D|+|A∩B∩C|+|A∩B∩D|+|A∩C∩D|+|B∩C∩D| -|A∩B∩C∩D|,只要想出每个集合的表示方法即可。
代码
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const LL mod = 1e9 + 7;
int main() {
int t;
cin >> t;
while (t -- ) {
LL l1, l2, l3, l4, r1, r2, r3, r4;
// 读入较大,不用scanf会超时
scanf("%lld%lld%lld%lld%lld%lld%lld%lld", &l1, &r1, &l2, &r2, &l3, &r3, &l4, &r4);
// 所有的选法
LL ans = (r1 - l1 + 1) * (r2 - l2 + 1) % mod * (r3 - l3 + 1) % mod * (r4 - l4 + 1) % mod;
// 1个集合,注意减去可能为负,要先加上mod在取模
LL l = max(l1, l2), r = min(r1, r2); // 1 == 2
if (r >= l) ans = (ans - ((r - l + 1) * (r3 - l3 + 1) % mod * (r4 - l4 + 1) % mod) % mod + mod) % mod;
l = max(l3, l2), r = min(r3, r2); // 2 == 3
if (r >= l) ans = (ans - ((r - l + 1) * (r1 - l1 + 1) % mod * (r4 - l4 + 1) % mod) % mod + mod) % mod;
l = max(l3, l4), r = min(r3, r4); // 3 == 4
if (r >= l) ans = (ans - ((r - l + 1) * (r1 - l1 + 1) % mod * (r2 - l2 + 1) % mod) % mod + mod) % mod;
l = max(l1, l4), r = min(r1, r4); // 1 == 4
if (r >= l) ans = (ans - ((r - l + 1) * (r3 - l3 + 1) % mod * (r2 - l2 + 1) % mod) % mod + mod) % mod;
// 2个集合
l = max(l1, max(l2, l3)), r = min(r1, min(r2, r3)); // 1 == 2 && 2 == 3
if (r >= l) ans = (ans + (r - l + 1) * (r4 - l4 + 1) % mod) % mod;
l = max(l1, max(l2, l4)), r = min(r1, min(r2, r4)); // 1 == 2 && 2 == 4
if (r >= l) ans = (ans + (r - l + 1) * (r3 - l3 + 1) % mod) % mod;
l = max(l1, max(l4, l3)), r = min(r1, min(r4, r3)); // 1 == 3 && 3 == 4
if (r >= l) ans = (ans + (r - l + 1) * (r2 - l2 + 1) % mod) % mod;
l = max(l4, max(l2, l3)), r = min(r4, min(r2, r3)); // 2 == 3 && 3 == 4
if (r >= l) ans = (ans + (r - l + 1) * (r1 - l1 + 1) % mod) % mod;
l = max(l1, l2), r = min(r1, r2); // 1 == 2 && 3 == 4
LL l_ = max(l3, l4), r_ = min(r3, r4);
if (r >= l && r_ >= l_) ans = (ans + (r - l + 1) * (r_ - l_ + 1) % mod) % mod;
l = max(l1, l4), r = min(r1, r4); // 1 == 4 && 2 == 3
l_ = max(l2, l3), r_ = min(r2, r3);
if (r >= l && r_ >= l_) ans = (ans + (r - l + 1) * (r_ - l_ + 1) % mod) % mod;
// 3个集合(4个集合相等) 1 == 2 && 2 == 3 && 3 == 4
l = max(max(l1, l2), max(l3, l4)), r = min(min(r1, r2), min(r3, r4));
if (r >= l) ans = ((ans - 3 * (r - l + 1) % mod) + mod) % mod; // 注意这里做了简化 * 3,因为最后几个集合的表示一样
cout << ans << endl;
}
return 0;
}