@@ -3,7 +3,8 @@ use std::ptr;
3
3
use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
4
4
use rustc_codegen_ssa:: common:: TypeKind ;
5
5
use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
6
- use rustc_middle:: bug;
6
+ use rustc_middle:: { bug, ty} ;
7
+ use rustc_middle:: ty:: { PseudoCanonicalInput , Ty , TyCtxt , TypingEnv } ;
7
8
use tracing:: debug;
8
9
9
10
use crate :: builder:: { Builder , PlaceRef , UNNAMED } ;
@@ -14,6 +15,82 @@ use crate::llvm::{Metadata, True, Type};
14
15
use crate :: value:: Value ;
15
16
use crate :: { attributes, llvm} ;
16
17
18
+ pub ( crate ) fn adjust_activity_to_abi < ' tcx > (
19
+ tcx : TyCtxt < ' tcx > ,
20
+ fn_ty : Ty < ' tcx > ,
21
+ da : & mut Vec < DiffActivity > ,
22
+ ) {
23
+ if !matches ! ( fn_ty. kind( ) , ty:: FnDef ( ..) ) {
24
+ bug ! ( "expected fn def for autodiff, got {:?}" , fn_ty) ;
25
+ }
26
+
27
+ // We don't actually pass the types back into the type system.
28
+ // All we do is decide how to handle the arguments.
29
+ let sig = fn_ty. fn_sig ( tcx) . skip_binder ( ) ;
30
+
31
+ let mut new_activities = vec ! [ ] ;
32
+ let mut new_positions = vec ! [ ] ;
33
+ for ( i, ty) in sig. inputs ( ) . iter ( ) . enumerate ( ) {
34
+ if let Some ( inner_ty) = ty. builtin_deref ( true ) {
35
+ if inner_ty. is_slice ( ) {
36
+ // Now we need to figure out the size of each slice element in memory to allow
37
+ // safety checks and usability improvements in the backend.
38
+ let sty = match inner_ty. builtin_index ( ) {
39
+ Some ( sty) => sty,
40
+ None => {
41
+ panic ! ( "slice element type unknown" ) ;
42
+ }
43
+ } ;
44
+ let pci = PseudoCanonicalInput {
45
+ typing_env : TypingEnv :: fully_monomorphized ( ) ,
46
+ value : sty,
47
+ } ;
48
+
49
+ let layout = tcx. layout_of ( pci) ;
50
+ let elem_size = match layout {
51
+ Ok ( layout) => layout. size ,
52
+ Err ( _) => {
53
+ bug ! ( "autodiff failed to compute slice element size" ) ;
54
+ }
55
+ } ;
56
+ let elem_size: u32 = elem_size. bytes ( ) as u32 ;
57
+
58
+ // We know that the length will be passed as extra arg.
59
+ if !da. is_empty ( ) {
60
+ // We are looking at a slice. The length of that slice will become an
61
+ // extra integer on llvm level. Integers are always const.
62
+ // However, if the slice get's duplicated, we want to know to later check the
63
+ // size. So we mark the new size argument as FakeActivitySize.
64
+ // There is one FakeActivitySize per slice, so for convenience we store the
65
+ // slice element size in bytes in it. We will use the size in the backend.
66
+ let activity = match da[ i] {
67
+ DiffActivity :: DualOnly
68
+ | DiffActivity :: Dual
69
+ | DiffActivity :: Dualv
70
+ | DiffActivity :: DuplicatedOnly
71
+ | DiffActivity :: Duplicated => {
72
+ DiffActivity :: FakeActivitySize ( Some ( elem_size) )
73
+ }
74
+ DiffActivity :: Const => DiffActivity :: Const ,
75
+ _ => bug ! ( "unexpected activity for ptr/ref" ) ,
76
+ } ;
77
+ new_activities. push ( activity) ;
78
+ new_positions. push ( i + 1 ) ;
79
+ }
80
+
81
+ continue ;
82
+ }
83
+ }
84
+ }
85
+ // now add the extra activities coming from slices
86
+ // Reverse order to not invalidate the indices
87
+ for _ in 0 ..new_activities. len ( ) {
88
+ let pos = new_positions. pop ( ) . unwrap ( ) ;
89
+ let activity = new_activities. pop ( ) . unwrap ( ) ;
90
+ da. insert ( pos, activity) ;
91
+ }
92
+ }
93
+
17
94
// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
18
95
// original inputs, as well as metadata and the additional shadow arguments.
19
96
// This function matches the arguments from the outer function to the inner enzyme call.
0 commit comments