1. 程式人生 > >predictionIO E-Commerce Recommendation 源碼分析

predictionIO E-Commerce Recommendation 源碼分析

predictionio e-commerce recommendation 源碼分析

Algorithm 類

 @Override
    public Model train(SparkContext sc, PreparedData preparedData) {
        TrainingData data = preparedData.getTrainingData();

        //模型訓練
        
        //建立用戶索引
        JavaPairRDD<String, Integer> userIndexRDD = data.getUsers().map(new Function<Tuple2<String, User>, String>() {
            @Override
            public String call(Tuple2<String, User> idUser) throws Exception {
                return idUser._1();
            }
        }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
            @Override
            public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
                return new Tuple2<>(element._1(), element._2().intValue());
            }
        });
        //變成java的map對象
        final Map<String, Integer> userIndexMap = userIndexRDD.collectAsMap();
        
        //最終變成 u1->1, u2->2
        
        
        
         //建立商品索引
        JavaPairRDD<String, Integer> itemIndexRDD = data.getItems().map(new Function<Tuple2<String, Item>, String>() {
            @Override
            public String call(Tuple2<String, Item> idItem) throws Exception {
                return idItem._1();
            }
        }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
            @Override
            public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
                return new Tuple2<>(element._1(), element._2().intValue());
            }
        });
        //最終變成 i1->1, i2->2     
        final Map<String, Integer> itemIndexMap = itemIndexRDD.collectAsMap();
        JavaPairRDD<Integer, String> indexItemRDD = itemIndexRDD.mapToPair(new PairFunction<Tuple2<String, Integer>, Integer, String>() {
            @Override
            public Tuple2<Integer, String> call(Tuple2<String, Integer> element) throws Exception {
                return element.swap();
            }
        });
        
        //索引反轉,便於日後根據序號ID找商品
        final Map<Integer, String> indexItemMap = indexItemRDD.collectAsMap();
        
        
        //建立評分索引
        JavaRDD<Rating> ratings = data.getViewEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
            @Override
            public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent viewEvent) throws Exception {
                Integer userIndex = userIndexMap.get(viewEvent.getUser());
                Integer itemIndex = itemIndexMap.get(viewEvent.getItem());

                return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
            }
        }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
            @Override
            public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
                return (element != null);
            }
        }).reduceByKey(new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer integer, Integer integer2) throws Exception {
                return integer + integer2;
            }
        }).map(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Rating>() {
            @Override
            public Rating call(Tuple2<Tuple2<Integer, Integer>, Integer> userItemCount) throws Exception {
                return new Rating(userItemCount._1()._1(), userItemCount._1()._2(), userItemCount._2().doubleValue());
            }
        });
        //最終變成 (u1,i1)->1  (u1,i2)->2
        
        
        // 調用MLlib ALS 算法
        MatrixFactorizationModel matrixFactorizationModel = ALS.trainImplicit(JavaRDD.toRDD(ratings), ap.getRank(), ap.getIteration(), ap.getLambda(), -1, 1.0, ap.getSeed());
        JavaPairRDD<Integer, double[]> userFeatures = matrixFactorizationModel.userFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
            @Override
            public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
                return new Tuple2<>((Integer) element._1(), element._2());
            }
        });//返回基於用戶維度的矩陣
        JavaPairRDD<Integer, double[]> productFeaturesRDD = matrixFactorizationModel.productFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
            @Override
            public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
                return new Tuple2<>((Integer) element._1(), element._2());
            }
        });//返回基於商品維度的矩陣
        
         // 當遇到冷啟動時,推薦最流行的商品,此數據來源於用戶購買的記錄
        JavaRDD<ItemScore> itemPopularityScore = data.getBuyEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
            @Override
            public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent buyEvent) throws Exception {
                Integer userIndex = userIndexMap.get(buyEvent.getUser());
                Integer itemIndex = itemIndexMap.get(buyEvent.getItem());

                return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
            }
        }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
            @Override
            public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
                return (element != null);
            }
        }).mapToPair(new PairFunction<Tuple2<Tuple2<Integer, Integer>, Integer>, Integer, Integer>() {
            @Override
            public Tuple2<Integer, Integer> call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
                return new Tuple2<>(element._1()._2(), element._2());
            }
        }).reduceByKey(new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer integer, Integer integer2) throws Exception {
                return integer + integer2;
            }
        }).map(new Function<Tuple2<Integer, Integer>, ItemScore>() {
            @Override
            public ItemScore call(Tuple2<Integer, Integer> element) throws Exception {
                return new ItemScore(indexItemMap.get(element._1()), element._2().doubleValue());
            }
        });
        //最終變成 i1->1  i2->2
        
         //生成最終的商品維度矩陣
         JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = indexItemRDD.join(productFeaturesRDD);
         
         //訓練結束
        return new Model(userFeatures, indexItemFeatures, userIndexRDD, itemIndexRDD, itemPopularityScore, data.getItems().collectAsMap(),buyItemForUser);
    }
    
    
    //推薦算法
     @Override
    public PredictedResult predict(Model model, final Query query) {
        final JavaPairRDD<String, Integer> matchedUser = model.getUserIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
            @Override
            public Boolean call(Tuple2<String, Integer> userIndex) throws Exception {
                return userIndex._1().equals(query.getUserEntityId());
            }
        });//找到要推薦給某用戶的用戶索引數據

        double[] userFeature = null;
        if (!matchedUser.isEmpty()) {//如果能找到該用戶索引
            final Integer matchedUserIndex = matchedUser.first()._2();//返回用戶的序號
            userFeature = model.getUserFeatures().filter(new Function<Tuple2<Integer, double[]>, Boolean>() {
                @Override
                public Boolean call(Tuple2<Integer, double[]> element) throws Exception {
                    return element._1().equals(matchedUserIndex);
                }
            }).first()._2();//返回用戶維度的矩陣,並且取第一條
        }

        if (userFeature != null) {//如果有用戶維度的數據,走正常的推薦
            return new PredictedResult(topItemsForUser(userFeature, model, query));
        } else {
            List<double[]> recentProductFeatures = getRecentProductFeatures(query, model);//返回該用戶最近點擊的商品
            if (recentProductFeatures.isEmpty()) {//推最流行的商品
                return new PredictedResult(mostPopularItems(model, query));
            } else {//走相似推薦
                return new PredictedResult(similarItems(recentProductFeatures, model, query));
            }
        }
    }
    
    //正常推薦流程
    private List<ItemScore> topItemsForUser(double[] userFeature, Model model, Query query) {
        //轉成用戶維度的矩陣
        final DoubleMatrix userMatrix = new DoubleMatrix(userFeature);

        JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
            @Override
            public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
                return new ItemScore(element._2()._1(), userMatrix.dot(new DoubleMatrix(element._2()._2())));
            }
        });//用戶維度的矩陣乘以商品維度的矩陣,將來根據得分高低,以此推薦

        //過濾一些商品,比如黑名單,或者根據商品屬性進行過濾
        itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
        //排序,並取前幾位,推薦出來
        List<ItemScore> result= sortAndTake(itemScores, query.getNumber());
        
       
        return result;
    }
    
    //推薦最流程的商品,最流行的商品在訓練模型時,已經預置
     private List<ItemScore> mostPopularItems(Model model, Query query) {
        JavaRDD<ItemScore> itemScores = validScores(model.getItemPopularityScore(), query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
        return sortAndTake(itemScores, query.getNumber());
    }
    
     //相似推薦,找到該用戶最近瀏覽的商品
     private List<double[]> getRecentProductFeatures(Query query, Model model) {
        try {
            List<double[]> result = new ArrayList<>();
            //根據用戶id,找該用戶發生的事件(查看商品記錄)
            List<Event> events = LJavaEventStore.findByEntity(
                    ap.getAppName(),
                    "user",
                    query.getUserEntityId(),
                    OptionHelper.<String>none(),
                    OptionHelper.some(ap.getSimilarItemEvents()),
                    OptionHelper.some(OptionHelper.some("item")),
                    OptionHelper.<Option<String>>none(),
                    OptionHelper.<DateTime>none(),
                    OptionHelper.<DateTime>none(),
                    OptionHelper.some(10),
                    true,
                    Duration.apply(10, TimeUnit.SECONDS));

            for (final Event event : events) {
                if (event.targetEntityId().isDefined()) {
                    JavaPairRDD<String, Integer> filtered = model.getItemIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
                        @Override
                        public Boolean call(Tuple2<String, Integer> element) throws Exception {
                            return element._1().equals(event.targetEntityId().get());
                        }
                    });//根據事件ID返回,商品數據

                    //返回第一個商品的序號
                    final Integer itemIndex = filtered.first()._2();

                    if (!filtered.isEmpty()) {

                        JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = model.getIndexItemFeatures().filter(new Function<Tuple2<Integer, Tuple2<String, double[]>>, Boolean>() {
                            @Override
                            public Boolean call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
                                return itemIndex.equals(element._1());
                            }//返回該商品對應的商品維度矩陣
                        });

                        //轉成javalist對象
                        List<Tuple2<Integer, Tuple2<String, double[]>>> oneIndexItemFeatures = indexItemFeatures.collect();
                        if (oneIndexItemFeatures.size() > 0) {
                            result.add(oneIndexItemFeatures.get(0)._2()._2());//返回該商品對應ASL打分矩陣,以此來跟其他的商品打分矩陣,做相似度比較
                        }
                    }
                }
            }

            return result;
        } catch (Exception e) {
            logger.error("Error reading recent events for user " + query.getUserEntityId());
            throw new RuntimeException(e.getMessage(), e);
        }
    }
    
    //具體的相似算法,根據上一個方法返回的item打分向量來計算
     private List<ItemScore> similarItems(final List<double[]> recentProductFeatures, Model model, Query query) {
        JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
            @Override
            public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
                double similarity = 0.0;
                for (double[] recentFeature : recentProductFeatures) {
                    similarity += cosineSimilarity(element._2()._2(), recentFeature);
                }//用每一個商品打分矩陣與返回的某一個商品的打分矩陣,做相似度算分

                return new ItemScore(element._2()._1(), similarity);
            }
        });

        itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
        return sortAndTake(itemScores, query.getNumber());
    }
    
    //如何判斷相似
     private double cosineSimilarity(double[] a, double[] b) {
        DoubleMatrix matrixA = new DoubleMatrix(a);
        DoubleMatrix matrixB = new DoubleMatrix(b);

        return matrixA.dot(matrixB) / (matrixA.norm2() * matrixB.norm2());
    }


由此來看該例子還是比較簡單,適合用於二次開發。下面是一些基礎知識

技術分享

predictionIO E-Commerce Recommendation 源碼分析