LY1087 [ 20230217 CQYC模拟赛VIII T2 ] 记忆

我们来看这样一道题:

请你维护一个序列 \(a\)

  • 1 k 将所有 \(a_i\) 变成 \(|a_i - k|\)
  • 2 l r\(\sum_{i = l} ^ {r} a_i\)

\(n, q \le 10 ^ 5\)

首先我们不难写出一个 \(naive\) 的代码。

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#define int long long
using namespace std;

const int N = 1e5 + 5;
array <int, N> s;

signed main() {
	int n, q;
	cin >> n >> q;
	for (int i = 1; i <= n; i++)
		cin >> s[i];
	while (q--) {
		int op, x, y;
		cin >> op;
		if (op == 1) {
			cin >> x;
			for (int i = 1; i <= n; i++)
				s[i] = abs(s[i] - x);
		}
		else {
			cin >> x >> y;
			int ans = 0;
			for (int i = x; i <= y; i++)
				ans += s[i];
			cout << ans << endl;
		}
	}
	return 0;
}

这份代码跑出了 \(27.62s\) 的惊人成绩。

考虑将读入换成快读,去掉 #define int long long

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#define ll long long
using namespace std;

#ifdef ONLINE_JUDGE

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
#define putchar(x) *(u++) = (x)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;

#endif

int read() {
	int p = 0, flg = 1;
	char c = getchar();
	while (c < '0' || c > '9') {
		if (c == '-') flg = -1;
		c = getchar();
	}
	while (c >= '0' && c <= '9') {
		p = p * 10 + c - '0';
		c = getchar();
	}
	return p * flg;
}

void write(ll x) {
	if (x < 0) {
		x = -x;
		putchar('-');
	}
	if (x > 9) {
		write(x / 10);
	}
	putchar(x % 10 + '0');
}

const int N = 1e5 + 5;
array <int, N> s;

int main() {
	int n = read(), q = read();
	for (int i = 1; i <= n; i++)
		s[i] = read();
	while (q--) {
		int op = read(), x, y;
		if (op == 1) {
			x = read();
			for (int i = 1; i <= n; i++)
				s[i] = abs(s[i] - x);
		}
		else {
			x = read(), y = read();
			ll ans = 0;
			for (int i = x; i <= y; i++)
				ans += s[i];
			write(ans), putchar(10);
		}
	}

#ifdef ONLINE_JUDGE
	fwrite(ubuf, 1, u - ubuf, stdout);
#endif
	return 0;
}

优化读入后,代码跑到了 \(26.15s\)

加上优化参数:

#pragma GCC optimize("Ofast", "inline", "-ffast-math")
#pragma GCC target("avx,sse2,sse3,sse4,mmx")

这样就能 YCOJ 跑到 \(90pts\) 的好成绩。

考虑循环展开,一次走 \(8\) 位。

rg int it = 1, ed = (n - (n % 8));
for (; it <= ed; it += 8) {
	s[it] = abs(s[it] - x);
	s[it + 1] = abs(s[it + 1] - x);
	s[it + 2] = abs(s[it + 2] - x);
	s[it + 3] = abs(s[it + 3] - x);
	s[it + 4] = abs(s[it + 4] - x);
	s[it + 5] = abs(s[it + 5] - x);
    s[it + 6] = abs(s[it + 6] - x);
	s[it + 7] = abs(s[it + 7] - x);
}
for (; it <= n; it++)
	s[it] = abs(s[it] - x);

直接就在洛谷过了。用时:\(8.16s\)

但是在 YCOJ 上最大的点跑了 \(1.09s\)

只超了一点点。

考虑指令集优化:


int* h;
__m256i bk[N], k;

h = ((int *) & bk) + 7;

k = _mm256_set1_epi32(-x);
for (rg int i = 1; i <= (n + 7) / 8; i++)
	bk[i] = _mm256_add_epi32(bk[i], k);
for (rg int i = 1; i <= (n + 7) / 8; i++)
	bk[i] = _mm256_abs_epi32(bk[i]);

洛谷上只跑了 126ms,直接狂暴拿到最优解,嬴!

完整代码:

#pragma GCC optimize("Ofast,no-stack-protector,unroll-loops,fast-math")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4.1,sse4.2,avx,avx2,popcnt,tune=native")

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>

#include <immintrin.h>
#include <emmintrin.h>

#define ll long long
#define il inline
#define rg register
using namespace std;

#ifdef ONLINE_JUDGE

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
#define putchar(x) *(u++) = (x)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;

#endif

il int read() {
	rg int p = 0, flg = 1;
	rg char c = getchar();
	while (c < '0' || c > '9') {
		if (c == '-') flg = -1;
		c = getchar();
	}
	while (c >= '0' && c <= '9') {
		p = p * 10 + c - '0';
		c = getchar();
	}
	return p * flg;
}

il void write(ll x) {
	if (x < 0) {
		x = -x;
		putchar('-');
	}
	if (x > 9) {
		write(x / 10);
	}
	putchar(x % 10 + '0');
}

const int N = 1e5 + 5;
array <int, N> s;

int* h;
__m256i bk[N], k;

int main() {
	rg int n = read(), q = read();

	h = ((int *) & bk) + 7;

	for (rg int i = 1; i <= n; i++)
		h[i] = read();


	while (q--) {
		rg int op = read(), x, y;
		if (op == 1) {
			x = read();
			k = _mm256_set1_epi32(-x);
			for (rg int i = 1; i <= (n + 7) / 8; i++)
				bk[i] = _mm256_add_epi32(bk[i], k);
			for (rg int i = 1; i <= (n + 7) / 8; i++)
				bk[i] = _mm256_abs_epi32(bk[i]);
		}
		else {
			x = read(), y = read();
			rg ll ans = 0;
			rg int it = x, ed = y - (y - x + 1) % 8;
			for (; it <= ed; it += 8) {
				ans += h[it];
				ans += h[it + 1];
				ans += h[it + 2];
				ans += h[it + 3];
				ans += h[it + 4];
				ans += h[it + 5];
				ans += h[it + 6];
				ans += h[it + 7];
			}
			for (; it <= y; it++)
				ans += h[it];
			write(ans), putchar(10);
		}
	}

#ifdef ONLINE_JUDGE
	fwrite(ubuf, 1, u - ubuf, stdout);
#endif
	return 0;
}

posted @ 2024-01-16 15:41  cxqghzj  阅读(7)  评论(0编辑  收藏  举报