Skip to content

Commit dd9dee4

Browse files
committed
Add adamax
1 parent ff29db4 commit dd9dee4

File tree

2 files changed

+204
-96
lines changed

2 files changed

+204
-96
lines changed

Deep Learning from Scratch in Python/.ipynb_checkpoints/Gradient Descent Optimization Algorithms-checkpoint.ipynb

Lines changed: 102 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"- [AdaDelta](https://youtu.be/6gvh0IySNMs)\n",
1414
"- [Adam](https://youtu.be/6nqV58NA_Ew)\n",
1515
"- [Nesterov](https://youtu.be/6FrBXv9OcqE)\n",
16+
"- Adamax\n",
1617
"\n",
1718
"## Tests\n",
1819
"In order to demonstrate the algorithms capabilities to optimize a function we used these simple test setup:\n",
@@ -21,7 +22,7 @@
2122
},
2223
{
2324
"cell_type": "code",
24-
"execution_count": 15,
25+
"execution_count": 7,
2526
"metadata": {},
2627
"outputs": [],
2728
"source": [
@@ -383,40 +384,80 @@
383384
" \n",
384385
" if i % 100 == 0:\n",
385386
" print(f\"Iteration {i}\")\n",
387+
" print(model)\n",
388+
" \n",
389+
" \n",
390+
"def adamax(model, xs, ys, learning_rate = 0.1, b1 = 0.9, b2 = 0.999, max_iteration = 1000):\n",
391+
" \"\"\"\n",
392+
" Adamax: This is the adamax optimizer that build upong adam with L_inf norm\n",
393+
" model: The model we want to optimize the parameter on\n",
394+
" xs: the feature of my dataset\n",
395+
" ys: the continous value (target)\n",
396+
" learning_rate: the amount of learning we want to happen at each time step (default is 0.1 and will be updated by the optimizer)\n",
397+
" b1: this is the first decaying average with proposed default value of 0.9 (deep learning purposes)\n",
398+
" b2: this is the second decaying average with proposed default value of 0.999 (deep learning purposes)\n",
399+
" max_iteration: the number of sgd round we want to do before stopping the optimization\n",
400+
" \"\"\"\n",
401+
" \n",
402+
" \n",
403+
" # Variable Initialization\n",
404+
" num_param = len(model.weights)\n",
405+
" m = [0 for _ in range(num_param)] # two m for each parameter\n",
406+
" v = [0 for _ in range(num_param)] # two v for each parameter\n",
407+
" g = [0 for _ in range(num_param)] # two gradient\n",
408+
" \n",
409+
" for t in range(1,max_iteration):\n",
410+
" \n",
411+
" # Calculate the gradients \n",
412+
" x, y = stochastic_sample(xs, ys)\n",
413+
" \n",
414+
" # Get the partial derivatives\n",
415+
" g = model.derivate(x, y)\n",
416+
"\n",
417+
" # Update the m and v parameter\n",
418+
" m = [b1*m_i + (1 - b1)*g_i for m_i, g_i in zip(m, g)]\n",
419+
" v = [np.maximum(b2*v_i, np.absolute(g_i)) for v_i, g_i in zip(v, g)]\n",
420+
"\n",
421+
" # Bias correction for m only\n",
422+
" m_cor = [m_i / (1 - (b1**t)) for m_i in m]\n",
423+
"\n",
424+
" # Update the parameter\n",
425+
" model.weights = [weight - (learning_rate / np.sqrt(v_i))*m_cor_i for weight, v_i, m_cor_i in zip(model.weights, v, m_cor)]\n",
426+
" \n",
427+
" if t % 100 == 0:\n",
428+
" print(f\"Iteration {t}\")\n",
386429
" print(model)"
387430
]
388431
},
389432
{
390433
"cell_type": "code",
391-
"execution_count": 16,
434+
"execution_count": 10,
392435
"metadata": {},
393436
"outputs": [
394437
{
395438
"name": "stdout",
396439
"output_type": "stream",
397440
"text": [
398-
"Nesterov Accelerated Gradient\n",
399-
"Iteration 0\n",
400-
"y = [0.89010029] + [0.35356173]*x\n",
441+
"Adamax\n",
401442
"Iteration 100\n",
402-
"y = [0.3723535] + [0.9655646]*x\n",
443+
"y = [-0.00777475] + [1.00552222]*x\n",
403444
"Iteration 200\n",
404-
"y = [0.01159183] + [0.9884176]*x\n",
445+
"y = [0.00039666] + [1.00003985]*x\n",
405446
"Iteration 300\n",
406-
"y = [0.00013029] + [1.00053685]*x\n",
447+
"y = [9.6917012e-06] + [0.99999651]*x\n",
407448
"Iteration 400\n",
408-
"y = [0.00110482] + [1.00002591]*x\n",
449+
"y = [2.99389693e-07] + [0.99999994]*x\n",
409450
"Iteration 500\n",
410-
"y = [1.20330159e-05] + [0.99999153]*x\n",
451+
"y = [-3.59743715e-09] + [1.]*x\n",
411452
"Iteration 600\n",
412-
"y = [4.34325924e-05] + [1.00000075]*x\n",
453+
"y = [2.77363348e-11] + [1.]*x\n",
413454
"Iteration 700\n",
414-
"y = [-1.87460485e-05] + [1.00003432]*x\n",
455+
"y = [-6.49186838e-14] + [1.]*x\n",
415456
"Iteration 800\n",
416-
"y = [1.26114336e-05] + [0.99998661]*x\n",
457+
"y = [-3.5671997e-15] + [1.]*x\n",
417458
"Iteration 900\n",
418-
"y = [-1.84626241e-06] + [1.0000026]*x\n",
419-
"y = [4.67870626e-06] + [0.99999889]*x\n"
459+
"y = [6.75723728e-17] + [1.]*x\n",
460+
"y = [4.19651576e-17] + [1.]*x\n"
420461
]
421462
}
422463
],
@@ -463,57 +504,60 @@
463504
"adadelta(model, xs, ys)\n",
464505
"print(model)\n",
465506
"\n",
466-
"\n",
467507
"# Adam\n",
468508
"model = Line()\n",
469509
"print(\"Adam\")\n",
470510
"adam(model, xs, ys)\n",
471511
"print(model)\n",
472512
"\n",
473-
"'''\n",
474-
"\n",
475513
"# Nesterov Accelerated Gradient\n",
476514
"model = Line()\n",
477515
"print(\"Nesterov Accelerated Gradient\")\n",
478516
"nesterov(model, xs, ys)\n",
517+
"print(model)\n",
518+
"'''\n",
519+
"\n",
520+
"# Adamax\n",
521+
"model = Line()\n",
522+
"print(\"Adamax\")\n",
523+
"adamax(model, xs, ys)\n",
479524
"print(model)"
480525
]
481526
},
482527
{
483528
"cell_type": "code",
484-
"execution_count": 17,
529+
"execution_count": 11,
485530
"metadata": {},
486531
"outputs": [
487532
{
488533
"name": "stdout",
489534
"output_type": "stream",
490535
"text": [
491-
"Nesterov Accelerated Gradient\n",
492-
"Iteration 0\n",
493-
"y = [0.66878322] + [0.90014342]*x\n",
536+
"Adamax\n",
494537
"Iteration 100\n",
495-
"y = [-0.17903726] + [2.40553571]*x\n",
538+
"y = [0.4369314] + [1.89897329]*x\n",
496539
"Iteration 200\n",
497-
"y = [0.05899865] + [2.04887923]*x\n",
540+
"y = [0.13822549] + [1.96756579]*x\n",
498541
"Iteration 300\n",
499-
"y = [0.00894693] + [1.97536479]*x\n",
542+
"y = [0.03852642] + [1.99071418]*x\n",
500543
"Iteration 400\n",
501-
"y = [0.01374569] + [2.05300722]*x\n",
544+
"y = [0.01035516] + [1.99740938]*x\n",
502545
"Iteration 500\n",
503-
"y = [0.10548793] + [1.91493233]*x\n",
546+
"y = [0.0021898] + [1.99963751]*x\n",
504547
"Iteration 600\n",
505-
"y = [0.00385495] + [2.01632264]*x\n",
548+
"y = [0.00051143] + [1.99991838]*x\n",
506549
"Iteration 700\n",
507-
"y = [0.05427682] + [2.13589417]*x\n",
550+
"y = [9.66789006e-05] + [1.99998553]*x\n",
508551
"Iteration 800\n",
509-
"y = [0.0113579] + [2.00570216]*x\n",
552+
"y = [1.43786098e-05] + [1.99999787]*x\n",
510553
"Iteration 900\n",
511-
"y = [-0.02275247] + [1.96719449]*x\n",
512-
"y = [-0.00900316] + [2.00493823]*x\n"
554+
"y = [2.96576946e-06] + [1.99999894]*x\n",
555+
"y = [5.99822337e-07] + [1.99999991]*x\n"
513556
]
514557
}
515558
],
516559
"source": [
560+
"\n",
517561
"# Here we have a simple line with intercept = 0 and slope = 2\n",
518562
"xs = [1,2,3,4,5,6,7]\n",
519563
"ys = [2,4,6,8,10,12,14]\n",
@@ -560,46 +604,50 @@
560604
"print(\"Adam\")\n",
561605
"adam(model, xs, ys)\n",
562606
"print(model)\n",
563-
"'''\n",
564607
"\n",
565608
"# Nesterov Accelerated Gradient\n",
566609
"model = Line()\n",
567610
"print(\"Nesterov Accelerated Gradient\")\n",
568611
"nesterov(model, xs, ys)\n",
612+
"print(model)\n",
613+
"'''\n",
614+
"\n",
615+
"# Adamax\n",
616+
"model = Line()\n",
617+
"print(\"Adamax\")\n",
618+
"adamax(model, xs, ys)\n",
569619
"print(model)"
570620
]
571621
},
572622
{
573623
"cell_type": "code",
574-
"execution_count": 20,
624+
"execution_count": 12,
575625
"metadata": {},
576626
"outputs": [
577627
{
578628
"name": "stdout",
579629
"output_type": "stream",
580630
"text": [
581-
"Nesterov Accelerated Gradient\n",
582-
"Iteration 0\n",
583-
"y = [0.30475578] + [0.9422567]*x\n",
631+
"Adamax\n",
584632
"Iteration 100\n",
585-
"y = [-0.39695759] + [2.51820116]*x\n",
633+
"y = [1.11731293] + [1.9747599]*x\n",
586634
"Iteration 200\n",
587-
"y = [-1.48096144] + [3.46703259]*x\n",
635+
"y = [1.05362733] + [1.99324401]*x\n",
588636
"Iteration 300\n",
589-
"y = [1.90685378] + [2.11784826]*x\n",
637+
"y = [1.02088689] + [1.99784093]*x\n",
590638
"Iteration 400\n",
591-
"y = [1.03772983] + [1.98962059]*x\n",
639+
"y = [1.00584065] + [1.99878137]*x\n",
592640
"Iteration 500\n",
593-
"y = [1.00457709] + [2.05050135]*x\n",
641+
"y = [1.00157963] + [1.99971604]*x\n",
594642
"Iteration 600\n",
595-
"y = [0.99465739] + [1.99190331]*x\n",
643+
"y = [1.00050744] + [1.99986608]*x\n",
596644
"Iteration 700\n",
597-
"y = [1.00337418] + [1.99869932]*x\n",
645+
"y = [1.00011891] + [1.99998385]*x\n",
598646
"Iteration 800\n",
599-
"y = [0.99671896] + [1.98843029]*x\n",
647+
"y = [1.00003134] + [1.99999406]*x\n",
600648
"Iteration 900\n",
601-
"y = [1.00465637] + [1.99875298]*x\n",
602-
"y = [0.99462305] + [2.00330215]*x\n"
649+
"y = [1.00000663] + [1.99999862]*x\n",
650+
"y = [1.00000132] + [1.99999974]*x\n"
603651
]
604652
}
605653
],
@@ -649,12 +697,18 @@
649697
"print(\"Adam\")\n",
650698
"adam(model, xs, ys)\n",
651699
"print(model)\n",
652-
"'''\n",
653700
"\n",
654701
"# Nesterov Accelerated Gradient\n",
655702
"model = Line()\n",
656703
"print(\"Nesterov Accelerated Gradient\")\n",
657704
"nesterov(model, xs, ys)\n",
705+
"print(model)\n",
706+
"'''\n",
707+
"\n",
708+
"# Adamax\n",
709+
"model = Line()\n",
710+
"print(\"Adamax\")\n",
711+
"adamax(model, xs, ys)\n",
658712
"print(model)"
659713
]
660714
},

0 commit comments

Comments
 (0)