【luogu CF1109E】Sasha and a Very Easy Test(线段树)

Sasha and a Very Easy Test

题目链接:luogu CF1109E

题目大意

维护一个长度为 n 的序列,有区间乘,单点除(保证能整除),区间求和答案对 p 取模。
p 不一定是质数。

思路

麻了考场被卡了常,重构了一遍之后发现改时限了,考场的代码能过了。
难泵。

首先如果 \(p\) 是质数,那直接上逆元乱搞都行。(线段树)
但是不是质数,但是我们考虑直接暴力拆解,因为有问题的时候是出现倍数。

那拆出来的质数,我们维护的值就可以用一个余数和一堆次数表示。
分别是跟 \(p\) 互质的部分,以及带上了多少个某个质数的次数。

然后维护就好了。
一个小小注意的点是我们这种分解只需要在最下层进行,因为我们除法是单点的,而且是要区间求和。
所以上面的我们直接拿值来加就可以,维护最下层的 \(n\) 个即可。

代码

考场代码(比较慢,被考场的 3s 卡了时间,虽然 luogu 上很快而且考场后改成了 8s 之后我跑了 4s

#include<map>
#include<cstdio>
#include<vector>
  
using namespace std;
  
const int N = 5e5 + 100;
int n, p, q, a[N], zs[35], tot;
map <int, int> pla;
vector <int> mic[35]; int R[35];
  
int add(int x, int y) {return x + y >= p ? x + y - p : x + y;}
int dec(int x, int y) {return x < y ? x - y + p : x - y;}
int mul(int x, int y) {return 1ll * x * y % p;}
  
int re, zf; char c;
int read() {
    re = 0; zf = 1; c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-') zf = -zf;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        re = (re << 3) + (re << 1) + c - '0';
        c = getchar();
    }
    return re * zf;
}
  
struct Val {
    int x, a[35];
};
  
Val operator *(Val x, Val y) {
    Val re = x; re.x = mul(re.x, y.x);
    for (int i = 0; i < tot; i++) re.a[i] += y.a[i];
    return re;
}
  
int exgcd(int a, int b, int &x, int &y) {
    if (!b) {
        x = 1; y = 1; return a;
    }
    int re = exgcd(b, a % b, y, x);
    y -= a / b * x; return re;
}
  
int get_inv(int now, int p) {
    int x, y; exgcd(now, p, x, y);
    return (x % p + p) % p;
}
  
Val operator /(Val x, Val y) {
    Val re = x; re.x = mul(re.x, get_inv(y.x, p));
    for (int i = 0; i < tot; i++) re.a[i] -= y.a[i];
    return re;
}
  
Val getVAL(int x) {
    Val re;
    re.x = x;
    for (int i = 0; i < tot; i++) {
        re.a[i] = 0;
        while (re.x % zs[i] == 0) re.a[i]++, re.x /= zs[i];
    }
    return re;
}
  
int msm(int id, int k) {//可以直接记忆化预处理次方,就不用快速幂,因为至多是大概 30*(m+1) 大概是 1e7 级别的
    while (k > R[id]) {
        mic[id].push_back(mul(mic[id][R[id]], zs[id])); R[id]++;
    }
    return mic[id][k];
}
  
int getval(Val x) {
    int re = x.x;
    for (int i = 0; i < tot; i++)
        re = mul(re, msm(i, x.a[i]));
    return re;
}
  
bool check_empty(Val x) {
    if (x.x != 1) return 0;
    for (int i = 0; i < tot; i++) if (x.a[i]) return 0;
    return 1;
}
  
struct XD_Tree {
    Val lzy[N << 2], val[N];
    int sum[N << 2], leaf[N << 2];
      
    void up(int now) {
        sum[now] = add(sum[now << 1], sum[now << 1 | 1]);
    }
      
    void downm(int now, Val a) {
        sum[now] = mul(sum[now], getval(a));
        lzy[now] = lzy[now] * a;
        if (leaf[now]) val[leaf[now]] = val[leaf[now]] * a;
    }
      
    void down(int now) {
        if (check_empty(lzy[now])) return ;
        downm(now << 1, lzy[now]); downm(now << 1 | 1, lzy[now]);
        lzy[now].x = 1; for (int i = 0; i < tot; i++) lzy[now].a[i] = 0;
    }
      
    void build(int now, int l, int r) {
        lzy[now].x = 1; for (int i = 0; i < tot; i++) lzy[now].a[i] = 0;
        if (l == r) {
            leaf[now] = l;
            val[l] = getVAL(a[l]); sum[now] = a[l] % p; return ;
        }
        int mid = (l + r) >> 1;
        build(now << 1, l, mid); build(now << 1 | 1, mid + 1, r);
        up(now);
    }
      
    int query(int now, int l, int r, int L, int R) {
        if (L <= l && r <= R) return sum[now];
        down(now); int mid = (l + r) >> 1, re = 0;
        if (L <= mid) re = add(re, query(now << 1, l, mid, L, R));
        if (mid < R) re = add(re, query(now << 1 | 1, mid + 1, r, L, R));
        return re;
    }
      
    void times(int now, int l, int r, int L, int R, Val a) {
        if (L <= l && r <= R) {
            downm(now, a); return ;
        }
        down(now); int mid = (l + r) >> 1;
        if (L <= mid) times(now << 1, l, mid, L, R, a);
        if (mid < R) times(now << 1 | 1, mid + 1, r, L, R, a);
        up(now);
    }
      
    void inv(int now, int l, int r, int pl, Val x) {
        if (l == r) {
            val[l] = val[l] / x; sum[now] = getval(val[l]);
            return ;
        }
        down(now); int mid = (l + r) >> 1;
        if (pl <= mid) inv(now << 1, l, mid, pl, x);
            else inv(now << 1 | 1, mid + 1, r, pl, x);
        up(now);
    }
}T;
  
int main() {
    n = read(); p = read();
      
    int tmp = p;
    for (int i = 2; i * i <= tmp; i++)
        if (tmp % i == 0) {
            zs[tot++] = i; pla[i] = tot - 1;
            while (tmp % i == 0) tmp /= i;
        }
    if (tmp > 1) zs[tot++] = tmp, pla[tmp] = tot - 1;
    for (int i = 0; i < tot; i++) mic[i].push_back(1);
      
    for (int i = 1; i <= n; i++) a[i] = read();
    T.build(1, 1, n);
    q = read();
    while (q--) {
        int op = read();
        if (op == 1) {
            int l = read(), r = read(), x = read();
            T.times(1, 1, n, l, r, getVAL(x));
        }
        if (op == 2) {
            int P = read(), x = read();
            T.inv(1, 1, n, P, getVAL(x));
        }
        if (op == 3) {
            int l = read(), r = read();
            printf("%d\n", T.query(1, 1, n, l, r));
        }
//      printf("%d\n", T.val[5].x);
//      for (int i = 0; i < tot; i++) printf("%d ", T.val[5].a[i]);
//      printf("\n");
    }
      
    return 0;
}

改进代码(根据别人的写法改进了一下,考后测只需要 1s 左右

#include<map>
#include<cstdio>
#include<vector>

using namespace std;

const int N = 5e5 + 100;
int n, p, q, a[N], phi, prime[10], tot;
int tmp[10];

int add(int x, int y) {return x + y >= p ? x + y - p : x + y;}
int dec(int x, int y) {return x < y ? x - y + p : x - y;}
int mul(int x, int y) {return 1ll * x * y % p;}

int re, zf; char c;
int read() {
	re = 0; zf = 1; c = getchar();
	while (c < '0' || c > '9') {
		if (c == '-') zf = -zf;
		c = getchar();
	}
	while (c >= '0' && c <= '9') {
		re = (re << 3) + (re << 1) + c - '0';
		c = getchar();
	}
	return re * zf;
}

int get_phi(int n) {
	int ans = n;
	for (int i = 2; i * i <= n; i++)
		if (n % i == 0) {
			ans = ans / i * (i - 1);
			while (n % i == 0) n /= i;
		}
	if (n > 1) ans = ans / n * (n - 1);
	return ans;
}

void Init() {
	phi = get_phi(p);
	int tmp = p;
	for (int i = 2;	i * i <= tmp; i++)
		if (tmp % i == 0) {
			prime[++tot] = i; while (tmp % i == 0) tmp /= i;
		}
	if (tmp > 1) prime[++tot] = tmp;
}

int ksm(int x, int y) {
	int re = 1;
	while (y) {
		if (y & 1) re = mul(re, x);
		x = mul(x, x); y >>= 1;
	}
	return re;
}

int work(int n) {
	for (int i = 1; i <= tot; i++) {
		tmp[i] = 0; while (n % prime[i] == 0) n /= prime[i], tmp[i]++;
	}
	return n % p;
}

struct XD_tree {
	int sum[N << 2], lzy[N << 2], res[N << 2], s[N << 2][10];
	
	void up(int now) {
		sum[now] = add(sum[now << 1], sum[now << 1 | 1]);
	}
	
	void downm(int now, int val, int re, int *tmp) {
		lzy[now] = mul(lzy[now], val); sum[now] = mul(sum[now], val);
		res[now] = mul(res[now], re);
		for (int i = 1; i <= tot; i++) s[now][i] += tmp[i];
	}
	
	void down(int now) {
		if (lzy[now] != 1 || res[now] != 1) {
			downm(now << 1, lzy[now], res[now], s[now]);
			downm(now << 1 | 1, lzy[now], res[now], s[now]);
			lzy[now] = res[now] = 1;
			for (int i = 1; i <= tot; i++) s[now][i] = 0;
		}
	}
	
	void build(int now, int l, int r) {
		lzy[now] = res[now] = 1;
		if (l == r) {
			sum[now] = a[l] % p;
			res[now] = work(a[l]);
			for (int i = 1; i <= tot; i++) s[now][i] = tmp[i];
			return ;
		}
		int mid = (l + r) >> 1;
		build(now << 1, l, mid); build(now << 1 | 1, mid + 1, r);
		up(now);
	}
	
	void times(int now, int l, int r, int L, int R, int val, int re) {
		if (L <= l && r <= R) {
			downm(now, val, re, tmp);
			return ;
		}
		down(now); int mid = (l + r) >> 1;
		if (L <= mid) times(now << 1, l, mid, L, R, val, re);
		if (mid < R) times(now << 1 | 1, mid + 1, r, L, R, val, re);
		up(now);
	}
	
	void inv(int now, int l, int r, int pl, int val) {
		if (l == r) {
			res[now] = mul(res[now], ksm(val, phi - 1));
			sum[now] = res[now];
			for (int i = 1; i <= tot; i++) s[now][i] -= tmp[i], sum[now] = mul(sum[now], ksm(prime[i], s[now][i]));
			return ;
		}
		down(now); int mid = (l + r) >> 1;
		if (pl <= mid) inv(now << 1, l, mid, pl, val);
			else inv(now << 1 | 1, mid + 1, r, pl, val);
		up(now);
	}
	
	int query(int now, int l, int r, int L, int R) {
		if (L <= l && r <= R) return sum[now];
		down(now); int mid = (l + r) >> 1, re = 0;
		if (L <= mid) re = add(re, query(now << 1, l, mid, L, R));
		if (mid < R) re = add(re, query(now << 1 | 1, mid + 1, r, L, R));
		return re;
	}
}T;

int main() {
	n = read(); p = read();
	for (int i = 1; i <= n; i++) a[i] = read();
	Init(); T.build(1, 1, n);
	q = read();
	while (q--) {
		int op = read();
		if (op == 1) {
			int l = read(), r = read(), x = read();
			T.times(1, 1, n, l, r, x, work(x));
		}
		if (op == 2) {
			int pl = read(), x = read();
			T.inv(1, 1, n, pl, work(x));
		}
		if (op == 3) {
			int l = read(), r = read();
			printf("%d\n", T.query(1, 1, n, l, r));
		}
	}
	
	return 0;
}
posted @ 2022-09-27 10:08  あおいSakura  阅读(17)  评论(0编辑  收藏  举报