dl4j doc-03-Deeplearning4j 官方 template 本地测试验证 入门 MINIST 实战测试
快速入门模板
现在您已经学会了如何运行不同的示例,我们为您提供了一个模板,其中包含一个带有简单评估代码的基本MNIST训练器。
快速入门模板可在 https://github.com/eclipse/deeplearning4j-examples/tree/master/mvn-project-template 上找到。
也可以下下载压缩包,然后倒入。
这个项目相对比较简单。
备份
或者直接在 https://github.com/houbb/dl4j-template 下载。
整体测试
整体的 maven 依赖
4.0.0
com.github.houbb
dl4j-template
1.0-SNAPSHOT
1.0.0-M2.1
1.2.3
11
2.4.3
UTF-8
org.deeplearning4j
deeplearning4j-core
${dl4j-master.version}
org.nd4j
nd4j-native
${dl4j-master.version}
ch.qos.logback
logback-classic
${logback.version}
org.apache.maven.plugins
maven-compiler-plugin
3.5.1
${java.version}
${java.version}
org.apache.maven.plugins
maven-shade-plugin
${maven-shade-plugin.version}
true
bin
true
*:*
org/datanucleus/**
META-INF/*.SF
META-INF/*.DSA
META-INF/*.RSA
package
shade
reference.conf
完整的 java 代码
package com.github.houbb.dl4j.template;
/*******************************************************************************
*
*
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
/**
* Implementation of LeNet-5 for handwritten digits image classification on MNIST dataset (99% accuracy)
* [LeCun et al., 1998. Gradient based learning applied to document recognition]
* Some minor changes are made to the architecture like using ReLU and identity activation instead of
* sigmoid/tanh, max pooling instead of avg pooling and softmax output layer.
*
* This example will download 15 Mb of data on the first run.
*
* @author hanlon
* @author agibsonccc
* @author fvaleri
* @author dariuszzbyrad
*/
public class LeNetMNISTReLu {
private static final Logger LOGGER = LoggerFactory.getLogger(LeNetMNISTReLu.class);
// private static final String BASE_PATH = System.getProperty("java.io.tmpdir") + "/mnist";
// 直接文件下载,并且解压到这个路径。
private static final String BASE_PATH = "C:\\Users\\dh\\.deeplearning4j\\data\\MNIST";
public static void main(String[] args) throws Exception {
// 图片高度
int height = 28; // height of the picture in px
// 图片宽度
int width = 28; // width of the picture in px
// 通道 1 表示 黑白
int channels = 1; // single channel for grayscale images
// 可能出现的结果数量 0-9 10个数字
int outputNum = 10; // 10 digits classification
// 批处理数量
int batchSize = 54; // number of samples that will be propagated through the network in each iteration
// 迭代次数
int nEpochs = 1; // number of training epochs
// 随机数生成器
int seed = 1234; // number used to initialize a pseudorandom number generator.
Random randNumGen = new Random(seed);
LOGGER.info("Data vectorization...");
// vectorization of train data
File trainData = new File(BASE_PATH + "/mnist_png/training");
FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // use parent directory name as the image label
ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
trainRR.initialize(trainSplit);
// MNIST中的数据
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);
// pixel values from 0-255 to 0-1 (min-max scaling)
DataNormalization imageScaler = new ImagePreProcessingScaler();
imageScaler.fit(trainIter);
trainIter.setPreProcessor(imageScaler);
// vectorization of test data
File testData = new File(BASE_PATH + "/mnist_png/testing");
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
testRR.initialize(testSplit);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
testIter.setPreProcessor(imageScaler); // same normalization for better results
LOGGER.info("Network configuration and training...");
// reduce the learning rate as the number of training epochs increases
// iteration #, learning rate
Map learningRateSchedule = new HashMap<>();
learningRateSchedule.put(0, 0.06);
learningRateSchedule.put(200, 0.05);
learningRateSchedule.put(600, 0.028);
learningRateSchedule.put(800, 0.0060);
learningRateSchedule.put(1000, 0.001);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.l2(0.0005) // ridge regression value
.updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))
.weightInit(WeightInit.XAVIER)
.list()
.layer(new ConvolutionLayer.Builder(5, 5)
.nIn(channels)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(new ConvolutionLayer.Builder(5, 5)
.stride(1, 1) // nIn need not specified in later layers
.nOut(50)
.activation(Activation.IDENTITY)
.build())
.layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(new DenseLayer.Builder().activation(Activation.RELU)
.nOut(500)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(10));
LOGGER.info("Total num of params: {}", net.numParams());
// evaluation while training (the score should go down)
for (int i = 0; i
org.deeplearning4j
deeplearning4j-datasets
1.0.0-M1
- 设置Java编译器版本: 如果你使用的是Maven或Gradle等构建工具,你可以尝试在项目配置中指定编译器版本。例如,对于Maven项目,在
pom.xml
中添加以下配置:
1.8
1.8
这将指定使用Java 8编译器进行编译,与DeepLearning4J编译的Java版本匹配。
根据你的实际情况选择其中一种方法来解决问题。
这里选择的升级 jdk 版本到 jdk11。
报错2-文件不存在
c.g.h.d.t.LeNetMNIST - Load data....
o.n.c.r.Downloader - Error extracting train-images-idx3-ubyte.gz files from file C:\Users\dh\.deeplearning4j\train-images-idx3-ubyte.gz - retrying...
java.io.EOFException: Unexpected end of ZLIB input stream
at java.base/java.util.zip.InflaterInputStream.fill(InflaterInputStream.java:245)
at java.base/java.util.zip.InflaterInputStream.read(InflaterInputStream.java:159)
at java.base/java.util.zip.GZIPInputStream.read(GZIPInputStream.java:118)
at java.base/java.io.FilterInputStream.read(FilterInputStream.java:107)
at org.apache.commons.io.IOUtils.copyLarge(IOUtils.java:1127)
解决方式1
默认的路径:
System.out.println(DL4JResources.getDirectory(ResourceType.DATASET, "MNIST").getAbsolutePath());
结果未:
C:\Users\dh\.deeplearning4j\data\MNIST
把下载的 minst 文件放在这个文件夹下面, 并且解压:
C:\Users\dh\.deeplearning4j\data\MNIST\mnist_png 的目录
2015/12/11 08:55 .
2024/03/27 16:06 ..
2015/12/11 08:55 testing
2015/12/11 08:55 training
traning 对应训练数据集
testing 对应测试数据集
解决方式2(未验证)
在Windows环境下,你可以从以下位置获取MNIST数据集:
训练数据集: 你可以从官方网站下载训练图像和标签数据集,然后将它们放在你的项目目录中的一个文件夹中。
测试数据集: 同样,你也可以从相同的官方网站下载测试图像和标签数据集,然后将它们放在另一个文件夹中。
在你的代码中,你可以指定这些文件的路径。假设你将训练和测试数据集放在项目目录下的名为data
的文件夹中,你可以这样加载数据集:
String trainDataPath = "data/train-images.idx3-ubyte"; // 训练图像数据文件路径
String trainLabelsPath = "data/train-labels.idx1-ubyte"; // 训练标签数据文件路径
String testDataPath = "data/t10k-images.idx3-ubyte"; // 测试图像数据文件路径
String testLabelsPath = "data/t10k-labels.idx1-ubyte"; // 测试标签数据文件路径
/*
Create an iterator using the batch size for one iteration
*/
log.info("Load data....");
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345,
new MnistDataSetIterator.Builder()
.useNormalizedData(true)
.trainFilePath(trainDataPath)
.labelsFilePath(trainLabelsPath)
.build());
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345,
new MnistDataSetIterator.Builder()
.useNormalizedData(true)
.testFilePath(testDataPath)
.labelsFilePath(testLabelsPath)
.build());
在这个例子中,trainDataPath
和trainLabelsPath
分别是训练图像和标签数据文件的路径,testDataPath
和testLabelsPath
分别是测试图像和标签数据文件的路径。请确保替换这些路径为你实际存放数据集的路径。
MNIST数据集是公开可用的,无需账户或注册即可下载。你可以通过以下链接访问官方网站来获取MNIST数据集:http://yann.lecun.com/exdb/mnist/
在该网站上,你会找到MNIST数据集的各个部分的直接链接,可以通过点击这些链接来下载数据集。例如,你可以下载训练图像、训练标签、测试图像和测试标签数据集。
参考资料
DL4J无法下载MNIST数据集解决 Server returned HTTP response code: 403 for URL解决方法
MNIST数据下载地址: http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz
参考资料
https://deeplearning4j.konduit.ai/multi-project/tutorials/quickstart