@@ -29,6 +29,8 @@ pub enum DiffMode {
2929 Forward ,
3030 /// The target function, to be created using reverse mode AD.
3131 Reverse ,
32+ /// The target function, to be created using batching.
33+ Batch ,
3234}
3335
3436/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
@@ -69,6 +71,12 @@ pub enum DiffActivity {
6971 /// length of a slice/vec. This is used for safety checks on slices.
7072 /// The integer (if given) specifies the size of the slice element in bytes.
7173 FakeActivitySize ( Option < u32 > ) ,
74+ /// Batching mode A
75+ Vector ,
76+ /// Batching mode B, missing implementation (only available as part of autodiff through dupv)
77+ // Leaf,
78+ /// Batching mode C, scalar.
79+ Scalar ,
7280}
7381
7482impl DiffActivity {
@@ -130,6 +138,7 @@ impl Display for DiffMode {
130138 DiffMode :: Source => write ! ( f, "Source" ) ,
131139 DiffMode :: Forward => write ! ( f, "Forward" ) ,
132140 DiffMode :: Reverse => write ! ( f, "Reverse" ) ,
141+ DiffMode :: Batch => write ! ( f, "Batch" ) ,
133142 }
134143 }
135144}
@@ -153,6 +162,13 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
153162 || activity == DiffActivity :: Active
154163 || activity == DiffActivity :: ActiveOnly
155164 }
165+ DiffMode :: Batch => {
166+ // Batching is a special case, since we don't compute derivatives wrt. the return value.
167+ // We just compute derivatives wrt. the inputs, so we can ignore the return value.
168+ activity == DiffActivity :: Const
169+ || activity == DiffActivity :: Vector
170+ || activity == DiffActivity :: Scalar
171+ }
156172 }
157173}
158174
@@ -186,6 +202,11 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
186202 DiffMode :: Reverse => {
187203 matches ! ( activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const )
188204 }
205+ DiffMode :: Batch => {
206+ // Batching is a special case, since we don't compute derivatives wrt. the return value.
207+ // We just compute derivatives wrt. the inputs, so we can ignore the return value.
208+ matches ! ( activity, Const | Vector )
209+ }
189210 } ;
190211}
191212
@@ -203,6 +224,8 @@ impl Display for DiffActivity {
203224 DiffActivity :: Duplicated => write ! ( f, "Duplicated" ) ,
204225 DiffActivity :: DuplicatedOnly => write ! ( f, "DuplicatedOnly" ) ,
205226 DiffActivity :: FakeActivitySize ( s) => write ! ( f, "FakeActivitySize({:?})" , s) ,
227+ DiffActivity :: Vector => write ! ( f, "Vector" ) ,
228+ DiffActivity :: Scalar => write ! ( f, "Scalar" ) ,
206229 }
207230 }
208231}
@@ -216,6 +239,7 @@ impl FromStr for DiffMode {
216239 "Source" => Ok ( DiffMode :: Source ) ,
217240 "Forward" => Ok ( DiffMode :: Forward ) ,
218241 "Reverse" => Ok ( DiffMode :: Reverse ) ,
242+ "Batch" => Ok ( DiffMode :: Batch ) ,
219243 _ => Err ( ( ) ) ,
220244 }
221245 }
@@ -235,6 +259,8 @@ impl FromStr for DiffActivity {
235259 "DualvOnly" => Ok ( DiffActivity :: DualvOnly ) ,
236260 "Duplicated" => Ok ( DiffActivity :: Duplicated ) ,
237261 "DuplicatedOnly" => Ok ( DiffActivity :: DuplicatedOnly ) ,
262+ "Scalar" => Ok ( DiffActivity :: Scalar ) ,
263+ "Vector" => Ok ( DiffActivity :: Vector ) ,
238264 _ => Err ( ( ) ) ,
239265 }
240266 }
0 commit comments