由于业务需求,需要从JAVA端和mysql的存储函数写一个余弦相似度算法
JAVA端
/**
* 计算两个向量的余弦相似度
*
* @param strOne 向量字符串1,格式如1,2,3
* @param strTwo 向量字符串2,格式如1,2,3
* @return
*/
private double cosineSimilarity(String strOne, String strTwo) {
//相似度大于0.6的基本上就是不同的人
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
String[] splitOne = strOne.split(",");
String[] splitTwo = strTwo.split(",");
double[] vectorA = Arrays.stream(splitOne).mapToDouble(Double::parseDouble).toArray();
double[] vectorB = Arrays.stream(splitTwo).mapToDouble(Double::parseDouble).toArray();
for (int i = 0; i < vectorA.length; i++) {
dotProduct += vectorA[i] * vectorB[i];
normA += Math.pow(vectorA[i], 2);
normB += Math.pow(vectorB[i], 2);
}
double abc = Math.sqrt(normA) * Math.sqrt(normB);
return BigDecimal.valueOf(dotProduct).divide(BigDecimal.valueOf(abc), 10, BigDecimal.ROUND_UP).doubleValue();
}
Mysql存储函数
CREATE DEFINER=`root`@`%` FUNCTION `getCosineSimilarity`(str1 text, str2 text) RETURNS varchar(1000) CHARSET utf8mb4
BEGIN
#计算两个向量的余弦相似度,传入进入格式如:1,2,3
DECLARE i INT DEFAULT 1;
DECLARE len1 INT DEFAULT 0;
DECLARE len2 INT DEFAULT 0;
DECLARE dot_product FLOAT DEFAULT 0;
DECLARE norm1 FLOAT DEFAULT 0;
DECLARE norm2 FLOAT DEFAULT 0;
DECLARE val1 FLOAT;
DECLARE val2 FLOAT;
DECLARE max_len INT;
DECLARE result FLOAT;
-- 处理NULL输入
IF str1 IS NULL OR str2 IS NULL THEN
RETURN 0;
END IF;
-- 计算实际元素数量
SET len1 = IF(str1 = '', 0, LENGTH(str1) - LENGTH(REPLACE(str1, ',', '')) + 1);
SET len2 = IF(str2 = '', 0, LENGTH(str2) - LENGTH(REPLACE(str2, ',', '')) + 1);
-- 处理空字符串
IF len1 = 0 OR len2 = 0 THEN
RETURN 0;
END IF;
SET max_len = GREATEST(len1, len2);
WHILE i <= max_len DO
-- 安全获取数值,处理解析错误
BEGIN
DECLARE continue HANDLER FOR SQLEXCEPTION, SQLWARNING
BEGIN
-- 解析失败时使用0
END;
SET val1 = IF(i <= len1,
NULLIF(CAST(TRIM(SUBSTRING_INDEX(SUBSTRING_INDEX(CONCAT(str1, ','), ',', i), ',', -1)) AS FLOAT), 0),
0);
SET val2 = IF(i <= len2,
NULLIF(CAST(TRIM(SUBSTRING_INDEX(SUBSTRING_INDEX(CONCAT(str2, ','), ',', i), ',', -1)) AS FLOAT), 0),
0);
END;
SET dot_product = dot_product + (val1 * val2);
SET norm1 = norm1 + (val1 * val1);
SET norm2 = norm2 + (val2 * val2);
SET i = i + 1;
END WHILE;
IF norm1 > 0 AND norm2 > 0 THEN
SET result = dot_product / (SQRT(norm1) * SQRT(norm2));
ELSE
SET result = 0;
END IF;
RETURN result;
END
672

被折叠的 条评论
为什么被折叠?



