DL4J實戰之二:鳶尾花分類
歡迎訪問我的GitHub
http://github.com/zq2599/blog_demos
內容:所有原創文章分類彙總及配套原始碼,涉及Java、Docker、Kubernetes、DevOPS等;
本篇概覽
- 本文是《DL4J》實戰的第二篇,前面做好了準備工作,接下來進入正式實戰,本篇內容是經典的入門例子:鳶尾花分類
- 下圖是一朵鳶尾花,我們可以測量到它的四個特徵:花瓣(petal)的寬和高,花萼(sepal)的 寬和高:
- 鳶尾花有三種:Setosa、Versicolor、Virginica
- 今天的實戰是用前饋神經網路Feed-Forward Neural Network (FFNN)就行鳶尾花分類的模型訓練和評估,在拿到150條鳶尾花的特徵和分類結果後,我們先訓練出模型,再評估模型的效果:
原始碼下載
- 本篇實戰中的完整原始碼可在GitHub下載到,地址和連結資訊如下表所示( http://github.com/zq2599/blo... ):
名稱 | 連結 | 備註 |
---|---|---|
專案主頁 | http://github.com/zq2599/blo... | 該專案在GitHub上的主頁 |
git倉庫地址(https) | http://github.com/zq2599/blo... | 該專案原始碼的倉庫地址,https協議 |
git倉庫地址(ssh) | [email protected] :zq2599/blog_demos.git | 該專案原始碼的倉庫地址,ssh協議 |
- 這個git專案中有多個資料夾,《DL4J實戰》系列的原始碼在<font color="blue">dl4j-tutorials</font>資料夾下,如下圖紅框所示:
- <font color="blue">dl4j-tutorials</font>資料夾下有多個子工程,本次實戰程式碼在<font color="blue">dl4j-tutorials</font>目錄下,如下圖紅框:
編碼
- 在<font color="blue">dl4j-tutorials</font>工程下新建子工程<font color="red">classifier-iris</font>,其pom.xml如下:
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <parent> <artifactId>dlfj-tutorials</artifactId> <groupId>com.bolingcavalry</groupId> <version>1.0-SNAPSHOT</version> </parent> <modelVersion>4.0.0</modelVersion> <artifactId>classifier-iris</artifactId> <properties> <maven.compiler.source>8</maven.compiler.source> <maven.compiler.target>8</maven.compiler.target> </properties> <dependencies> <dependency> <groupId>com.bolingcavalry</groupId> <artifactId>commons</artifactId> <version>${project.version}</version> </dependency> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> </dependency> <dependency> <groupId>org.nd4j</groupId> <artifactId>${nd4j.backend}</artifactId> </dependency> <dependency> <groupId>ch.qos.logback</groupId> <artifactId>logback-classic</artifactId> </dependency> </dependencies> </project>
- 上述pom.xml有一處需要注意的地方,就是<font color="blue">${nd4j.backend}</font>引數的值,該值在決定了後端線性代數計算是用CPU還是GPU,本篇為了簡化操作選擇了CPU(因為個人的顯示卡不同,程式碼裡無法統一),對應的配置就是<font color="red">nd4j-native</font>;
- 原始碼全部在Iris.java檔案中,並且程式碼中已新增詳細註釋,就不再贅述了:
package com.bolingcavalry.classifier; import com.bolingcavalry.commons.utils.DownloaderUtility; import lombok.extern.slf4j.Slf4j; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.SplitTestAndTrain; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File; /** * @author will ([email protected]) * @version 1.0 * @description: 鳶尾花訓練 * @date 2021/6/13 17:30 */ @SuppressWarnings("DuplicatedCode") @Slf4j public class Iris { public static void main(String[] args) throws Exception { //第一階段:準備 // 跳過的行數,因為可能是表頭 int numLinesToSkip = 0; // 分隔符 char delimiter = ','; // CSV讀取工具 RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter); // 下載並解壓後,得到檔案的位置 String dataPathLocal = DownloaderUtility.IRISDATA.Download(); log.info("鳶尾花資料已下載並解壓至 : {}", dataPathLocal); // 讀取下載後的檔案 recordReader.initialize(new FileSplit(new File(dataPathLocal,"iris.txt"))); // 每一行的內容大概是這樣的:5.1,3.5,1.4,0.2,0 // 一共五個欄位,從零開始算的話,標籤在第四個欄位 int labelIndex = 4; // 鳶尾花一共分為三類 int numClasses = 3; // 一共150個樣本 int batchSize = 150; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets) // 載入到資料集迭代器中 DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses); DataSet allData = iterator.next(); // 洗牌(打亂順序) allData.shuffle(); // 設定比例,150個樣本中,百分之六十五用於訓練 SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training // 訓練用的資料集 DataSet trainingData = testAndTrain.getTrain(); // 驗證用的資料集 DataSet testData = testAndTrain.getTest(); // 指定歸一化器:獨立地將每個特徵值(和可選的標籤值)歸一化為0平均值和1的標準差。 DataNormalization normalizer = new NormalizerStandardize(); // 先擬合 normalizer.fit(trainingData); // 對訓練集做歸一化 normalizer.transform(trainingData); // 對測試集做歸一化 normalizer.transform(testData); // 每個鳶尾花有四個特徵 final int numInputs = 4; // 共有三種鳶尾花 int outputNum = 3; // 隨機數種子 long seed = 6; //第二階段:訓練 log.info("開始配置..."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .activation(Activation.TANH) // 啟用函式選用標準的tanh(雙曲正切) .weightInit(WeightInit.XAVIER) // 權重初始化選用XAVIER:均值 0, 方差為 2.0/(fanIn + fanOut)的高斯分佈 .updater(new Sgd(0.1)) // 更新器,設定SGD學習速率排程器 .l2(1e-4) // L2正則化配置 .list() // 配置多層網路 .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3) // 隱藏層 .build()) .layer(new DenseLayer.Builder().nIn(3).nOut(3) // 隱藏層 .build()) .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) // 損失函式:負對數似然 .activation(Activation.SOFTMAX) // 輸出層指定啟用函式為:SOFTMAX .nIn(3).nOut(outputNum).build()) .build(); // 模型配置 MultiLayerNetwork model = new MultiLayerNetwork(conf); // 初始化 model.init(); // 每一百次迭代列印一次分數(損失函式的值) model.setListeners(new ScoreIterationListener(100)); long startTime = System.currentTimeMillis(); log.info("開始訓練"); // 訓練 for(int i=0; i<1000; i++ ) { model.fit(trainingData); } log.info("訓練完成,耗時[{}]ms", System.currentTimeMillis()-startTime); // 第三階段:評估 // 在測試集上評估模型 Evaluation eval = new Evaluation(numClasses); INDArray output = model.output(testData.getFeatures()); eval.eval(testData.getLabels(), output); log.info("評估結果如下\n" + eval.stats()); } }
- 編碼完成後,執行main方法,可見順利完成訓練並輸出了評估結果,還有混淆矩陣用於輔助分析:
- 至此,咱們的第一個實戰就完成了,通過經典例項體驗的DL4J訓練和評估的常規步驟,對重要API也有了初步認識,接下來會繼續實戰,接觸到更多的經典例項;
你不孤單,欣宸原創一路相伴
「其他文章」
- 為了生成唯一id,React18專門引入了新Hook:useId
- Java SPI機制從原理到實戰
- Rust 稽核團隊“一夜之間”集體辭職:開源社群治理話題再被熱議
- PHP 基金會,是個好事 (PHP Foundation)
- 以 Vuex 為引,一窺狀態管理全貌
- 0.99M,150FPS,移動端超輕量目標檢測演算法PP-PicoDet來了!
- Shell 指令碼避坑指南(一)
- 一文讀懂層次聚類(Python程式碼)
- KVO原理分析
- Markdown語法基礎
- 監聽鍵盤事件
- 升級到Java 17沒這麼簡單
- 正則表示式例項蒐集,通過例項來學習正則表示式。
- mqtt訊息推送(vue前端篇)
- 詳解電子表格中的json資料:序列化與反序列化
- 第八期:前端九條啟發分享
- MySQL事務的多版本併發控制(MVCC)實現原理
- 無處不在的 Kubernetes,難用的問題解決了嗎?
- 技術乾貨 | Flutter線上程式設計實踐總結
- Spark 架構設計與原理思想