博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
输出最大值MXNet实现
阅读量:5318 次
发布时间:2019-06-14

本文共 2224 字,大约阅读时间需要 7 分钟。

网络结构,输入为2个数,先经过10个节点的全连接层,再经过10个节点的ReLu,再经过10个节点的全连接层,再经过1个节点的全连接层,最后输出。

#-*-coding:utf-8-*- import loggingimport mathimport randomimport mxnet as mx # 导入 MXNet 库import numpy as np # 导入 NumPy 库,这是 Python 常用的科学计算库logging.getLogger().setLevel(logging.DEBUG) # 打开调试信息的显示'''设置超参数'''n_sample = 10000 # 训练用的数据点个数batch_size = 10 # 批大小learning_rate = 0.1 # 学习速率n_epoch = 10 # 训练 epoch 数'''生成训练数据'''# 每个数据点是在 (0,1) 之间的 2 个随机数train_in = [[ random.uniform(0, 1) for c in range(2)] for n in range(n_sample)] train_out = [0 for n in range(n_sample)] # 期望输出,先初始化为 0for i in range(n_sample):    # 每个数据点的期望输出是 2 个输入数中的大者    train_out[i] = max(train_in[i][0], train_in[i][1])'''定义train_iter为训练数据的迭代器,data为输入数据,label为标签对应train_out,shuffle代表每个epoch会随机打乱数据'''train_iter = mx.io.NDArrayIter(data = np.array(train_in), label = {'reg_label':np.array(train_out)}, batch_size = batch_size, shuffle = True)'''定义网络结构,src为输入层,fc1,fc2,fc3是全连接层,act1,act2是ReLu层,num_hidden代表神经元个数,data是输入数据,name是输出'''src = mx.sym.Variable('data') # 输入层fc1  = mx.sym.FullyConnected(data = src, num_hidden = 10, name = 'fc1') # 全连接层act1 = mx.sym.Activation(data = fc1, act_type = "relu", name = 'act1') # ReLU层fc2  = mx.sym.FullyConnected(data = act1, num_hidden = 10, name = 'fc2') # 全连接层act2 = mx.sym.Activation(data = fc2, act_type = "relu", name = 'act2') # ReLU层fc3  = mx.sym.FullyConnected(data = act2, num_hidden = 1, name = 'fc3') # 全连接层'''定义net为输出层,采用线性回归输出,MXNet会自动使用MSE作为损失函数,输入数据为fc3,输出层命名为reg'''net = mx.sym.LinearRegressionOutput(data = fc3, name = 'reg') # 输出层'''定义变量module需训练的网络模组,网络的输出symbol为net,期望标签名label_names为reg_label'''module = mx.mod.Module(symbol = net, label_names = (['reg_label']))'''定义module.fit进行训练'''module.fit(    train_iter, # 训练数据的迭代器    eval_data = None, # 在此只训练,不使用测试数据    eval_metric = mx.metric.create('mse'), # 输出 MSE 损失信息    #将权重和偏置初始化为在[-0.5,0.5]间均匀的随机数    initializer=mx.initializer.Uniform(0.5),    optimizer = 'sgd', # 梯度下降算法为 SGD    # 设置学习速率    optimizer_params = {'learning_rate': learning_rate},     num_epoch = n_epoch, # 训练 epoch 数    # 每经过 100 个 batch 输出训练速度     batch_end_callback = None,     epoch_end_callback = None, )#输出最终参数for k in module.get_params():    print(k)

转载于:https://www.cnblogs.com/cold-city/p/10460392.html

你可能感兴趣的文章
10.17动手动脑
查看>>
WPF中Image显示本地图片
查看>>
Windows Phone 7你不知道的8件事
查看>>
实用拜占庭容错算法PBFT
查看>>
java的二叉树树一层层输出,Java构造二叉树、树形结构先序遍历、中序遍历、后序遍历...
查看>>
php仿阿里巴巴,php实现的仿阿里巴巴实现同类产品翻页
查看>>
Node 中异常收集与监控
查看>>
Excel-基本操作
查看>>
面对问题,如何去分析?(分析套路)
查看>>
Excel-逻辑函数
查看>>
面对问题,如何去分析?(日报问题)
查看>>
nodejs vs python
查看>>
poj-1410 Intersection
查看>>
Java多线程基础(一)
查看>>
TCP粘包拆包问题
查看>>
SQL Server中利用正则表达式替换字符串
查看>>
POJ 1015 Jury Compromise(双塔dp)
查看>>
论三星输入法的好坏
查看>>
Linux 终端连接工具 XShell v6.0.01 企业便携版
查看>>
JS写一个简单日历
查看>>