|
| 1 | +use std::os::raw::{c_char, c_uint}; |
1 | 2 | use std::ptr;
|
2 | 3 |
|
3 | 4 | use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
|
| 5 | +use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree}; |
4 | 6 | use rustc_codegen_ssa::ModuleCodegen;
|
5 | 7 | use rustc_codegen_ssa::common::TypeKind;
|
6 | 8 | use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
|
@@ -512,3 +514,128 @@ pub(crate) fn differentiate<'ll>(
|
512 | 514 |
|
513 | 515 | Ok(())
|
514 | 516 | }
|
| 517 | + |
| 518 | +/// Converts a Rust TypeTree to Enzyme's internal TypeTree format |
| 519 | +/// |
| 520 | +/// This function takes a Rust-side TypeTree (from rustc_ast::expand::typetree) |
| 521 | +/// and converts it to Enzyme's internal C++ TypeTree representation that |
| 522 | +/// Enzyme can understand during differentiation analysis. |
| 523 | +fn to_enzyme_typetree( |
| 524 | + rust_typetree: RustTypeTree, |
| 525 | + data_layout: &str, |
| 526 | + llcx: &llvm::Context, |
| 527 | +) -> llvm::TypeTree { |
| 528 | + // Start with an empty TypeTree |
| 529 | + let mut enzyme_tt = llvm::TypeTree::new(); |
| 530 | + |
| 531 | + // Convert each Type in the Rust TypeTree to Enzyme format |
| 532 | + for rust_type in rust_typetree.0 { |
| 533 | + let concrete_type = match rust_type.kind { |
| 534 | + rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything, |
| 535 | + rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer, |
| 536 | + rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer, |
| 537 | + rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half, |
| 538 | + rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float, |
| 539 | + rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double, |
| 540 | + rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown, |
| 541 | + }; |
| 542 | + |
| 543 | + // Create a TypeTree for this specific type |
| 544 | + let type_tt = llvm::TypeTree::from_type(concrete_type, llcx); |
| 545 | + |
| 546 | + // Apply offset if specified |
| 547 | + let type_tt = if rust_type.offset == -1 { |
| 548 | + type_tt // -1 means everywhere/no specific offset |
| 549 | + } else { |
| 550 | + // Apply specific offset positioning |
| 551 | + type_tt.shift(data_layout, rust_type.offset, rust_type.size as isize, 0) |
| 552 | + }; |
| 553 | + |
| 554 | + // Merge this type into the main TypeTree |
| 555 | + enzyme_tt = enzyme_tt.merge(type_tt); |
| 556 | + } |
| 557 | + |
| 558 | + enzyme_tt |
| 559 | +} |
| 560 | + |
| 561 | +/// Attaches TypeTree information to LLVM function as enzyme_type attributes. |
| 562 | +/// |
| 563 | +/// This function converts Rust TypeTrees to Enzyme format and attaches them as |
| 564 | +/// LLVM string attributes. Enzyme reads these attributes during autodiff analysis |
| 565 | +/// to understand the memory layout and generate correct derivative code. |
| 566 | +/// |
| 567 | +/// # Arguments |
| 568 | +/// * `llmod` - LLVM module containing the function |
| 569 | +/// * `llcx` - LLVM context for creating attributes |
| 570 | +/// * `fn_def` - LLVM function to attach TypeTrees to |
| 571 | +/// * `tt` - Function TypeTree containing input and return type information |
| 572 | +pub(crate) fn add_tt<'ll>( |
| 573 | + llmod: &'ll llvm::Module, |
| 574 | + llcx: &'ll llvm::Context, |
| 575 | + fn_def: &'ll Value, |
| 576 | + tt: FncTree, |
| 577 | +) { |
| 578 | + let inputs = tt.args; |
| 579 | + let ret_tt: RustTypeTree = tt.ret; |
| 580 | + |
| 581 | + // Get LLVM data layout string for TypeTree conversion |
| 582 | + let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; |
| 583 | + let llvm_data_layout = |
| 584 | + std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) |
| 585 | + .expect("got a non-UTF8 data-layout from LLVM"); |
| 586 | + |
| 587 | + // Attribute name that Enzyme recognizes for TypeTree information |
| 588 | + let attr_name = "enzyme_type"; |
| 589 | + let c_attr_name = std::ffi::CString::new(attr_name).unwrap(); |
| 590 | + |
| 591 | + // Attach TypeTree attributes to each input parameter |
| 592 | + // Enzyme uses these to understand parameter memory layouts during differentiation |
| 593 | + for (i, input) in inputs.iter().enumerate() { |
| 594 | + unsafe { |
| 595 | + // Convert Rust TypeTree to Enzyme's internal format |
| 596 | + let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); |
| 597 | + |
| 598 | + // Serialize TypeTree to string format that Enzyme can parse |
| 599 | + let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); |
| 600 | + let c_str = std::ffi::CStr::from_ptr(c_str); |
| 601 | + |
| 602 | + // Create LLVM string attribute with TypeTree information |
| 603 | + let attr = llvm::LLVMCreateStringAttribute( |
| 604 | + llcx, |
| 605 | + c_attr_name.as_ptr(), |
| 606 | + c_attr_name.as_bytes().len() as c_uint, |
| 607 | + c_str.as_ptr(), |
| 608 | + c_str.to_bytes().len() as c_uint, |
| 609 | + ); |
| 610 | + |
| 611 | + // Attach attribute to the specific function parameter |
| 612 | + // Note: ArgumentPlace uses 0-based indexing, but LLVM uses 1-based for arguments |
| 613 | + attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]); |
| 614 | + |
| 615 | + // Free the C string to prevent memory leaks |
| 616 | + llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); |
| 617 | + } |
| 618 | + } |
| 619 | + |
| 620 | + // Attach TypeTree attribute to the return type |
| 621 | + // Enzyme needs this to understand how to handle return value derivatives |
| 622 | + unsafe { |
| 623 | + let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); |
| 624 | + let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner); |
| 625 | + let c_str = std::ffi::CStr::from_ptr(c_str); |
| 626 | + |
| 627 | + let ret_attr = llvm::LLVMCreateStringAttribute( |
| 628 | + llcx, |
| 629 | + c_attr_name.as_ptr(), |
| 630 | + c_attr_name.as_bytes().len() as c_uint, |
| 631 | + c_str.as_ptr(), |
| 632 | + c_str.to_bytes().len() as c_uint, |
| 633 | + ); |
| 634 | + |
| 635 | + // Attach to function return type |
| 636 | + attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]); |
| 637 | + |
| 638 | + // Free the C string |
| 639 | + llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); |
| 640 | + } |
| 641 | +} |
0 commit comments