@@ -37,13 +37,13 @@ interface ScopeState {
3737}
3838
3939/**
40- * @docalias (...inputs : Tensor[] ) => {
40+ * @docalias (a : Tensor, b: Tensor,... ) => {
4141 * value: Tensor,
42- * gradFunc: (dy: Tensor) => Tensor[]
42+ * gradFunc: (dy: Tensor) => Tensor|Tensor []
4343 * }
4444 */
4545export type CustomGradientFunc < T extends Tensor > = ( ...args : Tensor [ ] ) => {
46- value : T , gradFunc : ( dy : T ) => Tensor [ ] ;
46+ value : T , gradFunc : ( dy : T ) => Tensor | Tensor [ ] ;
4747} ;
4848
4949export interface TensorManager {
@@ -287,23 +287,27 @@ export class Engine implements TensorManager {
287287 }
288288
289289 /**
290- * Returns gradients of `f` w.r.t. each of the `xs`. The gradients returned
291- * are of the same length as `xs`, but some might be null if `f` was not
292- * a function of that `x`. It also takes optional dy to multiply the gradient,
293- * which defaults to `1`.
290+ * Returns gradients of `f` with respect to each of the `xs`. The gradients
291+ * returned are of the same length as `xs`, but some might be null if `f` was
292+ * not a function of that `x`. It also takes optional dy to multiply the
293+ * gradient, which defaults to `1`.
294294 */
295- gradients < T extends Tensor > ( f : ( ) => T , xs : Tensor [ ] , dy ?: T ) :
296- { value : T , grads : Tensor [ ] } {
295+ gradients < T extends Tensor > (
296+ f : ( ) => T , xs : Tensor [ ] , dy ?: T ,
297+ allowNoGradients = false ) : { value : T , grads : Tensor [ ] } {
297298 return tidy ( 'gradients' , ( ) => {
298299 const y = f ( ) ;
300+ util . assert (
301+ y instanceof Tensor ,
302+ 'The result y returned by f() must be a tensor.' ) ;
299303 // Filter out the nodes that don't connect x => y.
300304 const filteredTape =
301305 tape_util . getFilteredNodesXToY ( this . activeTape , xs , y ) ;
302- if ( filteredTape . length === 0 && xs . length > 0 ) {
306+ if ( ! allowNoGradients && filteredTape . length === 0 && xs . length > 0 ) {
303307 throw new Error (
304- ` Cannot compute gradient: y is not a function of \`x\`s. ` +
305- `Make sure the xs you are computing gradients with respect ` +
306- ` to are used inside the gradient function.` ) ;
308+ ' Cannot compute gradient of y=f(x) with respect to x. Make sure ' +
309+ 'that the f you passed encloses all operations that lead from x ' +
310+ ' to y.' ) ;
307311 }
308312
309313 const accumulatedGradientMap : { [ tensorId : number ] : Tensor } = { } ;
@@ -319,21 +323,50 @@ export class Engine implements TensorManager {
319323
320324 customGrad < T extends Tensor > ( f : CustomGradientFunc < T > ) :
321325 ( ...args : Tensor [ ] ) => T {
322- this . customGradientDepth ++ ;
323-
326+ util . assert (
327+ util . isFunction ( f ) ,
328+ 'The f passed in customGrad(f) must be a function.' ) ;
324329 return ( ...inputs : Tensor [ ] ) : T => {
325- let gradientsFunc : ( dy : T ) => Tensor [ ] ;
330+ util . assert (
331+ inputs . every ( t => t instanceof Tensor ) ,
332+ 'The args passed in customGrad(f)(x1, x2,...) must all be tensors' ) ;
333+ this . customGradientDepth ++ ;
334+
335+ let gradientsFunc : ( dy : T ) => Tensor | Tensor [ ] ;
326336 const gradientsMode = true ;
327337 const result = tidy ( f . name , ( ) => {
328338 const { value, gradFunc} = f ( ...inputs ) ;
339+ util . assert (
340+ value instanceof Tensor ,
341+ 'The function f passed in customGrad(f) must return an object ' +
342+ 'where `obj.value` is a tensor' ) ;
343+ util . assert (
344+ util . isFunction ( gradFunc ) ,
345+ 'The function f passed in customGrad(f) must return an object ' +
346+ 'where `obj.gradFunc` is a function.' ) ;
329347 gradientsFunc = gradFunc ;
330348 return value ;
331349 } , gradientsMode ) ;
332350
333351 this . customGradientDepth -- ;
334352
335353 if ( this . shouldRecord ( ) ) {
336- this . addTapeNode ( inputs , result , gradientsFunc ) ;
354+ const gradFunc = ( dy : T ) : Tensor [ ] => {
355+ const res = gradientsFunc ( dy ) ;
356+ const grads : Tensor [ ] = Array . isArray ( res ) ? res : [ res ] ;
357+ util . assert (
358+ grads . length === inputs . length ,
359+ 'The function f passed in customGrad(f) must return an object ' +
360+ 'where `obj.gradFunc` is a function that returns the same ' +
361+ 'number of tensors as inputs passed to f(...).' ) ;
362+ util . assert (
363+ grads . every ( t => t instanceof Tensor ) ,
364+ 'The function f passed in customGrad(f) must return an object ' +
365+ 'where `obj.gradFunc` is a function that returns a list of ' +
366+ 'only tensors.' ) ;
367+ return grads ;
368+ } ;
369+ this . addTapeNode ( inputs , result , gradFunc ) ;
337370 }
338371 return result ;
339372 } ;
0 commit comments