|
6 | 6 | use std::fmt::{self, Display, Formatter};
|
7 | 7 | use std::str::FromStr;
|
8 | 8 |
|
| 9 | +use crate::expand::typetree::TypeTree; |
9 | 10 | use crate::expand::{Decodable, Encodable, HashStable_Generic};
|
10 | 11 | use crate::ptr::P;
|
11 | 12 | use crate::{Ty, TyKind};
|
@@ -85,6 +86,9 @@ pub struct AutoDiffItem {
|
85 | 86 | /// The name of the function being generated
|
86 | 87 | pub target: String,
|
87 | 88 | pub attrs: AutoDiffAttrs,
|
| 89 | + // Type Tree support |
| 90 | + pub inputs: Vec<TypeTree>, |
| 91 | + pub output: TypeTree, |
88 | 92 | }
|
89 | 93 |
|
90 | 94 | #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
|
@@ -276,14 +280,23 @@ impl AutoDiffAttrs {
|
276 | 280 | !matches!(self.mode, DiffMode::Error | DiffMode::Source)
|
277 | 281 | }
|
278 | 282 |
|
279 |
| - pub fn into_item(self, source: String, target: String) -> AutoDiffItem { |
280 |
| - AutoDiffItem { source, target, attrs: self } |
| 283 | + pub fn into_item( |
| 284 | + self, |
| 285 | + source: String, |
| 286 | + target: String, |
| 287 | + inputs: Vec<TypeTree>, |
| 288 | + output: TypeTree, |
| 289 | + ) -> AutoDiffItem { |
| 290 | + AutoDiffItem { source, target, inputs, output, attrs: self } |
281 | 291 | }
|
282 | 292 | }
|
283 | 293 |
|
284 | 294 | impl fmt::Display for AutoDiffItem {
|
285 | 295 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
286 | 296 | write!(f, "Differentiating {} -> {}", self.source, self.target)?;
|
287 |
| - write!(f, " with attributes: {:?}", self.attrs) |
| 297 | + write!(f, " with attributes: {:?}", self.attrs)?; |
| 298 | + write!(f, " with attributes: {:?}", self.attrs)?; |
| 299 | + write!(f, " with inputs: {:?}", self.inputs)?; |
| 300 | + write!(f, " with output: {:?}", self.output) |
288 | 301 | }
|
289 | 302 | }
|
0 commit comments