22# -*- coding: utf-8 -*-
33
44from collections .abc import Iterable
5- from tensorlayerx .nn .core .common import _save_weights , _load_weights , _save_standard_weights_dict , _load_standard_weights_dict
5+ from tensorlayerx .nn .core .common import _save_weights , _load_weights , \
6+ _save_standard_weights_dict , _load_standard_weights_dict
7+ from .utils import WithLoss , WithGradPD , WithGradMS , WithGradTF , TrainOneStepWithPD , \
8+ TrainOneStepWithMS , TrainOneStepWithTH , TrainOneStepWithTF , GradWrap
69import tensorlayerx as tlx
710from tensorlayerx .nn import Module
811import numpy as np
1114if tlx .BACKEND == 'tensorflow' :
1215 import tensorflow as tf
1316if tlx .BACKEND == 'mindspore' :
14- from mindspore .ops import composite
1517 from mindspore .ops import operations as P
16- from mindspore .common import ParameterTuple
1718if tlx .BACKEND == 'paddle' :
1819 import paddle as pd
1920if tlx .BACKEND == 'torch' :
@@ -108,6 +109,12 @@ def train(self, n_epoch, train_dataset=None, test_dataset=False, print_train_bat
108109 train_weights = self .train_weights , optimizer = self .optimizer , metrics = self .metrics ,
109110 print_train_batch = print_train_batch , print_freq = print_freq , test_dataset = test_dataset
110111 )
112+ elif tlx .BACKEND == 'torch' :
113+ self .th_train (
114+ n_epoch = n_epoch , train_dataset = train_dataset , network = self .network , loss_fn = self .loss_fn ,
115+ train_weights = self .train_weights , optimizer = self .optimizer , metrics = self .metrics ,
116+ print_train_batch = print_train_batch , print_freq = print_freq , test_dataset = test_dataset
117+ )
111118
112119 def eval (self , test_dataset ):
113120 self .network .set_eval ()
@@ -436,10 +443,9 @@ def th_train(
436443
437444 train_loss += loss
438445 if metrics :
439- pass
440- # metrics.update(output, y_batch)
441- # train_acc += metrics.result()
442- # metrics.reset()
446+ metrics .update (output , y_batch )
447+ train_acc += metrics .result ()
448+ metrics .reset ()
443449 else :
444450 train_acc += (output .argmax (1 ) == y_batch ).type (torch .float ).sum ().item ()
445451 n_iter += 1
@@ -454,180 +460,23 @@ def th_train(
454460 print (" train loss: {}" .format (train_loss / n_iter ))
455461 print (" train acc: {}" .format (train_acc / n_iter ))
456462
457-
458- class WithLoss (Module ):
459- """
460- High-Level API for Training or Testing.
461-
462- Wraps the network with loss function. This Module accepts data and label as inputs and
463- the computed loss will be returned.
464-
465- Parameters
466- ----------
467- backbone : tensorlayer model
468- The tensorlayer network.
469- loss_fn : function
470- Objective function
471-
472- Methods
473- ---------
474- forward()
475- Model inference.
476-
477- Examples
478- --------
479- >>> import tensorlayerx as tlx
480- >>> net = vgg16()
481- >>> loss_fn = tlx.losses.softmax_cross_entropy_with_logits
482- >>> net_with_loss = tlx.model.WithLoss(net, loss_fn)
483-
484- """
485-
486- def __init__ (self , backbone , loss_fn ):
487- super (WithLoss , self ).__init__ ()
488- self ._backbone = backbone
489- self ._loss_fn = loss_fn
490-
491- def forward (self , data , label ):
492- out = self ._backbone (data )
493- return self ._loss_fn (out , label )
494-
495- @property
496- def backbone_network (self ):
497- return self ._backbone
498-
499-
500- class GradWrap (Module ):
501- """ GradWrap definition """
502-
503- def __init__ (self , network , trainable_weights ):
504- super (GradWrap , self ).__init__ (auto_prefix = False )
505- self .network = network
506- self .weights = ParameterTuple (trainable_weights )
507-
508- def forward (self , x , label ):
509- return composite .GradOperation (get_by_list = True )(self .network , self .weights )(x , label )
510-
511-
512- class WithGradMS (Module ):
513- "Module that returns the gradients."
514-
515- def __init__ (self , network , loss_fn = None , sens = None , optimizer = None ):
516- super (WithGradMS , self ).__init__ ()
517- self .network = network
518- self .loss_fn = loss_fn
519- self .weights = ParameterTuple (network .trainable_weights )
520- self .grad = composite .GradOperation (get_by_list = True , sens_param = (sens is not None ))
521- self .sens = sens
522- self .optimizer = optimizer
523- if self .loss_fn is None :
524- self .network_with_loss = network
525- else :
526- self .network_with_loss = WithLoss (self .network , self .loss_fn )
527- self .network .set_train ()
528-
529- def forward (self , inputs , label ):
530- grads = self .grad (self .network_with_loss , self .weights )(inputs , label )
531- return grads
532-
533-
534- class WithGradTF (object ):
535-
536- def __init__ (self , network , loss_fn = None , optimizer = None ):
537- self .network = network
538- self .loss_fn = loss_fn
539- self .train_weights = self .network .trainable_weights
540- self .optimizer = optimizer
541- if loss_fn is None :
542- self .network_with_loss = network
543- else :
544- self .network_with_loss = WithLoss (self .network , self .loss_fn )
545- self .network .set_train ()
546-
547- def __call__ (self , inputs , label ):
548- with tf .GradientTape () as tape :
549- loss = self .network_with_loss (inputs , label )
550- grads = tape .gradient (loss , self .train_weights )
551- return grads
552-
553-
554- class WithGradPD (object ):
555-
556- def __init__ (self , network , loss_fn = None , optimizer = None ):
557- self .network = network
558- self .loss_fn = loss_fn
559- self .train_weights = self .network .trainable_weights
560- self .optimizer = optimizer
561- if loss_fn is None :
562- self .network_with_loss = network
563- else :
564- self .network_with_loss = WithLoss (self .network , self .loss_fn )
565- self .network .set_train ()
566-
567- def __call__ (self , inputs , label ):
568- loss = self .network_with_loss (inputs , label )
569- grads = self .optimizer .gradient (loss , self .train_weights )
570- return grads
571-
572-
573- class TrainOneStepWithTF (object ):
574-
575- def __init__ (self , net_with_loss , optimizer , train_weights ):
576- self .net_with_loss = net_with_loss
577- self .optimzer = optimizer
578- self .train_weights = train_weights
579-
580- def __call__ (self , data , label ):
581- with tf .GradientTape () as tape :
582- loss = self .net_with_loss (data , label )
583- grad = tape .gradient (loss , self .train_weights )
584- self .optimzer .apply_gradients (zip (grad , self .train_weights ))
585- return loss
586-
587-
588- class TrainOneStepWithMS (object ):
589-
590- def __init__ (self , net_with_loss , optimizer , train_weights ):
591- self .net_with_loss = net_with_loss
592- self .optimizer = optimizer
593- self .train_weights = train_weights
594- self .net_with_loss = net_with_loss
595- self .train_network = GradWrap (net_with_loss , train_weights )
596-
597- def __call__ (self , data , label ):
598- loss = self .net_with_loss (data , label )
599- grads = self .train_network (data , label )
600- self .optimizer .apply_gradients (zip (grads , self .train_weights ))
601- loss = loss .asnumpy ()
602- return loss
603-
604-
605- class TrainOneStepWithPD (object ):
606-
607- def __init__ (self , net_with_loss , optimizer , train_weights ):
608- self .net_with_loss = net_with_loss
609- self .optimizer = optimizer
610- self .train_weights = train_weights
611-
612- def __call__ (self , data , label ):
613- loss = self .net_with_loss (data , label )
614- grads = self .optimizer .gradient (loss , self .train_weights )
615- self .optimizer .apply_gradients (zip (grads , self .train_weights ))
616- return loss .numpy ()
617-
618-
619- class TrainOneStepWithTH (object ):
620-
621- def __init__ (self , net_with_loss , optimizer , train_weights ):
622- self .net_with_loss = net_with_loss
623- self .optimizer = optimizer
624- self .train_weights = train_weights
625-
626- def __call__ (self , data , label ):
627- loss = self .net_with_loss (data , label )
628- grads = self .optimizer .gradient (loss , self .train_weights )
629- self .optimizer .apply_gradients (zip (grads , self .train_weights ))
630- return loss
463+ if test_dataset :
464+ # use training and evaluation sets to evaluate the model every print_freq epoch
465+ if epoch + 1 == 1 or (epoch + 1 ) % print_freq == 0 :
466+ network .set_eval ()
467+ val_loss , val_acc , n_iter = 0 , 0 , 0
468+ for X_batch , y_batch in test_dataset :
469+ _logits = network (X_batch ) # is_train=False, disable dropout
470+ val_loss += loss_fn (_logits , y_batch , name = 'eval_loss' )
471+ if metrics :
472+ metrics .update (_logits , y_batch )
473+ val_acc += metrics .result ()
474+ metrics .reset ()
475+ else :
476+ val_acc += (_logits .argmax (1 ) == y_batch ).type (torch .float ).sum ().item ()
477+ n_iter += 1
478+ print (" val loss: {}" .format (val_loss / n_iter ))
479+ print (" val acc: {}" .format (val_acc / n_iter ))
631480
632481
633482class WithGrad (object ):
@@ -713,3 +562,11 @@ def __init__(self, net_with_loss, optimizer, train_weights):
713562 def __call__ (self , data , label ):
714563 loss = self .net_with_train (data , label )
715564 return loss
565+
566+
567+ class TrainOneStepWithGradientClipping (object ):
568+ def __init__ (self ):
569+ pass
570+
571+ def __call__ (self , data , label ):
572+ pass
0 commit comments