大多数情况下,mxnet都使用python接口进行机器学习程序的编写,方便快捷,但是有的时候,需要把机器学习训练和识别的程序部署到生产版的程序中去,比如游戏或者云服务,此时采用C++等高级语言去编写才能提高性能,本文介绍了如何在windows系统下从源码编译mxnet,安装python版的包,并使用C++原生接口创建示例程序。


目标

  • 编译出libmxnet.lib和libmxnet.dll的gpu版本
  • 从源码安装mxnet python包
  • 构建mxnet C++示例程序

环境

  • windows10
  • vs2015
  • cmake3.7.2
  • Miniconda2(python2.7.14)
  • CUDA8.0
  • mxnet1.2
  • opencv3.4.1
  • OpenBLAS-v0.2.19-Win64-int32
  • cudnn-8.0-windows10-x64-v7.1(如果编译cpu版本的mxnet,则此项不需要)

步骤

下载源码

最好用git下载,递归地下载所有依赖的子repo,源码的根目录为mxnet

git clone --recursive https://github.com/dmlc/mxnet

依赖库

在此之前确保cmake和python已经正常安装,并且添加到环境变量,然后再下载第三方依赖库

  • 下载安装cuda,确保机器是英伟达显卡,且支持cuda,地址:https://developer.nvidia.com/cuda-toolkit
  • 下载安装opencv预编译版,地址:https://sourceforge.net/projects/opencvlibrary/files/opencv-win/3.4.1/opencv-3.4.1-vc14_vc15.exe/download
  • 下载openblas预编译版,地址:https://sourceforge.net/projects/openblas/files/v0.2.19/
  • 下载cudnn预编译版,注意与cuda版本对应,地址:https://developer.nvidia.com/compute/machine-learning/cudnn/secure/v7.0.5/prod/8.0_20171129/cudnn-8.0-windows10-x64-v7

cmake配置

打开cmake-gui,配置源码目录和生成目录,编译器选择vs2015 win64



配置第三方依赖库





configure和generate



编译vs工程

打开mxnet.sln,配置成release x64模式,编译整个solution


编译完成后会在对应文件夹生成mxnet的lib和dll


此时整个过程成功了一半


安装mxnet的python包

有了libmxnet.dll就可以同源码安装python版的mxnet包了

不过,前提是需要集齐所有依赖到的其他dll,如图所示,将这些dll全部拷贝到mxnet/python/mxnet目录下


tip: 关于dll的来源

  • opencv,openblas,cudnn相关dll都是从这几个库的目录里拷过来的
  • libgcc_s_seh-1.dll和libwinpthread-1.dll是从mingw相关的库目录里拷过来的,git,qt等这些目录都有
  • libgfortran-3.dll和libquadmath_64-0.dll是从adda(https://github.com/adda-team/adda/releases)这个库里拷过来的,注意改名

然后,在mxnet/python目录下使用命令行安装mxnet的python包

python setup.py install


安装过程中,python会自动把对应的dll考到安装目录,正常安装完成后,在python中就可以 import mxnet 了

生成C++依赖头文件

为了能够使用C++原生接口,这一步是很关键的一步,目的是生成mxnet C++程序依赖的op.h文件

在mxnet/cpp-package/scripts目录,将所有依赖到的dll拷贝进来


在此目录运行命令行

python OpWrapperGenerator.py libmxnet.dll


正常情况下就可以在mxnet/cpp-package/include/mxnet-cpp目录下生成op.h了


如果这个过程中出现一些error,多半是dll文件缺失或者版本不对,很好解决

构建C++示例程序

建立cpp工程,这里使用经典的mnist手写数字识别训练示例(请提前下载好mnist数据,地址:mnist),启用GPU支持

选择release x64模式


配置include和lib目录以及附加依赖项



include目录包括:

  • D:\mxnet\include
  • D:\mxnet\dmlc-core\include
  • D:\mxnet\nnvm\include
  • D:\mxnet\cpp-package\include

lib目录:

  • D:\mxnet\build_x64\Release

附加依赖项:

  • libmxnet.lib

代码 main.cpp

#include <chrono>
#include "mxnet-cpp/MxNetCpp.h"

using namespace std;
using namespace mxnet::cpp;

Symbol mlp(const vector<int> &layers)
{
	auto x = Symbol::Variable("X");
	auto label = Symbol::Variable("label");

	vector<Symbol> weights(layers.size());
	vector<Symbol> biases(layers.size());
	vector<Symbol> outputs(layers.size());

	for (size_t i = 0; i < layers.size(); ++i)
	{
		weights[i] = Symbol::Variable("w" + to_string(i));
		biases[i] = Symbol::Variable("b" + to_string(i));
		Symbol fc = FullyConnected(
			i == 0 ? x : outputs[i - 1],// data
			weights[i],biases[i],layers[i]);
		outputs[i] = i == layers.size() - 1 ? fc : Activation(fc,ActivationActType::kRelu);
	}

	return softmaxOutput(outputs.back(),label);
}

int main(int argc,char** argv)
{
	const int image_size = 28;
	const vector<int> layers{128,64,10};
	const int batch_size = 100;
	const int max_epoch = 10;
	const float learning_rate = 0.1;
	const float weight_decay = 1e-2;

	auto train_iter = MXDataIter("MNISTIter")
		.SetParam("image","./mnist_data/train-images.idx3-ubyte")
		.SetParam("label","./mnist_data/train-labels.idx1-ubyte")
		.SetParam("batch_size",batch_size)
		.SetParam("flat",1)
		.CreateDataIter();
	auto val_iter = MXDataIter("MNISTIter")
		.SetParam("image","./mnist_data/t10k-images.idx3-ubyte")
		.SetParam("label","./mnist_data/t10k-labels.idx1-ubyte")
		.SetParam("batch_size",1)
		.CreateDataIter();

	auto net = mlp(layers);

	// start traning
	cout << "==== mlp training begin ====" << endl;

	auto start_time = chrono::system_clock::Now();

	Context ctx = Context::gpu();  // Use GPU for training

	std::map<string,ndarray> args;
	args["X"] = ndarray(Shape(batch_size,image_size*image_size),ctx);
	args["label"] = ndarray(Shape(batch_size),ctx);
	// Let MXNet infer shapes of other parameters such as weights
	net.InferArgsMap(ctx,&args,args);

	// Initialize all parameters with uniform distribution U(-0.01,0.01)
	auto initializer = Uniform(0.01);
	for (auto& arg : args)
	{
		// arg.first is parameter name,and arg.second is the value
		initializer(arg.first,&arg.second);
	}

	// Create sgd optimizer
	Optimizer* opt = OptimizerRegistry::Find("sgd");
	opt->SetParam("rescale_grad",1.0 / batch_size)
		->SetParam("lr",learning_rate)
		->SetParam("wd",weight_decay);
	std::unique_ptr<lrscheduler> lr_sch(new FactorScheduler(5000,0.1));
	opt->Setlrscheduler(std::move(lr_sch));

	// Create executor by binding parameters to the model
	auto *exec = net.SimpleBind(ctx,args);
	auto arg_names = net.ListArguments();

	// Create metrics
	Accuracy train_acc,val_acc;

	// Start training
	for (int iter = 0; iter < max_epoch; ++iter)
	{
		int samples = 0;
		train_iter.Reset();
		train_acc.Reset();

		auto tic = chrono::system_clock::Now();
		while (train_iter.Next())
		{
			samples += batch_size;
			auto data_batch = train_iter.GetDataBatch();
			// Data provided by DataIter are stored in memory,should be copied to GPU first.
			data_batch.data.copyTo(&args["X"]);
			data_batch.label.copyTo(&args["label"]);
			// copyTo is imperative,need to wait for it to complete.
			ndarray::WaitAll();

			// Compute gradients
			exec->Forward(true);
			exec->Backward();

			// Update parameters
			for (size_t i = 0; i < arg_names.size(); ++i)
			{
				if (arg_names[i] == "X" || arg_names[i] == "label") continue;
				opt->Update(i,exec->arg_arrays[i],exec->grad_arrays[i]);
			}
			// Update metric
			train_acc.Update(data_batch.label,exec->outputs[0]);
		}
		// one epoch of training is finished
		auto toc = chrono::system_clock::Now();
		float duration = chrono::duration_cast<chrono::milliseconds>(toc - tic).count() / 1000.0;
		LG << "Epoch[" << iter << "] " << samples / duration \
			<< " samples/sec " << "Train-Accuracy=" << train_acc.Get();;

		val_iter.Reset();
		val_acc.Reset();
		while (val_iter.Next())
		{
			auto data_batch = val_iter.GetDataBatch();
			data_batch.data.copyTo(&args["X"]);
			data_batch.label.copyTo(&args["label"]);
			ndarray::WaitAll();

			// Only forward pass is enough as no gradient is needed when evaluating
			exec->Forward(false);
			val_acc.Update(data_batch.label,exec->outputs[0]);
		}
		LG << "Epoch[" << iter << "] Val-Accuracy=" << val_acc.Get();
	}

	// end training
	auto end_time = chrono::system_clock::Now();
	float total_duration = chrono::duration_cast<chrono::milliseconds>(end_time - start_time).count() / 1000.0;
	cout << "total duration: " << total_duration << " s" << endl;

	cout << "==== mlp training end ====" << endl;

	//delete exec;
	MXNotifyShutdown();

	getchar(); // wait here
	return 0;
}

编译生成目录

  • 预先把mnist数据拷进去,维持相对目录结构
  • 在执行目录也要把所有依赖的dll拷贝进来



运行结果


在官方的example里面有mlp的cpu和gpu两个版本,有兴趣的话可以跑起来做一个对比

其实,在某些数据量小的情况下,gpu版本并不明显比cpu版本消耗的训练时间少

至此,大功告成

windows下编译mxnet并使用C++训练模型的更多相关文章

  1. HTML实现代码雨源码及效果示例

    这篇文章主要介绍了HTML实现代码雨源码及效果示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

  2. ios8 – iOS 8上的ptrace

    我试图在ptrace上调用一个像thisptrace一样的函数;但是当我尝试使用#include导入它时,Xcode会给我一个错误’sys/ptrace.h’文件找不到.我错过了什么,我是否需要导入一个库,或者这在iOS上根本不可用?

  3. XCode 3.2 Ruby和Python模板

    在xcode3.2下,我的ObjectiveCPython/Ruby项目仍然可以打开更新和编译,但是你无法创建新项目.鉴于xcode3.2中缺少ruby和python的所有痕迹(即创建项目并添加新的ruby/python文件),是否有一种简单的方法可以再次安装模板?我发现了一些关于将它们复制到某个文件夹的信息,但我似乎无法让它工作,我怀疑文件夹的位置已经改变为3.2.解决方法3.2中的应用程序模板

  4. .dylib在Debug中链接,在XCode中找不到适用于iPhone的版本

    所以我已经将libxml2.2.dylib库包含在我的iPhoneXCode项目中,以创建一些Xml和XPath解析实用程序.当我编译并运行在模拟器和设备的调试模式时,我没有问题,但是,当我切换到发布模式我得到…

  5. 在编译时编译Xcode中的C类错误:stl vector

    我有一个C类,用gcc和可视化工作室中的寡妇在linux上编译.boid.h:并在boid.cpp中:但是,当我在Xcode中编译此代码时,我收到以下错误:有任何想法吗?我以为你可以使用C/C++代码并在Xcode中编译没有问题?.m文件被视为具有Objective-C扩展名的.c文件..mm文件被视为具有Objective-C扩展名的.cpp文件,那么它被称为Objective-C只需将.m文件重命名为.mm,右键单击或按住Ctrl键并在Xcode中的文件中选择重命名.

  6. 在编译的iOS应用程序(IPA)中加密内容

    由于IPA结构只是一个压缩文件,包含编译代码媒体内容,如图像&音频,我如何保护内容免受别人的窃取?是否有加密可以添加到IPA?

  7. ios – Swift 4向后兼容性

    一起使用.有没有办法在两个版本的Xcode中使这个工作?Swift4是否应该向后兼容?

  8. 如何从Haxe创建iOS-和OSX-库并在本机应用程序中使用它?

    我有一个在Haxe上编写自己的协议,数据结构和逻辑的跨平台实现.如何在iOS和OSX的企业应用程序中构建和使用它?

  9. 源码推荐:简化Swift编写的iOS动画,iOS Material Design库

    本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请发送邮件至dio@foxmail.com举报,一经查实,本站将立刻删除。

  10. Swift与OC混合编译

    SWift调用OC新建swift文件此时系统自动生成-Bridging-Header.h文件并且TARGETS->BuildSettings->Objective-CBridgingHeader(搜索bridg)选项中会自动填入以上头文件的路径在-Bridging-Header.h中#import要调用的OC对象头文件OC调用Swift在OC文件中#import“

随机推荐

  1. static – 在页面之间共享数据的最佳实践

    我想知道在UWP的页面之间发送像’selectedItem’等变量的最佳做法是什么?创建一个每个页面都知道的静态全局变量类是一个好主意吗?

  2. .net – 为Windows窗体控件提供百分比宽度/高度

    WindowsForm开发的新手,但在Web开发方面经验丰富.有没有办法为Windows窗体控件指定百分比宽度/高度,以便在用户调整窗口大小时扩展/缩小?当窗口调整大小时,可以编写代码来改变控件的宽度/高度,但我希望有更好的方法,比如在HTML/CSS中.在那儿?

  3. 使用Windows Azure查询表存储数据

    我需要使用特定帐户吗?>将应用程序部署到Azure服务后,如何查询数据?GoogleAppEngine有一个数据查看器/查询工具,Azure有类似的东西吗?>您可以看到的sqlExpressintance仅在开发结构中,并且一旦您表示没有等效,所以请小心使用它.>您可以尝试使用Linqpad查询表格.看看JamieThomson的thispost.

  4. windows – SetupDiGetClassDevs是否与文档中的设备实例ID一起使用?

    有没有更好的方法可以使用DBT_DEVICEARRIVAL事件中的数据获取设备的更多信息?您似乎必须指定DIGCF_ALLCLASSES标志以查找与给定设备实例ID匹配的所有类,或者指定ClassGuid并使用DIGCF_DEFAULT标志.这对我有用:带输出:

  5. Windows Live ID是OpenID提供商吗?

    不,WindowsLiveID不是OpenID提供商.他们使用专有协议.自从他们的“测试版”期结束以来,他们从未宣布计划继续它.

  6. 如果我在代码中进行了更改,是否需要重新安装Windows服务?

    我写了一个Windows服务并安装它.现在我对代码进行了一些更改并重新构建了解决方案.我还应该重新安装服务吗?不,只需停止它,替换文件,然后重新启动它.

  7. 带有双引号的字符串回显使用Windows批处理输出文件

    我正在尝试使用Windows批处理文件重写配置文件.我循环遍历文件的行并查找我想要用指定的新行替换的行.我有一个’函数’将行写入文件问题是%Text%是一个嵌入双引号的字符串.然后失败了.可能还有其他角色也会导致失败.如何才能使用配置文件中的所有文本?尝试将所有“在文本中替换为^”.^是转义字符,因此“将被视为常规字符你可以尝试以下方法:其他可能导致错误的字符是:

  8. .net – 将控制台应用程序转换为服务?

    我正在寻找不同的优势/劣势,将我们长期使用的控制台应用程序转换为Windows服务.我们为ActiveMQ使用了一个叫做java服务包装器的东西,我相信人们告诉我你可以用它包装任何东西.这并不是说你应该用它包装任何东西;我们遇到了这个问题.控制台应用程序是一个.NET控制台应用程序,默认情况下会将大量信息记录到控制台,尽管这是可配置的.任何推荐?我们应该在VisualStudio中将其重建为服务吗?我使用“-install”/“-uninstall”开关执行此操作.例如,seehere.

  9. windows – 捕获外部程序的STDOUT和STDERR *同时*它正在执行(Ruby)

    哦,我在Windows上:-(实际上,它比我想象的要简单,这看起来很完美:…是的,它适用于Windows!

  10. windows – 当我试图批量打印变量时,为什么我得到“Echo is on”

    我想要执行一个简单的批处理文件脚本:当我在XP中运行时,它给了我预期的输出,但是当我在Vista或Windows7中运行它时,我在尝试打印值时得到“EchoisOn”.以下是程序的输出:摆脱集合表达式中的空格.等号(=)的两侧可以并且应该没有空格BTW:我通常在@echo关闭的情况下启动所有批处理文件,并以@echo结束它们,所以我可以避免将代码与批处理文件的输出混合.它只是使您的批处理文件输出更好,更清洁.

返回
顶部