@@ -192,7 +192,6 @@ mod llvm_enzyme {
192
192
/// which becomes expanded to:
193
193
/// ```
194
194
/// #[rustc_autodiff]
195
- /// #[inline(never)]
196
195
/// fn sin(x: &Box<f32>) -> f32 {
197
196
/// f32::sin(**x)
198
197
/// }
@@ -371,7 +370,7 @@ mod llvm_enzyme {
371
370
let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
372
371
let inline_never = outer_normal_attr ( & inline_never_attr, new_id, span) ;
373
372
374
- // We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never) ]`.
373
+ // We're avoid duplicating the attribute `#[rustc_autodiff]`.
375
374
fn same_attribute ( attr : & ast:: AttrKind , item : & ast:: AttrKind ) -> bool {
376
375
match ( attr, item) {
377
376
( ast:: AttrKind :: Normal ( a) , ast:: AttrKind :: Normal ( b) ) => {
@@ -384,23 +383,25 @@ mod llvm_enzyme {
384
383
}
385
384
}
386
385
386
+ let mut has_inline_never = false ;
387
+
387
388
// Don't add it multiple times:
388
389
let orig_annotatable: Annotatable = match item {
389
390
Annotatable :: Item ( ref mut iitem) => {
390
391
if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & attr. kind ) ) {
391
392
iitem. attrs . push ( attr) ;
392
393
}
393
- if ! iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) ) {
394
- iitem . attrs . push ( inline_never . clone ( ) ) ;
394
+ if iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) ) {
395
+ has_inline_never = true ;
395
396
}
396
397
Annotatable :: Item ( iitem. clone ( ) )
397
398
}
398
399
Annotatable :: AssocItem ( ref mut assoc_item, i @ Impl { .. } ) => {
399
400
if !assoc_item. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & attr. kind ) ) {
400
401
assoc_item. attrs . push ( attr) ;
401
402
}
402
- if ! assoc_item. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) ) {
403
- assoc_item . attrs . push ( inline_never . clone ( ) ) ;
403
+ if assoc_item. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) ) {
404
+ has_inline_never = true ;
404
405
}
405
406
Annotatable :: AssocItem ( assoc_item. clone ( ) , i)
406
407
}
@@ -410,9 +411,8 @@ mod llvm_enzyme {
410
411
if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & attr. kind ) ) {
411
412
iitem. attrs . push ( attr) ;
412
413
}
413
- if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) )
414
- {
415
- iitem. attrs . push ( inline_never. clone ( ) ) ;
414
+ if iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) ) {
415
+ has_inline_never = true ;
416
416
}
417
417
}
418
418
_ => unreachable ! ( "stmt kind checked previously" ) ,
@@ -433,11 +433,19 @@ mod llvm_enzyme {
433
433
434
434
let new_id = ecx. sess . psess . attr_id_generator . mk_attr_id ( ) ;
435
435
let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
436
+
437
+ // If the source function has the `#[inline(never)]` attribute, we'll also add it to the diff function
438
+ let mut d_attrs = thin_vec ! [ d_attr] ;
439
+
440
+ if has_inline_never {
441
+ d_attrs. push ( inline_never) ;
442
+ }
443
+
436
444
let d_annotatable = match & item {
437
445
Annotatable :: AssocItem ( _, _) => {
438
446
let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( d_fn) ;
439
447
let d_fn = Box :: new ( ast:: AssocItem {
440
- attrs : thin_vec ! [ d_attr ] ,
448
+ attrs : d_attrs ,
441
449
id : ast:: DUMMY_NODE_ID ,
442
450
span,
443
451
vis,
@@ -447,13 +455,13 @@ mod llvm_enzyme {
447
455
Annotatable :: AssocItem ( d_fn, Impl { of_trait : false } )
448
456
}
449
457
Annotatable :: Item ( _) => {
450
- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr ] , ItemKind :: Fn ( d_fn) ) ;
458
+ let mut d_fn = ecx. item ( span, d_attrs , ItemKind :: Fn ( d_fn) ) ;
451
459
d_fn. vis = vis;
452
460
453
461
Annotatable :: Item ( d_fn)
454
462
}
455
463
Annotatable :: Stmt ( _) => {
456
- let mut d_fn = ecx. item ( span, thin_vec ! [ d_attr ] , ItemKind :: Fn ( d_fn) ) ;
464
+ let mut d_fn = ecx. item ( span, d_attrs , ItemKind :: Fn ( d_fn) ) ;
457
465
d_fn. vis = vis;
458
466
459
467
Annotatable :: Stmt ( Box :: new ( ast:: Stmt {
0 commit comments