优化算法(一)SGD算法实现

news/2024/6/29 4:02:41 标签: SGD, 深度学习, 神经网络

SGD随机梯度下降算法,和最常用的GD相比,GD每一次迭代都是所有样本都一起进行计算,而SGD是每一次迭代中每个样本分别进行计算,梯度算法的最终目标是减少cost值,训练出最优的参数值,GD每一次迭代都让所有样本去优化参数,一次迭代进行一次优化,而SGD一次只让一个样本去优化参数。

 

贴一下代码:

#SGD(w初始化使用 2/sqrt(l-1))
def SGD_model(X,Y,layer_dims,iter_times,alphs):
    costs = []
    m = X.shape[1]
    n = X.shape[0]
    np.random.seed(3)
    parameters = initialize_parameters(layer_dims)
    for i in range(0,iter_times):
        for j in range(0,m):
            A,caches=forward_propagation(X[:,j].reshape(n,1),parameters)
            cost=cpmpute_cost(A,Y[:,j].reshape(1,1))
            grads=back_propagation(Y[:,j].reshape(1,1),caches,parameters)
            parameters=update_parameters(parameters,grads,alphs)
            costs.append(cost)
        if i%100 == 0:
            print(cost)
    return costs,parameters

 

测试一下:

n=train_data_finalX.shape[0]
layer_dims=[n,20,7,5,1]
costs,parameters=SGD_model(train_data_finalX,train_data_finalY,layer_dims,500,0.0003)
 
y_pred_train=predict(train_data_finalX,parameters)
print('train acc is ',np.mean(y_pred_train == train_data_finalY)*100,'%')    
 
y_pred_test=predict(test_data_finalX,parameters)
print('test acc is ',np.mean(y_pred_test == test_data_finalY)*100,'%')
#可以看到cost减低的是很快的,这里打印出来的cost并不是每一次的cost值,只是挑着打印了几个
#这个cost也不是所有样本cost值加起来得到的,其实还是单个样本的cost值
#过拟合现象还是很严重的

0.6967667264512503
0.3580429544797275
0.1366511629971142
0.013014664339787691
0.005059855441099931
train acc is  100.0 %
test acc is  84.0 %

 

 

特别要注意的是,可能会遇到cost值不降低,或者降低到一定值以后就不变了,这种情况预测值得到的可能都是0,可以试试这几个解决办法:

  1. w参数初始化,不要选择在random以后乘以0.01,如果激活函数使用的是tanh函数,使用(1/sqrt(上一层单元数)),如果使用的是relu则试试(2/sqrt(上一层单元数)),参数初始化对结果有很大的影响力度。
  2. 学习因子设置的小一点
  3. 如果出现严重的过拟合现象,可以试试增加隐藏层,或者隐藏层单元数

http://www.niftyadmin.cn/n/705472.html

相关文章

PL/SQL查看表结构

SET LONG 99999;SET LINESIZE 140 PAGESIZE 1000;SELECT DBMS_METADATA.GET_DDL(&OBJECT_TYPE,&NAME,&SCHEMA) FROM DUAL;转载于:https://www.cnblogs.com/chenlaichao/p/8005318.html

如何将linux下的代码上传到github上

2019独角兽企业重金招聘Python工程师标准>>> 本文适用情景:linux系统,第一次上传,远端没有对应厂库。其它情景仅作参考!1.安装git 首先,你可以试着输入 git,看看系统有没有安装Git: …

Tensorflow(一)搭建环境可能遇到的问题

Anaconda 目前,python3.7是不支持tensorflow的,所以要学习tensorflow需要换成3.5或者3.6的版本 os是win10,以前已经安装了puthon3.6.7和jupyter,直接安装了tensorflow,后来想装一个anaconda3 5.2,anacond…

经典C程序例子解析

题目:一球从100米高度自由落下,每次落地后反跳回原高度的一半;再落下,求它在第10次落地时,共经过多少米?第10次反弹多高?根据本周所学知识可以很轻松的编写源代码为 main(){float sn100.0,hnsn/…

Tensorflow(二)MNIST数据集分类

1.获取数据集 有两种方式可以得到数据集,第一是直接通过mnist input_data.read_data_sets(MNIST_data,one_hot True)进行联网下载,但这个方法可能很慢或者连接不到服务器,所以推荐使用第二个,在MNIST 直接下载数据,…

工作流二次开发之邮箱提醒

2019独角兽企业重金招聘Python工程师标准>>> 为了考虑以后二次开发,和将来的代码增多。调用工作流的接口,大量代码写在自己新建项目中。 工作流接口: public boolean sendMail(Map lhm){ //设置HTTP连接的URL地址&#x…

Tensorflow(三)训练一个简单卷积神经网络

这是吴恩达老师第四课第一周的编程练习,题目是分析图片中手势得到手所表示的数字。 数据集我传到github上,可以下载https://github.com/penguin219/WU_Lesson4_week1 特别要注意的是,如果你使用的是新版本的tensorflow,很有可能…