收藏官网首页
查看: 12342|回复: 0

机器学习实战之SVM

36

主题

69

帖子

265

积分

中级会员

Rank: 3Rank: 3

积分
265
跳转到指定楼层
楼主
发表于 2016-4-6 11:46:50 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
免费使用STM32、APP自动代码生成工具
支持向量机 主要用在分类问题上,其效果也是非常好的。下面代码是基于sparkML lib svm 算法的实现
  1. package com.gizwits.mllib

  2. import org.apache.log4j.{Level, Logger}
  3. import org.apache.spark.mllib.classification.SVMWithSGD
  4. import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
  5. import org.apache.spark.mllib.optimization.{HingeGradient, L1Updater}
  6. import org.apache.spark.mllib.util.MLUtils
  7. import org.apache.spark.{SparkConf, SparkContext}

  8. /**
  9.   * Created by feel
  10.   * <p>
  11.   *
  12.   * SVM(线性支持向量机)是一个有监督的学习模型,通常用来进行模式识别、分类、以及回归分析。SVM只支持二分类
  13.   *
  14.   * </p>
  15.   *
  16.   */
  17. object SparkSVM {


  18.   def main(args: Array[String]) {
  19.     //  设置日志级别
  20.     val rootLogger = Logger.getRootLogger()

  21.     Logger.getLogger("com.gizwits").setLevel(Level.ERROR)


  22.     val conf = new SparkConf().setAppName("SparkSVM").setMaster("local[2]")
  23.     val sc = new SparkContext(conf)

  24.     /*
  25.       数据解释,如下,  第一个是标签也就是分类,0 表示1类,1  表示另一类.  冒号2边 表示 特征维度:特征值, 如 128:51  表示 第128维度, 特征值为51,比如单词出现的 次数
  26.       0 128:51 129:159
  27.       1 159:124 160:253b
  28.       1 159:124 160:253
  29.      */
  30.     val dataInputPath = "file:///Users/feel/githome/idea/spark-exercise/src/main/resources/sample_libsvm_data.txt"
  31.     // 加载 LIBSVM 格式的数据
  32.     val data = MLUtils.loadLibSVMFile(sc, dataInputPath)

  33.     // 切分数据,60%用于训练,40%用于测试
  34.     val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
  35.     // 60% 用于训练  训练集,RDD[LabeledPoint]
  36.     val training = splits(0).cache()

  37.     //  40% 用于测试
  38.     val test = splits(1)

  39.     // 运行训练算法 来构建模型
  40.     val numIterations = 200 // 迭代次数,默认是100
  41.     /**
  42.       * SVMWithSGD类中参数说明:
  43.       *
  44.       * stepSize: 迭代步长,默认为1.0
  45.       *
  46.       * numIterations: 迭代次数,默认为100
  47.       *
  48.       * regParam: 正则化参数,默认值为0.0
  49.       *
  50.       * miniBatchFraction: 每次迭代参与计算的样本比例,默认为1.0
  51.       *
  52.       * gradient:HingeGradient (),梯度下降; 其损失函数是 hinge los
  53.       *
  54.       * updater:SquaredL2Updater (),正则化,L2范数;线性SVM使用L2正则化做训练。也可以替换为L1正则化,这样就成了线性优化问题
  55.       *
  56.       * optimizer:GradientDescent (gradient, updater),梯度下降最优化计算。
  57.       */
  58.    // val model =  SVMWithSGD.train(data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, regType=’l2′, intercept=False)
  59.     //  val model = SVMWithSGD.train(training, numIterations) // 随机梯度下降法 简单模式


  60.     val svmAlg = new SVMWithSGD()

  61.     /**
  62.       * 为了防止过拟合,需要在loss function后面加入一个正则化项一起求最小值。
  63.       * 正则化项相当于对weights向量的惩罚,期望求出一个更简单的模型。
  64.       * MLlib目前支持两种正则化方法L1和L2。
  65.       * L2正则化假设模型参数服从高斯分布,L2正则化函数比L1更光滑,所以更容易计算;
  66.       * L1假设模型参数服从拉普拉斯分布,L1正则化具备产生稀疏解的功能,从而具备feature selection的能力。
  67.       */
  68.     svmAlg.optimizer
  69.       .setGradient(new HingeGradient()) // HingeGradient,支持ANNGradient,LogisticGradient,LeastSquaresGradient
  70.       .setNumIterations(numIterations)
  71.       .setRegParam(0.1). //设置正则化参数
  72.       setUpdater(new L1Updater) //  目前支持 SquaredL2Updater,ANNUpdater,SimpleUpdater


  73.     val model = svmAlg.run(training)


  74.     // 清空默认值
  75.     model.clearThreshold()

  76.     // 用训练集计算原始分数
  77.     val scoreAndLabels = test.map { point =>
  78.       val score = model.predict(point.features)
  79.       (score, point.label)
  80.     }
  81.     //训练数据集用RDD[LabeledPoint]表示,其中label是分类类型的索引,从0开始,即0, 1, 2, …。
  82.     val result = scoreAndLabels.map { t =>
  83.       val str = "point.label=" + t._2 + " score= " + t._1
  84.       println(str)
  85.       str
  86.     }
  87.     result.collect()

  88.     // 以下是2种分类模型的评估
  89.     // 获得评价指标,二分类评估

  90.     /**
  91.       * ROC的全名叫做Receiver Operating Characteristic,是一个画在二维平面上的经过(0, 0),(1, 1)的曲线,
  92.       * 一般情况下,这个曲线都应该处于(0, 0)和(1, 1)连线的上方,如果不幸的出现在了下方,说明分类器的结果反了。
  93.       * 程序里的结果其实是计算的AUC,也就是Area Under roc Curve,
  94.       * 也就是处于ROC curve下方的那部分面积的大小,一般是0.5至1,越大就说明分类效果越好
  95.       */
  96.     val metrics = new BinaryClassificationMetrics(scoreAndLabels)
  97.     val auROC = metrics.areaUnderROC() //受试者操作特征
  98.     println("Area under ROC = " + auROC)

  99.     // 计算测试误差,多分类

  100.     val metricsLabel = new MulticlassMetrics(scoreAndLabels)

  101.     val precision = metricsLabel.precision

  102.     println("Precision = " + precision)
  103.     sc.stop()
  104.   }
  105. }
复制代码



您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

加入Q群 返回顶部

版权与免责声明 © 2006-2024 Gizwits IoT Technology Co., Ltd. ( 粤ICP备11090211号 )

快速回复 返回顶部 返回列表