package com.rongyi.platform.game.web.data;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import java.util.Arrays;
import java.util.List;
public class EuclideanDistance {
public static double calculateEuclideanDistance(List<Double> list1, List<Double> list2) {
// 确保两个集合的大小相同
if (list1.size() != list2.size()) {
throw new IllegalArgumentException("欧式距离集合大小必选相同");
}
// 计算欧式距离
double distanceSquared = 0.0;
for (int i = 0; i < list1.size(); i++) {
double temp = Math.pow(list1.get(i) - list2.get(i), 2);
distanceSquared += temp;
}
// return Math.sqrt(distanceSquared);
return distanceSquared;
}
public static double calculateEuclideanDistance2(List<Double> list1, List<Double> list2) {
// 确保两个集合的大小相同
if (list1.size() != list2.size()) {
throw new IllegalArgumentException("欧式距离集合大小必选相同");
}
// 将集合转换为double数组
double[] array1 = list1.stream().mapToDouble(Double::doubleValue).toArray();
double[] array2 = list2.stream().mapToDouble(Double::doubleValue).toArray();
// 计算欧式距离
StandardDeviation stdDev = new StandardDeviation();
double stdDev1 = stdDev.evaluate(array1);
double stdDev2 = stdDev.evaluate(array2);
// 标准差的平方就是欧式距离的平方
double distanceSquared = 0.0;
for (int i = 0; i < array1.length; i++) {
distanceSquared += Math.pow(array1[i] - array2[i], 2);
}
return Math.sqrt(distanceSquared) / (stdDev1 + stdDev2);
}
}