十次方人工智能笔记二:人工智能
完成十次方文章智能分类

人工智能与机器学习

什么是人工智能

​ 人工智能(Artificial Intelligence),英文缩写为AI。它是研究、开发用于模拟、延伸和扩展人的智能的理论、方法、技术及应用系统的一门新的技术科学。

​ 人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器,该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。人工智能从诞生以来,理论和技术日益成熟,应用领域也不断扩大,可以设想,未来人工智能带来的科技产品,将会是人类智慧的“容器”。人工智能可以对人的意识、思维的信息过程的模拟。人工智能不是人的智能,但能像人那样思考、也可能超过人的智能。

什么是机器学习

​ 机器学习,它正是这样一门学科,它致力于研究如何通过计算(CPU和GPU计算)的手段,利用经验来改善(计算机)系统自身的性能。
​ 它是人工智能的核心,是使计算机具有智能的根本途径,应用遍及人工智能各领域。
​ 数据 + 机器学习算法 = 机器学习模型

​ 有了学习算法就可以把经验数据提供给它,它就能基于这些数据产生模型。

AI、机器学习和深度学习的关系

机器学习是人工智能的一个分支,深度学习是实现机器学习的一种技术。

智能分类

通过机器学习,当用户录入一篇文章或从互联网爬取一篇文章时可以预测其归属的类型。

智能分类流程图:

构建IK分词资料库

  1. tensquare_common中引入IK分词器依赖

    1
    2
    3
    4
    5
    <dependency>
    <groupId>com.janeluo</groupId>
    <artifactId>ikanalyzer</artifactId>
    <version>2012_u6</version>
    </dependency>
  2. tensquare_common中新增分词工具类

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    package util;

    import org.wltea.analyzer.core.IKSegmenter;
    import org.wltea.analyzer.core.Lexeme;

    import java.io.IOException;
    import java.io.StringReader;

    public class IkUtil {


    /**
    * 对制定文本进行中文分词
    *
    * @param content 源文本
    * @param splitChar 分词后结果的间隔符
    * @return 分词后的文本
    */
    public static String split(String content, String splitChar) throws IOException {
    StringReader reader = new StringReader(content);
    IKSegmenter ikSegmenter = new IKSegmenter(reader, true);
    Lexeme lexeme;
    StringBuilder result = new StringBuilder();
    while ((lexeme = ikSegmenter.next()) != null) {
    result.append(lexeme.getLexemeText()).append(splitChar);
    }
    return result.toString();
    }

    }
  3. 新增消除HTML代码中的标签工具类

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    package util;

    import java.util.regex.Matcher;
    import java.util.regex.Pattern;

    /**
    * html标签处理工具类
    */
    public class HTMLUtil {

    public static String delHTMLTag(String htmlStr){
    String regEx_script="<script[^>]*?>[\\s\\S]*?<\\/script>"; //定义script的正则表达式
    String regEx_style="<style[^>]*?>[\\s\\S]*?<\\/style>"; //定义style的正则表达式
    String regEx_html="<[^>]+>"; //定义HTML标签的正则表达式

    Pattern p_script=Pattern.compile(regEx_script,Pattern.CASE_INSENSITIVE);
    Matcher m_script=p_script.matcher(htmlStr);
    htmlStr=m_script.replaceAll(""); //过滤script标签

    Pattern p_style=Pattern.compile(regEx_style,Pattern.CASE_INSENSITIVE);
    Matcher m_style=p_style.matcher(htmlStr);
    htmlStr=m_style.replaceAll(""); //过滤style标签

    Pattern p_html=Pattern.compile(regEx_html,Pattern.CASE_INSENSITIVE);
    Matcher m_html=p_html.matcher(htmlStr);
    htmlStr=m_html.replaceAll(""); //过滤html标签

    return htmlStr.trim(); //返回文本字符串
    }
    }
  4. 修改tensquare_article_crawler,新增将爬取的网页进行分词并存储到文本文件中的持久化类

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    package com.tensquare.crawler.pipeline;

    import lombok.extern.slf4j.Slf4j;
    import org.springframework.beans.factory.annotation.Value;
    import org.springframework.stereotype.Component;
    import us.codecraft.webmagic.ResultItems;
    import us.codecraft.webmagic.Task;
    import us.codecraft.webmagic.pipeline.Pipeline;
    import util.HTMLUtil;
    import util.IkUtil;

    import java.io.File;
    import java.io.IOException;
    import java.io.PrintWriter;
    import java.util.UUID;

    @Slf4j
    @Component
    public class ArticleTxtPipeline implements Pipeline {

    @Value("${ai.dataPath}")
    private String dataPath;

    private String channelId;

    public void setChannelId(String channelId) {
    this.channelId = channelId;
    }

    @Override
    public void process(ResultItems resultItems, Task task) {
    String title = resultItems.get("title");
    String content = HTMLUtil.delHTMLTag(resultItems.get("content"));
    log.info("文章名 [{}]", title);
    // 输出到文本文件
    try {
    // 输出流
    PrintWriter printWriter = new PrintWriter(new File(dataPath + "/" + channelId + "/" + UUID.randomUUID() + ".txt"));
    printWriter.print(IkUtil.split(content, " "));
    } catch (IOException e) {
    log.error("持久化网页内如至文件失败, e = ", e);
    }
    }
    }
  5. application.yml中新增配置

    1
    2
    ai:
    dataPath: E:/article
  6. 修改爬去任务类,添加新增的持久化类

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    package com.tensquare.crawler.task;

    import com.tensquare.crawler.pipeline.ArticlePipeline;
    import com.tensquare.crawler.pipeline.ArticleTxtPipeline;
    import com.tensquare.crawler.processor.ArticleProcessor;
    import lombok.extern.slf4j.Slf4j;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.scheduling.annotation.Scheduled;
    import org.springframework.stereotype.Component;
    import us.codecraft.webmagic.Spider;
    import us.codecraft.webmagic.scheduler.RedisScheduler;

    @Slf4j
    @Component
    public class ArticleCrawlerTask {

    @Autowired
    private ArticlePipeline articlePipeline;

    @Autowired
    private ArticleTxtPipeline articleTxtPipeline;

    @Autowired
    private RedisScheduler redisScheduler;

    @Autowired
    private ArticleProcessor articleProcessor;

    @Scheduled(cron = "0 0 0 * * *")
    public void aiTask() {
    log.info("开始爬取AI文章");
    articlePipeline.setChannelId("ai");
    articleTxtPipeline.setChannelId("ai");
    Spider spider = Spider.create(articleProcessor);
    spider.addUrl("https://blog.csdn.net/nav/ai/")
    .addPipeline(articlePipeline)
    .addPipeline(articleTxtPipeline)
    .setScheduler(redisScheduler)
    .start();
    }

    @Scheduled(cron = "0 0 1 * * *")
    public void blockChainTask() {
    log.info("开始爬取区块链文章");
    articlePipeline.setChannelId("blockchain");
    articleTxtPipeline.setChannelId("blockchain");
    Spider spider = Spider.create(articleProcessor);
    spider.addUrl("https://blog.csdn.net/nav/blockchain/")
    .addPipeline(articlePipeline)
    .addPipeline(articleTxtPipeline)
    .setScheduler(redisScheduler)
    .run();
    }

    @Scheduled(cron = "0 0 2 * * *")
    public void dbTask() {
    log.info("开始爬取区数据库文章");
    articlePipeline.setChannelId("db");
    articleTxtPipeline.setChannelId("db");
    Spider spider = Spider.create(articleProcessor);
    spider.addUrl("https://blog.csdn.net/nav/db/")
    .addPipeline(articlePipeline)
    .addPipeline(articleTxtPipeline)
    .setScheduler(redisScheduler)
    .run();
    }

    @Scheduled(cron = "0 0 3 * * *")
    public void langTask() {
    log.info("开始爬取编程语言文章");
    articlePipeline.setChannelId("lang");
    articleTxtPipeline.setChannelId("lang");
    Spider spider = Spider.create(articleProcessor);
    spider.addUrl("https://blog.csdn.net/nav/lang/")
    .addPipeline(articlePipeline)
    .addPipeline(articleTxtPipeline)
    .setScheduler(redisScheduler)
    .run();
    }

    @Scheduled(cron = "0 0 4 * * *")
    public void newsTask() {
    log.info("开始爬取资讯文章");
    articlePipeline.setChannelId("news");
    articleTxtPipeline.setChannelId("news");
    Spider spider = Spider.create(articleProcessor);
    spider.addUrl("https://blog.csdn.net/nav/news/")
    .addPipeline(articlePipeline)
    .addPipeline(articleTxtPipeline)
    .setScheduler(redisScheduler)
    .run();
    }

    @Scheduled(cron = "0 0 5 * * *")
    public void webTask() {
    log.info("开始爬取前端文章");
    articlePipeline.setChannelId("web");
    articleTxtPipeline.setChannelId("web");
    Spider spider = Spider.create(articleProcessor);
    spider.addUrl("https://blog.csdn.net/nav/web/")
    .addPipeline(articlePipeline)
    .addPipeline(articleTxtPipeline)
    .setScheduler(redisScheduler)
    .run();
    }


    }
  7. 生成的分词资料库展示

合并分词库

  1. 创建Module(省略)

  2. pom.xml

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    <?xml version="1.0" encoding="UTF-8"?>
    <project xmlns="http://maven.apache.org/POM/4.0.0"
    xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
    xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
    <artifactId>tensquare_parent</artifactId>
    <groupId>com.tensquare</groupId>
    <version>1.0.0-SNAPSHOT</version>
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>tensquare_ai</artifactId>

    <dependencies>
    <dependency>
    <groupId>com.tensquare</groupId>
    <artifactId>tensquare_common</artifactId>
    <version>${tensquare.version}</version>
    </dependency>
    </dependencies>
    </project>
  3. application.yml

    1
    2
    3
    4
    5
    ai:
    # 分词库汇总文件
    wordLib: E:/article.txt
    # 分词库目录
    dataPath: E:/article
  4. 启动类

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    package com.tensquare.ai;

    import org.springframework.boot.SpringApplication;
    import org.springframework.boot.autoconfigure.SpringBootApplication;
    import org.springframework.scheduling.annotation.EnableScheduling;

    @EnableScheduling
    @SpringBootApplication
    public class AIApplication {
    public static void main(String[] args) {
    SpringApplication.run(AIApplication.class, args);
    }
    }
  5. tensquare_common中新增工具类,用于文件内容合并

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    package util;

    import java.io.*;
    import java.util.ArrayList;
    import java.util.List;

    /**
    * 文件工具类
    */
    public class FileUtil {


    /**
    * 将多个文本文件合并为一个文本文件
    * @param outFileName
    * @param inFileNames
    * @throws IOException
    */
    public static void merge(String outFileName ,List<String> inFileNames) throws IOException {

    FileWriter writer = new FileWriter(outFileName, true);

    for(String inFileName :inFileNames ){
    try {
    String txt= readToString(inFileName);
    writer.write(txt);
    System.out.println(txt);
    }catch (Exception e){
    }
    }
    writer.close();
    }


    /**
    * 查找某目录下的所有文件名称
    * @param path
    * @return
    */
    public static List<String> getFiles(String path) {
    List<String> files = new ArrayList<String>();
    File file = new File(path);
    File[] tempList = file.listFiles();

    for (int i = 0; i < tempList.length; i++) {
    if (tempList[i].isFile()) {//如果是文件
    files.add(tempList[i].toString());
    }
    if (tempList[i].isDirectory()) {//如果是文件夹
    files.addAll( getFiles(tempList[i].toString()) );
    }
    }
    return files;
    }


    /**
    * 读取文本文件内容到字符串
    * @param fileName
    * @return
    */
    public static String readToString(String fileName) {
    String encoding = "UTF-8";
    File file = new File(fileName);
    Long filelength = file.length();
    byte[] filecontent = new byte[filelength.intValue()];
    try {
    FileInputStream in = new FileInputStream(file);
    in.read(filecontent);
    in.close();
    } catch (FileNotFoundException e) {
    e.printStackTrace();
    } catch (IOException e) {
    e.printStackTrace();
    }
    try {
    return new String(filecontent, encoding);
    } catch (UnsupportedEncodingException e) {
    System.err.println("The OS does not support " + encoding);
    e.printStackTrace();
    return null;
    }
    }

    }
  6. Service

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    package com.tensquare.ai.service;

    import lombok.extern.slf4j.Slf4j;
    import org.springframework.beans.factory.annotation.Value;
    import org.springframework.stereotype.Service;
    import util.FileUtil;

    import java.io.IOException;
    import java.util.List;

    @Slf4j
    @Service
    public class Word2VecService {

    @Value("${ai.wordLib}")
    private String wordLib;

    @Value("${ai.dataPath}")
    private String dataPath;

    /**
    * 文件合并
    */
    public void mergeWord() {
    List<String> fileNames = FileUtil.getFiles(dataPath);
    try {
    FileUtil.merge(wordLib, fileNames);
    } catch (IOException e) {
    log.error("合并分词资料库失败,e = ", e);
    }
    }
    }
  7. 任务类

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    package com.tensquare.ai.task;

    import com.tensquare.ai.service.Word2VecService;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.scheduling.annotation.Scheduled;
    import org.springframework.stereotype.Component;

    @Component
    public class TrainTask {

    @Autowired
    private Word2VecService word2VecService;

    @Scheduled(cron = "0 30 16 * * *")
    public void trainModel() {
    // 合并资料库
    word2VecService.mergeWord();
    }

    }

构建词向量模型

  1. tensquare_ai中新增依赖

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    <dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-beta</version>
    </dependency>
    <dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-nlp</artifactId>
    <version>1.0.0-beta</version>
    </dependency>
    <dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native-platform</artifactId>
    <version>1.0.0-beta</version>
    </dependency>
  2. application.yml新增配置

    1
    2
    3
    ai:
    # 词向量模型
    vecModel: E:/article.vecmodel
  3. 修改Word2VecService,新增构建词向量模型的方法

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    @Value("${ai.vecModel}")
    private String vecModel;

    /**
    * 构建词向量模型
    */
    public void build() {
    try {
    // 加载分词库
    LineSentenceIterator sentenceIterator = new LineSentenceIterator(new File(wordLib));

    Word2Vec vec = new Word2Vec.Builder()
    .minWordFrequency(5)// 一个词在语料中最少出现的次数,若低于该值将不予学习
    .iterations(1)// 处理数据时允许系数更新的次数
    .layerSize(100)// 指定词向量中特征数量
    .seed(42)// 随机数发生器
    .windowSize(5)// 当前词与预测词在句中最大距离
    .iterate(sentenceIterator)// 指定数据集
    .build();
    // 构建
    vec.fit();
    // 删除原有模型
    Paths.get(vecModel).toFile().delete();
    // 保存新生成的模型
    WordVectorSerializer.writeWordVectors(vec, vecModel);
    } catch (Exception e) {
    log.error("词向量模型构建失败, e = ", e);
    }

    }
  1. 合并后的词向量模型

卷积神经网络模型

神经网络

卷积

构建卷积神经网络模型

  1. application.yml新增配置

    1
    2
    # 卷积神经网络模型
    cnnModel: E:/article.cnnModel
  2. 添加工具类

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    package com.tensquare.ai.util;

    import org.deeplearning4j.iterator.CnnSentenceDataSetIterator;
    import org.deeplearning4j.iterator.LabeledSentenceProvider;
    import org.deeplearning4j.iterator.provider.FileLabeledSentenceProvider;
    import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
    import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
    import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
    import org.deeplearning4j.nn.conf.ConvolutionMode;
    import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
    import org.deeplearning4j.nn.conf.graph.MergeVertex;
    import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
    import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
    import org.deeplearning4j.nn.conf.layers.OutputLayer;
    import org.deeplearning4j.nn.graph.ComputationGraph;
    import org.deeplearning4j.util.ModelSerializer;
    import org.nd4j.linalg.api.ndarray.INDArray;
    import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

    import java.io.File;
    import java.io.IOException;
    import java.util.*;

    /**
    * CNN工具类
    */
    public class CnnUtil {

    /**
    * 创建计算图(卷积神经网络)
    * @param cnnLayerFeatureMaps 卷积核的数量(=词向量维度)
    * @return 计算图
    */
    public static ComputationGraph createComputationGraph(int cnnLayerFeatureMaps){
    //训练模型
    int vectorSize = 300; //向量大小
    //int cnnLayerFeatureMaps = 100; ////每种大小卷积层的卷积核的数量=词向量维度
    ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
    .convolutionMode(ConvolutionMode.Same)// 设置卷积模式
    .graphBuilder()
    .addInputs("input")
    .addLayer("cnn1", new ConvolutionLayer.Builder()//卷积层
    .kernelSize(3,vectorSize)//卷积区域尺寸
    .stride(1,vectorSize)//卷积平移步幅
    .nIn(1)
    .nOut(cnnLayerFeatureMaps)
    .build(), "input")
    .addLayer("cnn2", new ConvolutionLayer.Builder()
    .kernelSize(4,vectorSize)
    .stride(1,vectorSize)
    .nIn(1)
    .nOut(cnnLayerFeatureMaps)
    .build(), "input")
    .addLayer("cnn3", new ConvolutionLayer.Builder()
    .kernelSize(5,vectorSize)
    .stride(1,vectorSize)
    .nIn(1)
    .nOut(cnnLayerFeatureMaps)
    .build(), "input")
    .addVertex("merge", new MergeVertex(), "cnn1", "cnn2", "cnn3")//全连接层
    .addLayer("globalPool", new GlobalPoolingLayer.Builder()//池化层
    .build(), "merge")
    .addLayer("out", new OutputLayer.Builder()//输出层
    .nIn(3*cnnLayerFeatureMaps)
    .nOut(3)
    .build(), "globalPool")
    .setOutputs("out")
    .build();
    ComputationGraph net = new ComputationGraph(config);
    net.init();
    return net;
    }

    /**
    * 获取训练数据集
    * @param path 分词语料库根目录
    * @param childPaths 分词语料库子文件夹
    * @param vecModel 词向量模型
    * @return
    */
    public static DataSetIterator getDataSetIterator(String path, String[] childPaths, String vecModel ){
    //加载词向量模型
    WordVectors wordVectors = WordVectorSerializer.loadStaticModel(new File(vecModel));
    //词标记分类比标签
    Map<String,List<File>> reviewFilesMap = new HashMap<>();

    for( String childPath: childPaths){
    reviewFilesMap.put(childPath, Arrays.asList(new File(path+"/"+ childPath ).listFiles()));
    }
    //标记跟踪
    LabeledSentenceProvider sentenceProvider = new FileLabeledSentenceProvider(reviewFilesMap, new Random(12345));
    return new CnnSentenceDataSetIterator.Builder()
    .sentenceProvider(sentenceProvider)
    .wordVectors(wordVectors)
    .minibatchSize(32)
    .maxSentenceLength(256)
    .useNormalizedWordVectors(false)
    .build();

    }


    /**
    * 预言
    * @param vecModel 词向量模型
    * @param cnnModel 卷积神经网络模型
    * @param dataPath 分词语料库根目录
    * @param childPaths 分词语料库文件夹
    * @param content 预言的文本内容
    * @return
    * @throws IOException
    */
    public static Map<String, Double> predictions(String vecModel,String cnnModel,String dataPath,String[] childPaths,String content) throws IOException {
    Map<String, Double> map = new HashMap<>();
    //模型应用
    ComputationGraph model = ModelSerializer.restoreComputationGraph(cnnModel);//通过cnn模型获取计算图对象
    //加载数据集
    DataSetIterator dataSet = CnnUtil.getDataSetIterator(dataPath,childPaths, vecModel);
    //通过句子获取概率矩阵对象
    INDArray featuresFirstNegative = ((CnnSentenceDataSetIterator) dataSet).loadSingleSentence(content);
    INDArray predictionsFirstNegative =model.outputSingle(featuresFirstNegative);
    List<String> labels = dataSet.getLabels();

    for (int i = 0; i < labels.size(); i++) {
    map.put(labels.get(i) + "", predictionsFirstNegative.getDouble(i));
    }
    return map;
    }



    }
  3. 创建智能分类模型

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    package com.tensquare.ai.service;

    import com.tensquare.ai.util.CnnUtil;
    import lombok.extern.slf4j.Slf4j;
    import org.deeplearning4j.nn.graph.ComputationGraph;
    import org.deeplearning4j.util.ModelSerializer;
    import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
    import org.springframework.beans.factory.annotation.Value;
    import org.springframework.stereotype.Service;

    import java.io.IOException;
    import java.nio.file.Paths;

    /**
    * 智能分类模型
    */
    @Slf4j
    @Service
    public class CnnService {

    @Value("${ai.cnnModel}")
    private String cnnModel;

    @Value("${ai.dataPath}")
    private String dataPath;

    @Value("${ai.vecModel}")
    private String vecModel;

    /**
    * 构建卷积模型
    */
    public void build() {
    try {
    // 创建计算图对象
    ComputationGraph computationGraph = CnnUtil.createComputationGraph(10);
    // 加载词向量,训练数据集
    String[] childPaths = {"ai", "db", "web"};
    DataSetIterator dataSetIterator = CnnUtil.getDataSetIterator(dataPath, childPaths, vecModel);
    // 训练
    computationGraph.fit();
    // 删除之前生成卷积模型
    Paths.get(cnnModel).toFile().delete();
    // 保存
    ModelSerializer.writeModel(computationGraph, cnnModel, true);
    } catch (IOException e) {
    log.error("卷积模型生成失败, e = ", e);
    }

    }


    }
  4. TrainTask#trainModel中新增逻辑

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    @Autowired
    private CnnService cnnService;

    @Scheduled(cron = "0 30 15 * * *")
    public void trainModel() {
    // 合并资料库
    word2VecService.mergeWord();
    // 构建词向量模型
    word2VecService.build();
    // 构建卷积模型
    cnnService.build();
    }

实现智能分类

  1. 修改CnnService新增方法

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    /**
    * 返回map集合 分类与百分比
    * @param content
    * @return
    */
    public Map textClassify(String content) {
    log.info("待分类数据, content = [{}]", content);
    Map result = null;
    // 分词
    try {
    content = IkUtil.split(content, " ");
    String[] childPaths = {"ai", "db", "web"};
    // 获取预测结果
    result = CnnUtil.predictions(vecModel, cnnModel, dataPath, childPaths, content);
    } catch (IOException e) {
    log.error("智能分类失败, e = ", e);
    }
    return result;
    }
  2. 创建AiController

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    package com.tensquare.ai.controller;

    import com.tensquare.ai.service.CnnService;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.web.bind.annotation.PostMapping;
    import org.springframework.web.bind.annotation.RequestBody;
    import org.springframework.web.bind.annotation.RequestMapping;
    import org.springframework.web.bind.annotation.RestController;

    import java.util.Map;

    @RestController
    @RequestMapping("ai")
    public class AiController {

    @Autowired
    private CnnService cnnService;

    @PostMapping("textclassify")
    public Map textClassify(@RequestBody Map<String, String> content) {
    return cnnService.textClassify(content.get("content"));
    }

    }
  3. 测试

    训练数据越多,智能分类越准确,但是如果你输入的词不存在与分词数据库就会报错。

文章作者: imxushuai
文章链接: https://www.imxushuai.com/2002/01/02/12.十次方人工智能笔记二:人工智能/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 imxushuai
支付宝打赏
微信打赏