|
6 | 6 | use std::fmt::{self, Display, Formatter};
|
7 | 7 | use std::str::FromStr;
|
8 | 8 |
|
| 9 | +use crate::expand::typetree::TypeTree; |
9 | 10 | use crate::expand::{Decodable, Encodable, HashStable_Generic};
|
10 | 11 | use crate::{Ty, TyKind};
|
11 | 12 |
|
@@ -84,6 +85,8 @@ pub struct AutoDiffItem {
|
84 | 85 | /// The name of the function being generated
|
85 | 86 | pub target: String,
|
86 | 87 | pub attrs: AutoDiffAttrs,
|
| 88 | + pub inputs: Vec<TypeTree>, |
| 89 | + pub output: TypeTree, |
87 | 90 | }
|
88 | 91 |
|
89 | 92 | #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
@@ -275,14 +278,22 @@ impl AutoDiffAttrs {
|
275 | 278 | !matches!(self.mode, DiffMode::Error | DiffMode::Source)
|
276 | 279 | }
|
277 | 280 |
|
278 |
| - pub fn into_item(self, source: String, target: String) -> AutoDiffItem { |
279 |
| - AutoDiffItem { source, target, attrs: self } |
| 281 | + pub fn into_item( |
| 282 | + self, |
| 283 | + source: String, |
| 284 | + target: String, |
| 285 | + inputs: Vec<TypeTree>, |
| 286 | + output: TypeTree, |
| 287 | + ) -> AutoDiffItem { |
| 288 | + AutoDiffItem { source, target, inputs, output, attrs: self } |
280 | 289 | }
|
281 | 290 | }
|
282 | 291 |
|
283 | 292 | impl fmt::Display for AutoDiffItem {
|
284 | 293 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
285 | 294 | write!(f, "Differentiating {} -> {}", self.source, self.target)?;
|
286 |
| - write!(f, " with attributes: {:?}", self.attrs) |
| 295 | + write!(f, " with attributes: {:?}", self.attrs)?; |
| 296 | + write!(f, " with inputs: {:?}", self.inputs)?; |
| 297 | + write!(f, " with output: {:?}", self.output) |
287 | 298 | }
|
288 | 299 | }
|
0 commit comments