Skip to content

Commit 04f64af

Browse files
committed
Get args from tuple using fnabi and minor fixes
1 parent f2f116a commit 04f64af

File tree

3 files changed

+56
-36
lines changed

3 files changed

+56
-36
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ mod llvm_enzyme {
1616
use rustc_ast::tokenstream::*;
1717
use rustc_ast::visit::AssocCtxt::*;
1818
use rustc_ast::{
19-
self as ast, AngleBracketedArg, AngleBracketedArgs, AssocItemKind, BindingMode, FnRetTy,
20-
FnSig, GenericArg, GenericArgs, Generics, ItemKind, MetaItemInner, PatKind, Path,
21-
PathSegment, TyKind, Visibility,
19+
self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode,
20+
FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind,
21+
MetaItemInner, PatKind, Path, PathSegment, TyKind, Visibility,
2222
};
2323
use rustc_expand::base::{Annotatable, ExtCtxt};
2424
use rustc_span::{Ident, Span, Symbol, kw, sym};
@@ -554,10 +554,18 @@ mod llvm_enzyme {
554554
let generic_args = generics
555555
.params
556556
.iter()
557-
.map(|p| {
558-
let path = ast::Path::from_ident(p.ident);
559-
let ty = ecx.ty_path(path);
560-
AngleBracketedArg::Arg(GenericArg::Type(ty))
557+
.filter_map(|p| match &p.kind {
558+
GenericParamKind::Type { .. } => {
559+
let path = ast::Path::from_ident(p.ident);
560+
let ty = ecx.ty_path(path);
561+
Some(AngleBracketedArg::Arg(GenericArg::Type(ty)))
562+
}
563+
GenericParamKind::Const { .. } => {
564+
let expr = ecx.expr_path(ast::Path::from_ident(p.ident));
565+
let anon_const = AnonConst { id: ast::DUMMY_NODE_ID, value: expr };
566+
Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const)))
567+
}
568+
GenericParamKind::Lifetime { .. } => None,
561569
})
562570
.collect::<ThinVec<_>>();
563571

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use std::ptr;
33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
44
use rustc_codegen_ssa::common::TypeKind;
55
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
6-
use rustc_middle::{bug, ty};
76
use rustc_middle::ty::{PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
7+
use rustc_middle::{bug, ty};
88
use tracing::debug;
99

1010
use crate::builder::{Builder, PlaceRef, UNNAMED};

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv};
1717
use rustc_middle::{bug, span_bug};
1818
use rustc_span::{Span, Symbol, sym};
1919
use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
20+
use rustc_target::callconv::PassMode;
2021
use rustc_target::spec::PanicStrategy;
2122
use tracing::debug;
2223

@@ -1136,8 +1137,6 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11361137
let ret_ty = sig.output();
11371138
let llret_ty = bx.layout_of(ret_ty).llvm_type(bx);
11381139

1139-
let val_arr: Vec<&'ll Value> = get_args_from_tuple(bx, args[2]);
1140-
11411140
// Get source, diff, and attrs
11421141
let (source_id, source_args) = match fn_args.into_type_list(tcx)[0].kind() {
11431142
ty::FnDef(def_id, source_params) => (def_id, source_params),
@@ -1155,6 +1154,7 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11551154
};
11561155
let fn_diff =
11571156
Instance::try_resolve(tcx, bx.cx.typing_env(), *diff_id, diff_args).unwrap().unwrap();
1157+
let val_arr: Vec<&'ll Value> = get_args_from_tuple(bx, args[2], fn_diff);
11581158
let diff_symbol = symbol_name_for_instance_in_crate(tcx, fn_diff.clone(), LOCAL_CRATE);
11591159

11601160
let diff_attrs = autodiff_attrs(tcx, fn_diff.def_id());
@@ -1181,39 +1181,51 @@ fn codegen_enzyme_autodiff<'ll, 'tcx>(
11811181

11821182
fn get_args_from_tuple<'ll, 'tcx>(
11831183
bx: &mut Builder<'_, 'll, 'tcx>,
1184-
op: OperandRef<'tcx, &'ll Value>,
1184+
tuple_op: OperandRef<'tcx, &'ll Value>,
1185+
fn_instance: Instance<'tcx>,
11851186
) -> Vec<&'ll Value> {
1186-
match op.val {
1187-
OperandValue::Ref(ref place_value) => {
1188-
let mut ret_arr = vec![];
1189-
let tuple_place = PlaceRef { val: *place_value, layout: op.layout };
1190-
1191-
for i in 0..tuple_place.layout.layout.0.fields.count() {
1192-
let field_place = tuple_place.project_field(bx, i);
1193-
let field_layout = tuple_place.layout.field(bx, i);
1194-
let field_ty = field_layout.ty;
1195-
let llvm_ty = field_layout.llvm_type(bx.cx);
1196-
1197-
let field_val = bx.load(llvm_ty, field_place.val.llval, field_place.val.align);
1198-
1199-
match field_ty.kind() {
1200-
ty::Ref(_, inner_ty, _) if matches!(inner_ty.kind(), ty::Slice(_)) => {
1201-
let ptr = bx.extract_value(field_val, 0);
1202-
let len = bx.extract_value(field_val, 1);
1203-
ret_arr.push(ptr);
1204-
ret_arr.push(len);
1187+
let cx = bx.cx;
1188+
let fn_abi = cx.fn_abi_of_instance(fn_instance, ty::List::empty());
1189+
1190+
match tuple_op.val {
1191+
OperandValue::Immediate(val) => vec![val],
1192+
OperandValue::Pair(v1, v2) => vec![v1, v2],
1193+
OperandValue::Ref(ptr) => {
1194+
let tuple_place = PlaceRef { val: ptr, layout: tuple_op.layout };
1195+
1196+
let mut result = Vec::with_capacity(fn_abi.args.len());
1197+
let mut tuple_index = 0;
1198+
1199+
for arg in &fn_abi.args {
1200+
match arg.mode {
1201+
PassMode::Ignore => {}
1202+
PassMode::Direct(_) | PassMode::Cast { .. } => {
1203+
let field = tuple_place.project_field(bx, tuple_index);
1204+
let llvm_ty = field.layout.llvm_type(bx.cx);
1205+
let val = bx.load(llvm_ty, field.val.llval, field.val.align);
1206+
result.push(val);
1207+
tuple_index += 1;
12051208
}
1206-
_ => {
1207-
ret_arr.push(field_val);
1209+
PassMode::Pair(_, _) => {
1210+
let field = tuple_place.project_field(bx, tuple_index);
1211+
let llvm_ty = field.layout.llvm_type(bx.cx);
1212+
let pair_val = bx.load(llvm_ty, field.val.llval, field.val.align);
1213+
result.push(bx.extract_value(pair_val, 0));
1214+
result.push(bx.extract_value(pair_val, 1));
1215+
tuple_index += 1;
1216+
}
1217+
PassMode::Indirect { .. } => {
1218+
let field = tuple_place.project_field(bx, tuple_index);
1219+
result.push(field.val.llval);
1220+
tuple_index += 1;
12081221
}
12091222
}
12101223
}
12111224

1212-
ret_arr
1225+
result
12131226
}
1214-
OperandValue::Pair(v1, v2) => vec![v1, v2],
1215-
OperandValue::Immediate(v) => vec![v],
1216-
OperandValue::ZeroSized => bug!("unexpected `ZeroSized` arg"),
1227+
1228+
OperandValue::ZeroSized => bug!("unexpected ZeroSized argument in get_args_from_tuple"),
12171229
}
12181230
}
12191231

0 commit comments

Comments
 (0)