let fat tension(推公式,交换计算顺序,预处理)

题意

\(n\)个人,每个人有两种属性,分别是\(X_i\)\(Y_i\)。其中\(X_i\)\(k\)维向量,\(Y_i\)\(d\)维向量。

定义\(le(i,j) = \frac{X_i \cdot X_j}{|X_i||X_j|}\),即\(X_i\)\(X_j\)的余弦相似度。

\(i = 1,2,\dots, n\),求\(Y_i^{new} = \sum\limits_{j=1}^n le(i,j)Y_j\)

题目链接:https://ac.nowcoder.com/acm/contest/33187/I

数据范围

\(3 \leq n \leq 10^4\)
\(1 \leq k, d \leq 50\)

思路

如果直接暴力计算,那么肯定会超时,因此考虑对公式进行变形,期望能够找到快速计算的方法。

为了方便起见,我们设\(X_i\)的维度为\(m_1\)\(Y_i\)的维度为\(m_2\)

考察每个元素\(Y_{i,o}^{new} = \sum\limits_{j=1}^n \frac{X_i \cdot X_j}{|X_i| |X_j|}Y_{j,o} = \sum\limits_{j=1}^n \sum\limits_{k=1}^{m_1} \frac{X_{i,k}X_{j,k}}{|X_i||X_j|}Y_{j,o}\)

我们令\(X_{i,k}' = \frac{X_{i,k}}{|X_i|}\)\(X_{j,k}' = \frac{X_{j,k}}{|X_j|}\),其中\(|X_i|\)\(|X_j|\)是可以预处理的。

\(Y_{i,o}^{new} = \sum\limits_{j=1}^n \sum\limits_{k=1}^{m_1} X_{i,k}' X_{j,k}' Y_{j,o}\)

其实到这一步为止,还都是一些化简,对解决问题没有帮助。后面需要做的就是通过预处理,对上面的式子快速计算。

通过分析时间复杂度,正确的做法很可能是\(O(nm_1m_2)\)的,但是计算上面那个式子是\(O(n^2m_1m_2)\)的,因此我们要想办法优化掉一个\(O(n)\)

通过观察,我们可以交换求和次序,得到\(Y_{i,o}^{new} = \sum\limits_{k=1}^{m_1} X_{i,k}' (\sum\limits_{j=1}^n X_{j,k}' Y_{j,o})\)

然后我们发现,括号里面的内容可以\(O(nm_1m_2)\)预处理。

最终,我们可以在\(O(nm_1m_2)\)的时间复杂度内解决问题。

代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>

using namespace std;

typedef long long ll;

const int N = 10010, M = 60;

int n, m1, m2;
ll X2[N][M], Y[N][M];
double X[N][M];
double C[M][M];

int main()
{
    scanf("%d%d%d", &n, &m1, &m2);
    for(int i = 1; i <= n; i ++) {
        ll tmp = 0;
        for(int j = 1; j <= m1; j ++) {
            scanf("%lld", &X2[i][j]);
            tmp += X2[i][j] * X2[i][j];
        }
        double s = sqrt(tmp);
        for(int j = 1; j <= m1; j ++) {
            X[i][j] = X2[i][j] * 1.0 / s;
        }
    }
    for(int i = 1; i <= n; i ++) {
        for(int j = 1; j <= m2; j ++) {
            scanf("%lld", &Y[i][j]);
        }
    }
    for(int i = 1; i <= m1; i ++) {
        for(int j = 1; j <= m2; j ++) {
            for(int k = 1; k <= n; k ++) {
                C[i][j] += X[k][i] * Y[k][j];
            }
        }
    }
    for(int i = 1; i <= n; i ++) {
        for(int j = 1; j <= m2; j ++) {
            double ans = 0;
            for(int k = 1; k <= m1; k ++) {
                ans += X[i][k] * C[k][j];
            }
            printf("%.8f ", ans);
        }
        printf("\n");
    }
    return 0;
}
posted @ 2022-08-20 23:01  pbc的成长之路  阅读(20)  评论(0编辑  收藏  举报