参数服务器

本项目使用非常少的代码编写了深度学习训练的全过程,有完整的结构,通过面向对象的封装,在算法上有一定扩展性,不仅支持单机模式还支持分布式模式

使用java实现的dnn训练框架,底层矩阵库使用Jblas(https://github.com/mikiobraun/jblas),参数服务器使用Grpc+protobuf,ui方面使用ploty.js+nanohttpd

支持单机多CPU训练

支持分布式训练,多worker,多ps自定义负载均衡

支持同步更新和异步更新

支持二分类和多分类

实现embdding+全链接模型

实现Wide And Deep模型

实现卷积+池化+全链接模型

支持训练数据,测试数据异步读取,自定义parser

UI Server可视化图表

例子

运行 CTR.java 点击率预估例子,test auc在0.71左右

运行 Mnist.java 手写输入例子,正确率在0.92左右。如果从网上下载全量的mnist数据,正确率在0.98左右

运行 CnnMnist.java 使用Cnn实现的手写输入识别,对比Mnist.java的0.92同样的数据量可以做到0.96

注意,根据运行的cpu core数量不同,结果略有差异,需要略微调整mini batch数量尽快收敛

Jblas linux需要安装libgfortran3

架构

单机多CPU是使用多个相同的模型,提交到线程池,各自做训练,然后再主线程等待所有线程执行完毕对梯度进行更新

多机与单机的不同在于KVStore是分布式实现,获取参数都是从PS获取,多机还有一点是需要调用参数服务器的barrier测是否需要阻塞

模型结构

以wide and deep模型为例子,展示了如何组合layer构造模型

代码结构

可视化

需要启动UiServer,访问 localhost:8888 图表会随着训练过程动态刷新

使用UiClient::plot() 方法可以在训练中进行打点,打点信息会异步的发送给UiServer,不会阻塞训练过程

包概括

TODO

数据读取,libsvm文件的标准化读取,异步数据队列,HDFS支持等

特征处理,参考sklearn中的Preprocessing

更多的Layer实现,dropout,bn等

参数传输的压缩,学习quantize相关paper

多分类的evalute,学习one-vs-all auc