Loading

[题解]CF1774D Same Count One

思路

首先记所有 \(1\) 的数量为 \(num\),那么显然有当 \(n \bmod num \neq 0\) 时无解。那么考虑有解的时候该怎么办。

显然对于每一个 \(a_i\) 序列中,最终 \(1\) 的数量为 \(\frac{num}{n}\),记作 \(t\);并记 \(cnt_i\) 表示 \(a_i\) 序列中 \(1\) 的数量。

我们希望最终所有的 \(cnt_i\) 都等于 \(t\),并且希望操作步数最小,我们考虑一个显然的贪心:将 \(cnt_i > t\) 的序列中的 \(1\)\(cnt_j < t\) 缺失的 \(1\)

这样我们每一次的操作都会使 \(\sum_{i = 1}^{n}|cnt_i - t|\) 减少 \(2\),显然是最优的方案。

注意:如果你在交换的时候,一定需要更新 \(a_{i,k}\)\(a_{j,k}\),否则有一个很简单的 Hack。因为你不更新,你的程序会认为 \(a_{3,1}\) 在第一次操作后还是 \(0\) 可以交换。

Code

#include <bits/stdc++.h>
#define fst first
#define snd second
#define re register

using namespace std;

typedef pair<int,int> pii;
const int N = 1e5 + 10;
int n,m;
int cnt[N];
pii del[N];

struct answer{
    int a,b,pos;
};

inline int read(){
    int r = 0,w = 1;
    char c = getchar();
    while (c < '0' || c > '9'){
        if (c == '-') w = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9'){
        r = (r << 3) + (r << 1) + (c ^ 48);
        c = getchar();
    }
    return r * w;
}

inline void solve(){
    int num = 0;
    n = read();
    m = read();
    bool vis[n + 10][m + 10];
    vector<answer> ans;
    for (re int i = 1;i <= n;i++){
        cnt[i] = 0;
        for (re int j = 1;j <= m;j++){
            int x;
            x = read();
            if (x) vis[i][j] = true;
            else vis[i][j] = false;
            num += x;
            cnt[i] += x;
        }
    }
    if (num % n) return puts("-1"),void();
    num /= n;
    for (re int i = 1;i <= n;i++) del[i] = {cnt[i] - num,i};
    sort(del + 1,del + n + 1);
    for (re int i = 1,j = n;i < j;){
        int p = del[i].snd,q = del[j].snd;
        for (re int k = 1;k <= m && del[i].fst && del[j].fst;k++){
            if (!vis[p][k] && vis[q][k]){
                del[i].fst++;
                del[j].fst--;
                vis[p][k] = true;
                vis[q][k] = false;
                ans.push_back({p,q,k});
            }
        }
        if (!del[i].fst) i++;
        if (!del[j].fst) j--;
    }
    printf("%d\n",ans.size());
    for (auto p:ans) printf("%d %d %d\n",p.a,p.b,p.pos);
}

int main(){
    int T;
    T = read();
    while (T--) solve();
    return 0;
}
posted @ 2024-01-13 16:35  BeautifulWish  阅读(16)  评论(0编辑  收藏  举报