Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 115 additions & 1 deletion src/mk_graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! This module provides functionality to generate graph visualizations
//! of Rust's MIR in various formats (DOT, D2).

use std::collections::{HashMap, HashSet, VecDeque};
use std::fs::File;
use std::io::{self, Write};

Expand All @@ -12,7 +13,11 @@ use rustc_middle::ty::TyCtxt;
extern crate rustc_session;
use rustc_session::config::{OutFileName, OutputType};

use crate::printer::collect_smir;
extern crate stable_mir;
use stable_mir::mir::{ConstOperand, Operand, TerminatorKind};

use crate::printer::{collect_smir, Item};
use crate::MonoItemKind;

// Sub-modules
pub mod context;
Expand All @@ -25,6 +30,115 @@ pub use context::GraphContext;
pub use index::{AllocEntry, AllocIndex, AllocKind, TypeIndex};
pub use util::GraphLabelString;

// =============================================================================
// Lang Start Filtering
// =============================================================================

pub(crate) fn skip_lang_start() -> bool {
use std::sync::OnceLock;
static VAR: OnceLock<bool> = OnceLock::new();
*VAR.get_or_init(|| std::env::var("SKIP_LANG_START").is_ok())
}

/// Compute the set of symbol names to exclude from graph rendering.
/// Excludes `std::rt::lang_start` items and items uniquely downstream
/// of them (i.e., only reachable through `lang_start` in the call graph).
///
/// The algorithm:
/// 1. Build a call graph from Call terminators
/// 2. Identify `std::rt::lang_start` seed items (via demangled name of MonoItemFn)
/// 3. Find entry-point items (not called by any other item)
/// 4. BFS from non-seed entry points, not entering seed nodes
/// 5. Everything not reachable gets excluded
pub(crate) fn compute_lang_start_exclusions(items: &[Item], ctx: &GraphContext) -> HashSet<String> {
// Build forward call graph: symbol_name -> list of callee names
let mut call_graph: HashMap<&str, Vec<&str>> = HashMap::new();
for item in items {
if let MonoItemKind::MonoItemFn {
body: Some(body), ..
} = &item.mono_item_kind
{
let callees: Vec<&str> = body
.blocks
.iter()
.filter_map(|block| {
if let TerminatorKind::Call {
func: Operand::Constant(ConstOperand { const_, .. }),
..
} = &block.terminator.kind
{
return ctx.functions.get(&const_.ty()).map(|s| s.as_str());
}
None
})
.collect();
call_graph.insert(&item.symbol_name, callees);
}
}

// Identify seed items via the demangled MonoItemFn name containing "std::rt::lang_start".
let seed_names: HashSet<&str> = items
.iter()
.filter(|item| is_std_rt_lang_start(&item.mono_item_kind))
.map(|item| item.symbol_name.as_str())
.collect();

// Retrieve all items that were called via a Call terminator
let has_callers: HashSet<&str> = call_graph.values().flatten().copied().collect();

// BFS from non-seed entry points (items with no callers)
let mut reachable: HashSet<&str> = HashSet::new();
let mut queue: VecDeque<&str> = VecDeque::new();

for item in items {
let name = item.symbol_name.as_str();
let is_entry = !has_callers.contains(name);
if is_entry && !seed_names.contains(name) {
// some items call other items
reachable.insert(name);
queue.push_back(name);
}
}

while let Some(name) = queue.pop_front() {
if let Some(callees) = call_graph.get(name) {
for &callee in callees {
if !reachable.contains(callee) && !seed_names.contains(callee) {
reachable.insert(callee);
queue.push_back(callee);
}
}
}
}

// Everything NOT reachable should be excluded
let all_names: HashSet<&str> = items
.iter()
.map(|i| i.symbol_name.as_str())
.chain(ctx.functions.values().map(|s| s.as_str())) // chain external functions too
.collect();

all_names
.difference(&reachable)
.map(|s| s.to_string())
.collect()
}

/// Check the demangled MonoItemFn name for `std::rt::lang_start`.
/// This catches:
/// - `std::rt::lang_start::<()>` (the runtime entry point)
/// - `std::rt::lang_start::<()>::{closure#0}` (its closure)
/// - `<{closure@std::rt::lang_start<()>::{closure#0}} as ...>::call_once` (trait impls referencing it)
/// - `std::ptr::drop_in_place::<{closure@std::rt::lang_start<()>::{closure#0}}>` (drop glue)
///
/// But not a user-defined `lang_start` e.g. `crate1::something::lang_start`.
fn is_std_rt_lang_start(kind: &MonoItemKind) -> bool {
match kind {
MonoItemKind::MonoItemFn { name, .. } => name.contains("std::rt::lang_start"),
_ => false,
}
}

// =============================================================================
// Entry Points
// =============================================================================
Expand Down
23 changes: 19 additions & 4 deletions src/mk_graph/output/d2.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! D2 diagram format output for MIR graphs.

use std::collections::HashSet;

extern crate stable_mir;
use stable_mir::mir::TerminatorKind;

Expand All @@ -10,11 +12,22 @@ use crate::mk_graph::context::GraphContext;
use crate::mk_graph::util::{
escape_d2, is_unqualified, name_lines, short_name, terminator_targets,
};
use crate::mk_graph::{compute_lang_start_exclusions, skip_lang_start};

impl SmirJson<'_> {
/// Convert the MIR to D2 diagram format
pub fn to_d2_file(self) -> String {
pub fn to_d2_file(mut self) -> String {
let ctx = GraphContext::from_smir(&self);

// Optionally filter out lang_start items and their unique descendants
let excluded: HashSet<String> = if skip_lang_start() {
let excluded = compute_lang_start_exclusions(&self.items, &ctx);
self.items.retain(|i| !excluded.contains(&i.symbol_name));
excluded
} else {
HashSet::new()
};
Comment on lines +23 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a useful filter op. I recommend moving it; could be a free fn in the mk_graph mod.


let mut output = String::new();

output.push_str("direction: right\n\n");
Expand All @@ -23,7 +36,7 @@ impl SmirJson<'_> {
for item in self.items {
match item.mono_item_kind {
MonoItemKind::MonoItemFn { name, body, .. } => {
render_d2_function(&name, body.as_ref(), &ctx, &mut output);
render_d2_function(&name, body.as_ref(), &ctx, &excluded, &mut output);
}
MonoItemKind::MonoItemGlobalAsm { asm } => {
render_d2_asm(&asm, &mut output);
Expand Down Expand Up @@ -61,6 +74,7 @@ fn render_d2_function(
name: &str,
body: Option<&stable_mir::mir::Body>,
ctx: &GraphContext,
excluded: &HashSet<String>,
out: &mut String,
) {
let fn_id = short_name(name);
Expand All @@ -80,7 +94,7 @@ fn render_d2_function(

// Call edges (must be outside the container)
if let Some(body) = body {
render_d2_call_edges(&fn_id, body, ctx, out);
render_d2_call_edges(&fn_id, body, ctx, excluded, out);
}
}

Expand Down Expand Up @@ -115,6 +129,7 @@ fn render_d2_call_edges(
fn_id: &str,
body: &stable_mir::mir::Body,
ctx: &GraphContext,
excluded: &HashSet<String>,
out: &mut String,
) {
for (idx, block) in body.blocks.iter().enumerate() {
Expand All @@ -124,7 +139,7 @@ fn render_d2_call_edges(
let Some(callee_name) = ctx.resolve_call_target(func) else {
continue;
};
if !is_unqualified(&callee_name) {
if !is_unqualified(&callee_name) || excluded.contains(&callee_name) {
continue;
}
Comment on lines 139 to 144
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Want to point out that we have to thread excluded into this function to filter a target. Seems. like a smell -- Not asking for a change :)

Right now the flow is:

  1. Build ctx: populates ctx.functions with everything
  2. Compute excluded
  3. self.items.retain(...) to remove excluded items
  4. Thread excluded through every render function to check at call-edge time

If instead, step 3 also did:

ctx.functions.retain(|_, name| !excluded.contains(name));

then ctx.resolve_call_target() would return None for excluded functions, and the existing early-return in D2 handles it already:

let Some(callee_name) = ctx.resolve_call_target(func) else {
  continue;  // already skips when resolve returns None
};
// no need for: || excluded.contains(&callee_name)
if !is_unqualified(&callee_name) {
  continue;
}

That would also eliminate the need to compute (locally) and thread &excluded entirely. NOTE: This would likely need the dot renderer to be factored.


Expand Down
28 changes: 26 additions & 2 deletions src/mk_graph/output/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,25 @@ use crate::MonoItemKind;

use crate::mk_graph::context::GraphContext;
use crate::mk_graph::util::{block_name, is_unqualified, name_lines, short_name, GraphLabelString};
use crate::mk_graph::{compute_lang_start_exclusions, skip_lang_start};

impl SmirJson<'_> {
/// Convert the MIR to DOT (Graphviz) format
pub fn to_dot_file(self) -> String {
pub fn to_dot_file(mut self) -> String {
let mut bytes = Vec::new();

// Build context BEFORE consuming self
let ctx = GraphContext::from_smir(&self);

// Optionally filter out lang_start items and their unique descendants
let excluded: HashSet<String> = if skip_lang_start() {
let excluded = compute_lang_start_exclusions(&self.items, &ctx);
self.items.retain(|i| !excluded.contains(&i.symbol_name));
excluded
} else {
HashSet::new()
};
Comment on lines +26 to +32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha! Good ole copy-and-paste inheritance (or composition in this case). This supports the filter refaction.


{
let mut writer = DotWriter::from(&mut bytes);

Expand Down Expand Up @@ -57,7 +67,7 @@ impl SmirJson<'_> {

// first create all nodes for functions not in the items list
for f in ctx.functions.values() {
if !item_names.contains(f) {
if !item_names.contains(f) && !excluded.contains(f) {
graph
.node_named(block_name(f, 0))
.set_label(&name_lines(f))
Expand Down Expand Up @@ -245,6 +255,20 @@ impl SmirJson<'_> {

match &b.terminator.kind {
TerminatorKind::Call { func, args, .. } => {
// Skip call edges to excluded nodes
if let Operand::Constant(ConstOperand {
const_, ..
}) = func
{
if ctx
.functions
.get(&const_.ty())
.is_some_and(|c| excluded.contains(c))
{
continue;
}
}

let e = match func {
Operand::Constant(ConstOperand {
const_, ..
Expand Down