收藏官网首页
查看: 8087|回复: 0

机器学习之推荐算法实战

36

主题

69

帖子

265

积分

中级会员

Rank: 3Rank: 3

积分
265
跳转到指定楼层
楼主
发表于 2016-4-6 11:41:10 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
注册成为机智云开发者,手机加虚拟设备快速开发

自从Amazone公布了协同过滤算法后,在推荐系统领域,它就占据了很重要的地位。不像传统的内容推荐,协同过滤不需要考虑物品的属性问题,用户的行为,行业问题等,只需要建立用户与物品的关联关系即可,可以物品之间更多的内在关系,类似于经典的啤酒与尿不湿的营销案例。所以,讲到推荐必须要首先分享协同过滤。下面代码实战基于sparkMLlib ASL 算法实战
  1. package com.gizwits.mllib

  2. import org.apache.log4j._
  3. import org.apache.spark.mllib.recommendation._
  4. import org.apache.spark.rdd.RDD
  5. import org.apache.spark.{SparkConf, SparkContext}

  6. /**
  7.   * Created by feel
  8.   *
  9.   * moivelens 电影推荐  协同过滤算法实现电影推荐.目前spark 实现的算法有(交替最小二乘法(ALS))
  10.   * 数据下载:http://grouplens.org/datasets/movielens/
  11.   *
  12.   */
  13. object MoiveRecommenderALS {

  14.   /**
  15.     *
  16.     * @param input            电影评分数据
  17.     * @param numIterations    迭代的次数
  18.     * @param lambda           ALS的正则化参数。
  19.     * @param rank             模型中隐语义因子的个数。
  20.     * @param numUserBlocks    用于并行化计算的分块个数 (设置为-1为自动配置)。
  21.     * @param numProductBlocks 用于并行化计算的分块个数 (设置为-1为自动配置)。
  22.     * @param implicitPrefs    决定了是用显性反馈ALS的版本还是用适用隐性反馈数据集的版本
  23.     * @param userDataInput    用户数据输入
  24.     */
  25.   case class Params(
  26.                      input: String = null,
  27.                      numIterations: Int = 20,
  28.                      lambda: Double = 1.0,
  29.                      rank: Int = 10,
  30.                      numUserBlocks: Int = -1,
  31.                      numProductBlocks: Int = -1,
  32.                      implicitPrefs: Boolean = false,
  33.                      userDataInput: String = null)

  34.   val numRecommender = 10

  35.   def main(args: Array[String]) {
  36.     //  设置日志级别
  37.     val rootLogger = Logger.getRootLogger()

  38.     Logger.getLogger("com.gizwits").setLevel(Level.ERROR)

  39.     rootLogger.setLevel(Level.ERROR)
  40.     val conf = new SparkConf()
  41.       .setAppName("MoiveRecommenderALS")
  42.     conf.setMaster("local[4]")
  43.     val context = new SparkContext(conf)

  44.     val inputDataPath = "file:///Users/feel/githome/idea/spark-exercise/src/main/resources/u.data"
  45.     val userInputPath = "file:///Users/feel/githome/idea/spark-exercise/src/main/resources/u.user"

  46.     //可以调整这些参数,不断优化结果,使均方差变小。比如iterations越多,lambda较小,均方差会较小,推荐结果较优
  47.     val defaultParams = Params(
  48.       inputDataPath, 20, 0.01, 10, -1, -1, false, userInputPath
  49.     )
  50.     //加载数据
  51.     val data = context.textFile(inputDataPath)

  52.     /**
  53.       * *MovieLens ratings are on a scale of 1-5:
  54.       * 5: Must see
  55.       * 4: Will enjoy
  56.       * 3: It's okay
  57.       * 2: Fairly bad
  58.       * 1: Awful
  59.       */
  60.     val ratings = data.map(_.split("\t") match {
  61.       case Array(user, item, rate, time) => Rating(user.toInt, item.toInt, rate.toDouble)
  62.     })



  63.     //使用ALS建立推荐模型
  64.     //也可以使用简单模式    val model = ALS.train(ratings, ranking, numIterations)


  65.     val model = new ALS()
  66.       .setRank(defaultParams.rank)
  67.       .setIterations(defaultParams.numIterations)
  68.       .setLambda(defaultParams.lambda)
  69.       .setImplicitPrefs(defaultParams.implicitPrefs)
  70.       .setUserBlocks(defaultParams.numUserBlocks)
  71.       .setProductBlocks(defaultParams.numProductBlocks)
  72.       .run(ratings)

  73.     //预测
  74.     predictMoive(defaultParams, context, model)


  75.     //模型评估
  76.     evaluateMode(ratings, model)

  77.     //clean up
  78.     context.stop()
  79.     //end  main

  80.   }

  81.   /**
  82.     * 模型评估
  83.     */
  84.   private def evaluateMode(ratings: RDD[Rating], model: MatrixFactorizationModel) {

  85.     //使用训练数据训练模型
  86.     val usersProducets = ratings.map(r => r match {
  87.       case Rating(user, product, rate) => (user, product)
  88.     })

  89.     //预测数据
  90.     val predictions = model.predict(usersProducets).map(u => u match {
  91.       case Rating(user, product, rate) => ((user, product), rate)
  92.     })

  93.     //将真实分数与预测分数进行合并
  94.     val ratesAndPreds = ratings.map(r => r match {
  95.       case Rating(user, product, rate) =>
  96.         ((user, product), rate)
  97.     }).join(predictions)

  98.     //计算均方差
  99.     val MSE = ratesAndPreds.map(r => r match {
  100.       case ((user, product), (r1, r2)) =>
  101.         val err = (r1 - r2)
  102.         err * err
  103.     }).mean()

  104.     //打印出均方差值
  105.     println("Mean Squared Error = " + MSE)
  106.   }

  107.   /**
  108.     * 预测数据并保存到HBase中或其他存储引擎
  109.     */
  110.   private def predictMoive(params: Params, context: SparkContext, model: MatrixFactorizationModel) {


  111.     val recommenders = new scala.collection.mutable.ArrayBuffer[scala.collection.mutable.HashMap[String, String]]();

  112.     //读取需要进行电影推荐的用户数据
  113.     val userData = context.textFile(params.userDataInput)

  114.     userData.map(_.split("\\|") match {
  115.       case Array(id, age, sex, job, x) => (id)
  116.     }).collect().foreach(id => {
  117.       //为用户推荐电影
  118.       val rs = model.recommendProducts(id.toInt, numRecommender)
  119.       var value = ""
  120.       var key = 0

  121.       //保存推荐数据到hbase中
  122.       rs.foreach(r => {
  123.         key = r.user
  124.         value = value + r.product + ":" + r.rating + ","
  125.       })

  126.       //成功,则封装put对象,等待插入到Hbase中
  127.       if (!value.equals("")) {
  128.         val put = new scala.collection.mutable.HashMap[String, String]
  129.         put += ("rowKey" -> key.toString)
  130.         put += ("t:info" -> value)
  131.         recommenders.+=(put)

  132.       }
  133.     })

  134.     recommenders.foreach(println _)

  135.   }
  136. }
复制代码



您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

加入Q群 返回顶部

版权与免责声明 © 2006-2024 Gizwits IoT Technology Co., Ltd. ( 粤ICP备11090211号 )

快速回复 返回顶部 返回列表