UserCF的MapReduce实现

本文介绍了一种基于用户协同过滤(UserCF)的推荐系统优化方案,重点在于如何通过计算用户间的Jaccard相似度来提高推荐精度,并通过合理划分item组来减少计算复杂度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

用户协同过滤(UserCF)在推荐系统刚出来的时候,使用的比较多,但随着技术的发展以及业务场景的限制,现在只使用协同过滤的推荐系统已经不多了,具体的原理和细节网上有很多,也就不在这里赘述了。我最近在实现一个推荐系统的时候,需要用到UserCF,而在UserCF中最核心的东西就是计算每个用户的相似度,得出相似用户之后再根据这些用户看过的或者点击过的东西去做推荐。

计算相似度的方法有很多,就不在这一一列举了,我使用了比较简单的Jaccard相似度:两个集合的交集作为分子、两个集合的并集作为分母,计算出来的数值即可。放到具体的业务中,也就是以两个用户共同点击、浏览或者发生其他一些行为的item数作为分子,以两个用户各自点击过的item并集作为分母,即可认为是这两个用户的相似度。

这个简单的逻辑用代码实现起来一点也不难,而难的是数据量大了之后,计算的性能会很差,因为计算过程中需要先得到用户的笛卡尔积,然后再算相似度,所以当用户数上万、甚至上百万时,积之后的这个数量是相当恐怖的。当然了,稍稍用脚趾头想想就知道,不用计算全部用户的笛卡尔积,只需要叉乘点击过相同item的用户就行了,因为根据定义,没有点击过相同item的用户相似度肯定为0,也就不用计算了。整个流程用mapreduce实现起来还是挺容易的,具体参见以下代码。

我用的是阿里的odps,三年前就已经用过了,那时的界面还是比较简陋的,现在貌似已经叫Max Compute了,功能也增加了很多,整体看着还是不错的。另外,mr脚本虽然是用java写的,但java不是我的主要编程语言,所以有写的不好的地方,拍砖还请轻点。

import java.io.IOException;
import java.util.*;
import java.util.regex.Pattern;

import com.aliyun.odps.data.Record;
import com.aliyun.odps.mapred.*;


public class UserCF {

    public static class UserCFMapper extends MapperBase {
        Record k;
        Record v;
        Pattern pattern = Pattern.compile("^[-\\+]?[\\d]*$");

        @Override
        public void setup(TaskContext context) throws IOException {
            k = context.createMapOutputKeyRecord();
            v = context.createMapOutputValueRecord();
        }

        @Override
        public void map(long recordNum, Record record, TaskContext context) throws IOException {
            if (record != null) {
                String itemId = record.getString(1);
                String userId = record.getString(2);
                if (pattern.matcher(itemId).matches()) {
                    k.setString(0, userId);
                    v.setString(0, itemId);
                    context.write(k, v);
                }
            }
        }
    }

    public static class UserFilterReducer extends ReducerBase {
        Record k;
        Record v;
        static final long MIN = 1;
        static final long MAX = 3000;

        @Override
        public void setup(TaskContext context) throws IOException {
            k = context.createOutputKeyRecord();
            v = context.createOutputValueRecord();
        }

        @Override
        public void reduce(Record key, Iterator<Record> values, TaskContext context) throws IOException {
            Set<String> items = new HashSet<>();
            while (values.hasNext()) {
                Record record = values.next();
                items.add(record.getString(0));
            }

            long cnt = items.size();
            if (cnt > MIN && cnt < MAX) {
                for (String item : items) {
                    k.setString(0, item);                                   //一个item就放到一个reducer
                    v.setString(0, Helper.encode(key.getString(0), cnt));   //这里就简单的直接把userid和点击过的item数拼成一个字符串,后面计算相似度会用到
                    context.write(k, v);
                }
            }
        }
    }

    public static class ItemFilterReducer extends ReducerBase {
        Record k;
        Record v;
        static final int MAXUSER = 1000000;     //同一item被点击过的用户数,设置阈值主要目的是减小数据集

        @Override
        public void setup(TaskContext context) throws IOException {
            k = context.createOutputKeyRecord();
            v = context.createOutputValueRecord();
        }

        @Override
        public void reduce(Record key, Iterator<Record> values, TaskContext context) throws IOException {
            Set<String> users = new HashSet<>();
            while (values.hasNext()) {
                Record record = values.next();
                users.add(record.getString(0));
            }

            long cnt = users.size();
            if (cnt < MAXUSER) {
                for (String user : users) {
                    k.setString(0, key.getString(0));
                    v.setString(0, user);
                    context.write(k, v);
                }
            }
        }
    }

    public static class ProductReducer extends ReducerBase {
        Record k;
        Record v;

        @Override
        public void setup(TaskContext context) throws IOException {
            k = context.createOutputKeyRecord();
            v = context.createOutputValueRecord();
        }

        @Override
        public void reduce(Record key, Iterator<Record> values, TaskContext context) throws IOException {
            Set<String> users = new HashSet<>();
            while (values.hasNext()) {
                Record value = values.next();
                String user = value.getString(0);
                users.add(user);
            }

            List<String> list = new ArrayList<>(users);
            for (int i = 0; i < list.size(); ++i) {
                String curUser = list.get(i);
                String[] val = Helper.decode(curUser);
                // 此处可以加入些业务过滤条件,只对满足条件的用户计算相似度

                for (int j = i + 1; j < list.size(); ++j) {
                    String otherUser = list.get(j);
                    if (otherUser.equals(curUser))
                        continue;

                    k.setString(0, curUser);
                    v.setString(0, otherUser);
                    v.setBigint(1, 1L);
                    context.write(k, v);

                    
                    k.setString(0, otherUser);
                    v.setString(0, curUser);
                    v.setBigint(1, 1L);
                    context.write(k, v);

                }

            }

        }

    }

    public static class SimUserReducer extends ReducerBase {
        private Record v;

        @Override
        public void setup(TaskContext context) throws IOException {
            v = context.createOutputRecord();
        }

        @Override
        public void reduce(Record key, Iterator<Record> values, TaskContext context) throws IOException {
            Map<String, Long> map = new HashMap<>();
            while (values.hasNext()) {
                Record record = values.next();
                String other = record.getString(0);

                if (map.containsKey(other)) {
                    map.put(other, map.get(other) + record.getBigint(1));
                } else {
                    map.put(other, record.getBigint(1));
                }
            }

            /*
                用jaccard计算相似度:分子为共同点过的item数,
                                   分母的并集数可以这样计算:两个用户各自点击item数 - 共同的点击item数
            */
            String[] val = Helper.decode(key.getString(0));
            Map<String, Double> userSimilarity = new TreeMap<>();
            for (Map.Entry<String, Long> entry : map.entrySet()) {
                String[] deviceCnt = entry.getKey().split(":");
                String userId = deviceCnt[0];
                long cnt = Long.parseLong(deviceCnt[1]);
                long numerator = entry.getValue(); 
                long denominator = Long.parseLong(val[1]) + cnt - numerator;    
                double sim = numerator * 1. / denominator;
                userSimilarity.put(userId, sim);
            }

            List<Map.Entry<String, Double>> list = new ArrayList<>(userSimilarity.entrySet());
            list.sort((Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) -> o2.getValue().compareTo(o1.getValue()));

            List<String> res = new ArrayList<>();
            if (list.size() > 50) {             //只取最相似的50个用户
                list = list.subList(0, 50);
            }

            for (Map.Entry<String, Double> li : list) {
                res.add(li.getKey() + ":" + String.valueOf(li.getValue()));
            }

            v.setString(0, val[0]);
            v.setString(1, String.join("\t", res));
            context.write(v);

        }

    }

}

实际运行中,发现product的reducer发射出去的记录数量还是很大,而且整个运行时间在3个小时左右,还需要再优化。

我想到的一个方法是:将所有的item按照一定规则分组,比如根据hashcode,这样每个组内的item数会比较少,然后就先在这个组内对每个用户计算和其他用户的jaccard相似度,然后按照相似度大小排序,只取排名前面一部分的用户发射出去,这样一来可以减少output的记录数,参考的思想是apriori的生成频繁项集的过程,两个在全集相似的用户在子集中也应该是大概率的相似,这里只是说大概率,并不是一定。所以,这样做的缺点是可能会损失一些真正相似的用户,取决与分组的规则是否合理了。以下只给出优化部分的代码,其他代码同上。

首先是将item按照hashcode分组发射出去,也就是将几个item放到一个reducer中,而不是像原来的一个item一个reducer。这里用了模1000,也就是理论上来说,最好的情况是(所有item数/1000)个item会分到同一个reducer中,当然了,这个数还需要再对reducer个数取模,这样,每个reducer中的item就是一个分组,其中的个数应该是远大于1 的。

    public static class ItemFilterReducer extends ReducerBase {
        Record k;
        Record v;
        static final int MAXUSER = 500000;

        @Override
        public void setup(TaskContext context) throws IOException {
            k = context.createOutputKeyRecord();
            v = context.createOutputValueRecord();
        }

        @Override
        public void reduce(Record key, Iterator<Record> values, TaskContext context) throws IOException {
            Set<String> users = new HashSet<>();
            while (values.hasNext()) {
                Record record = values.next();
                users.add(record.getString(0));
            }

            long cnt = users.size();
            if (cnt < MAXUSER) {
                String hash_key = key.getString(0);
                hash_key = String.valueOf(hash_key.hashCode()%1000);
                for (String user : users) {
                    k.setString(0, hash_key);
                    v.setString(0, user);
                    v.setString(1, key.getString(0));
                    context.write(k, v);
                }
            }
        }
    }

然后是计算相似度的reducer,这里做的优化就是在每个reducer中对每个用户粗略计算一下其他用户的相似度,然后只输出部分用户。这里的输出阈值还是需要尝试几个值,以期做到损失尽量少的真正相似用户。

    public static class ProductReducer extends ReducerBase {
        Record k;
        Record v;

        @Override
        public void setup(TaskContext context) throws IOException {
            k = context.createOutputKeyRecord();
            v = context.createOutputValueRecord();
        }

        /*
            先在当前reducer中,计算一遍相似度、排序,然后输出排名靠前的一部分用户,相似度计算方式同上
        */
        @Override
        public void reduce(Record key, Iterator<Record> values, TaskContext context) throws IOException {
            Map<String, Set<Integer>> userItemMap = new HashMap<>();
            Map<String, Integer> itemInd = new HashMap<>();
            while (values.hasNext()) {
                Record value = values.next();
                String otherUser = value.getString(0);
                String item = value.getString(1);
                if (!itemInd.containsKey(item)) {
                    itemInd.put(item, itemInd.size()+1);
                }
                if (!userItemMap.containsKey(otherUser)) {
                    userItemMap.put(otherUser, new HashSet<>());
                }
                userItemMap.get(otherUser).add(itemInd.get(item));
            }

            List<Map.Entry<String, Set<Integer>>> l = new ArrayList<>(userItemMap.entrySet());
            itemInd.clear();
            userItemMap.clear();

            //计算
            for (int i = 0; i < l.size(); ++i) {
                String curUser = l.get(i).getKey();
                String[] val = Helper.decode(curUser);
                // 此处可以加入些业务过滤条件,只对满足条件的用户计算相似度

                long cnt1 = Long.valueOf(val[1]);
                Map<String, Double> similarity = new HashMap<>(); 
                Set<Integer> items1 = l.get(i).getValue();
                for (int j = 0; j < l.size(); ++j) {
                    String otherUser = l.get(j).getKey();
                    long cnt2 = Long.valueOf(Helper.decode(otherUser)[1]);
                    Set<Integer> items2 = l.get(j).getValue();
                    items2.retainAll(items1);
                    int common = items2.size();
                    if (common > 0) {
                        similarity.put(otherUser, common*1./(cnt1+cnt2-common));
                    }
                }

                //排序  
                List<Map.Entry<String, Double>> simFilter = new ArrayList<>(similarity.entrySet());
                simFilter.sort((Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) -> o2.getValue().compareTo(o1.getValue()));

                //输出排名靠前50%的用户
                int filerSize = (int)(simFilter.size()*0.5);
                simFilter = simFilter.subList(0, filerSize);
                for (Map.Entry<String, Double> li : simFilter) {
                    k.setString(0, curUser);
                    v.setString(0, li.getKey());
                    v.setBigint(1, 1L);
                    context.write(k, v);

                    k.setString(0, li.getKey());
                    v.setString(0, curUser);
                    v.setBigint(1, 1L);
                    context.write(k, v);
                    
                }
            }
        }
    }
实际运行后,运行时间减少了差不多一半,结果数据量减少了10%左右。当然了,这个优化办法还不是最优的,仍然需要继续改进。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值