Skip to content

[Observability 2/7] Structured system logging and gantt chart generation#2602

Open
felipemello1 wants to merge 1 commit intogh/felipemello1/18/basefrom
gh/felipemello1/18/head
Open

[Observability 2/7] Structured system logging and gantt chart generation#2602
felipemello1 wants to merge 1 commit intogh/felipemello1/18/basefrom
gh/felipemello1/18/head

Conversation

@felipemello1
Copy link

@felipemello1 felipemello1 commented Mar 16, 2026

Stack from ghstack (oldest at bottom):

How to review

Prioritize the files under /observability. Check how they are used in the toy example.
No need to nitpick the toy example. Real implementation in trainer are in PRs 6 and 7.

Summary

The system logging layer. Every process writes structured JSONL to disk via Python's logging module.

Two APIs:

  • record_span(description, event_type) — context manager that writes START/END events with wall-clock duration. Used for timing training phases (forward/backward, optimizer, data loading, etc).
  • record_event({"key": value, ...}) — point-in-time scalar snapshots for per-rank diagnostics.

Example

from torchtitan.observability import (
    init_observability, set_step, add_step_tag,
    record_span, record_event, EventType,
)
from torchtitan.tools.logging import init_logger

# Console logging (stdout with [titan] format)
init_logger()

# JSONL file handlers for structured logging
init_observability(source="trainer", output_dir="./outputs")

for step in range(steps):
    # Stamp all subsequent JSONL entries with step=step
    set_step(step)

    if should_garbage_collect:
        add_step_tag("gc")
        with record_span("trainer_time/gc_s", EventType.GC_COLLECT):
            run_gc()

    with record_span("trainer_time/forward_backward_s", EventType.FWD_BWD):
        reduced_loss = model.fwd_bwd(batch)
        record_event({"loss": reduced_loss.item()})
// dump_dir/system_logs/trainer_rank_0_system.jsonl
{"int": {"step": 42, "rank": 0, "time_us": 1708200121724000},
 "normal": {"log_type_name": "fwd_bwd_end", "source": "trainer",
            "message": "[step 42] trainer_time/forward_backward_s fwd_bwd_end took 123.45 ms",
            "caller": "torchtitan/trainer.py:730:train_step"},
 "double": {"value": 123.45},
 "normvector": {"step_tags": ["gc"]}}

Output folder (from toy_spmd on this PR)

outputs/toy_spmd/
├── analysis/
│   └── system_metrics_gantt.json      ← generate_gantt_trace output
└── system_logs/
    ├── trainer_rank_0_system.jsonl
    ├── trainer_rank_1_system.jsonl
    ├── trainer_rank_2_system.jsonl
    └── trainer_rank_3_system.jsonl

Gantt chart

toy_rl
gantt_rl

toy_spmd
ganttspmd

Others

init_observability(source, output_dir) sets up the JSONL file handlers. Each rank gets its own file: {output_dir}/system_logs/{source}_rank_{rank}_system.jsonl.

The JSONL format uses four typed columns (int, normal, double, normvector) for easy ingestion into Grafana, DuckDB, etc.

generate_gantt_trace(log_dir, output_path) reads all system JSONL files and produces a Chrome Trace JSON. Open in Perfetto to see a gantt chart of every record_span across all ranks.

Also includes:

  • EventType enum for categorizing spans,
  • set_step/add_step_tag for step context,
  • InflightEventTrackingHandler for crash forensics

Test plan

Run toy_spmd: python -m torch.distributed.run --nproc_per_node=4 -m torchtitan.experiments.observability.toy_spmd
Run toy_rl: python -m torchtitan.experiments.observability.toy_rl

  • 44 new unit tests for record_span, record_event, init_observability, gantt generation, step state
  • Integration: toy_spmd produces system JSONL with record_span start/end events
  • Integration: toy_rl produces system JSONL from 4 actor types
  • Integration: generate_gantt_trace produces Chrome Trace JSON from system JSONL

Console output (toy_spmd)

step: 1  loss: 3.65555  grad_norm: 0.49419
step: 5  loss: 3.05283  grad_norm: 0.40876
step: 10  loss: 2.53639  grad_norm: 0.31771
  val loss: 2.4611
step: 20  loss: 2.02626  grad_norm: 0.22093
  val loss: 1.9949
Chrome Trace: outputs/toy_spmd/analysis/system_metrics_gantt.json
  576 events from 4 sources

Console output (toy_rl)

step: 1  loss: 3.65156  grad_norm: 0.49804
step: 1  reward_mean: 1.00000
step: 3  loss: 3.32657  grad_norm: 0.44708
step: 3  reward_mean: 1.00000
step: 6  loss: 2.93530  grad_norm: 0.36988
step: 6  reward_mean: 1.00000
Chrome Trace: outputs/toy_rl/analysis/system_metrics_gantt.json
Done in 13.8s.

Sample of generated logs

{"int": {"tid": 1056731, "thread_time_ns": 3325664547, "rank": 0, "pid": 1056731, "time": 1773610642, "time_ms": 1773610642167, "time_us": 1773610642167603, "seq_id": 0}, "normal": {"source": "trainer", "host_name": "devgpu011.ldc3.facebook.com", "log_type": "event", "log_type_name": "build_model_start", "event_name": "setup/model_build", "caller": "torchtitan/experiments/observability/toy_spmd.py:168:__init__", "log_file": "toy_spmd.py", "log_function": "__init__", "log_level": "INFO", "logger_name": "torchtitan.observability.system", "message": "[step N/A] setup/model_build build_model_start"}, "double": {"delta_ms": 0.00051083043217659}, "normvector": {}}
{"int": {"tid": 1056731, "thread_time_ns": 3998473694, "rank": 0, "pid": 1056731, "time": 1773610642, "time_ms": 1773610642858, "time_us": 1773610642858849, "seq_id": 1}, "normal": {"source": "trainer", "host_name": "devgpu011.ldc3.facebook.com", "log_type": "event", "log_type_name": "build_model_end", "event_name": "setup/model_build", "caller": "torchtitan/experiments/observability/toy_spmd.py:168:__init__", "log_file": "toy_spmd.py", "log_function": "__init__", "log_level": "INFO", "logger_name": "torchtitan.observability.system", "message": "[step N/A] setup/model_build build_model_end took 691.24 ms"}, "double": {"delta_ms": 691.2433146499097, "value": 691.238438244909}, "normvector": {}}

Comment on lines +13 to +18
class MetricsProcessor:
"""Step context manager for the toy trainer.

Mirrors the method order of components/metrics.py MetricsProcessor
so the toy and production versions are easy to compare.
"""
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this class is incrementally built over the prs 1-5 (toy examples), so its ready for prs 6-7 (integration with trainer)

_system_logger = logging.getLogger(SYSTEM_LOGGER_NAME)


def init_observability(source: str, output_dir: str, rank: int | None = None) -> None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function (and maybe a couple others) should probably not be here, since its NOT specific to system metrics (its also shared by experiment metrics). I can move in a follow up.

self.step = step
tokens, labels, loss_mask = next(data_iterator)
self.train_step(tokens, labels, loss_mask)
self.metrics_processor.set_step(step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When should we call set_step directly, vs. call it from MetricsProcessor?

Copy link
Author

@felipemello1 felipemello1 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MetricsProcessor calls set_step (and i guess a few other things, like clear_tags). Its just a convenience.

Ideally, you call set_step at the start of every step. Take a look at the toyrl example.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be clearer if this functionality was pushed outside of Metrics Processor - otherwise, like @tianyu-l, I'm not sure when I would use set_step directly or use MetricProcessor.set_step.

dist.all_reduce(
loss_scalar, op=dist.ReduceOp.SUM, group=self.dp_mesh.get_group()
)
record_event(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't it be using record_metrics, although I haven't internalized the difference between record_metrics and record_event

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

record_metric goes to metric logs (something you want to put in wandb). Record event goes to system logs (something that goes to scuba). In your gantt, they are the arrows bellow. Imagine also that they can be something like "counters", for example to check if all ranks have reached that spot, or anything that you would like to query later

image

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'shouldn't it be using record_metrics'

added in the next PR. This PR is about record_spans and record_event

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, you actually wouldn't be logging the metrics here - you're just doing it to show it's possible but it would actually be handled by record_metric?

Otherwise, the conflation of the two is a little confusing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i agree its confusing. In later PRs Tianyu also questioned it and i gave an example on how sixlib does it. I think we can iron it out a bit.

return self.value


class EventType(StrEnum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How necessary is this? My worry is that we'd constantly need to add things to this entry. In the extreme case, it's hard to check if something I want is in the existing enum, so that I would blindly add a new one.

Why don't we just stick with whatever use give in str format? If users want, they can define Enums over the string values they have.

Copy link
Author

@felipemello1 felipemello1 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also debating this, and my resolution was to make it optional in record_span. So, if the user does:

record_span("my_string"), it works.

The reason i decided to keep the Event, was that if we want to compare runs or experiments, and there are N engineers working on different branches, there are no guarantees that they will use the same EventType. So, imagine that you build some postprocessing code to check that "FWD_BWD" meets some specification. If we don't have the EventType, the algorithm is not reliable.

Do you think that this is a strong enough argument to keep it? or would you say to remove it entirely?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If anything, I would strip this down to the bare minimum for shared Enums (fwd_bwd, step, optim). The rest could be project specific. So when I go to setup my RL workflow, I can choose to include some common RL event types I want to track. Or I can ignore them entirely. The point is the granularity is up to me.



def to_structured_json(log_dict: dict[str, Any]) -> str:
"""Convert a log dict to 4-column JSON (int/normal/double/normvector)."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you share a specification on the naming? seems there's some convention going on

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean 'int/normal/double/normvector'? That's for scuba

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's for scuba

Can you say more? Is this just a scuba thing or is it the standard for other visualization/logging tools as well?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, scuba expects these 4 columns. I do NOT know if its a standard for other logging tools. I think that we would have to understand that with time/user requests, or have a follow up trying to integrate DuckDB or something like that. Also, I think that i will probably need to enable handles using ENV args. E.g.\

torchrun .... --enable_handler_scuba

So we can have meta specific handlers.

# ---------------------------------------------------------------------------


class EventsOnlyFilter(logging.Filter):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this? if everything is filtered, we could have required LOG_TYPE_NAME to be given when calling record_span

Copy link
Author

@felipemello1 felipemello1 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It protects against user error. If someone, for some reason, does:

logger = logging.getLogger("torchtitan.observability.system") # logger dedicated to structured logger
logger.info("hello")

because of this filter, it would correctly NOT be added to the jsonl.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm a user in the above scenario, I would be very confused. Should include a warning saying my message X will not be logged and how to properly log something to the system.

Copy link
Author

@felipemello1 felipemello1 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should include a warning saying my message X will not be logged and how to properly log something to the system.

You do NOT want to use this logging.getLogger("torchtitan.observability.system") for normal logger.info. This is something that is internal to record_span. This would only ever happen if an user really tried to make a mistake, and if that happened, it would be safe. The logger would be a normal logger, printing to console, and it would not damage the output jsonl. To log to jsonl, user has to pass the "extras" field, which record_span does.

To initialize regular logger in titan, we call init_logger().

rank = int(os.environ.get("RANK", 0))

# --- System handler ---
sys_logger = logging.getLogger(SYSTEM_LOGGER_NAME)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be helpful if you could also help improve logging module naming across all torchtitan components.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you have in mind? You want the logger to show the callsite info?

Comment on lines +331 to +333
Adds per-rank system and experiment JSONL handlers for structured
logging. Does NOT set up console logging — call ``init_logger()``
from ``torchtitan.tools.logging`` for that.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think we can put both of them into torchtitan/components/observability?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, i wanted to minimize the amount of changes that werent necessary, since the PRs were already so long.

)


class record_span(ContextDecorator): # noqa: N801
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question: what's the difference between context decorator and context manager?

Copy link
Author

@felipemello1 felipemello1 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL

A ContextDecorator is a class-based utility that enables a ContextManager to also be used as a function decorator

image

)
handler = StructuredLoggingHandler(filepath=sys_path)
handler.setFormatter(StructuredJSONFormatter(rank=rank, source=source))
sys_logger.addHandler(handler)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh IIUC this is using registration mechanism provided by python's logging.getLogger, sounds neat


message = record.getMessage()
if message is not None:
if len(message) <= MAX_MESSAGE_SIZE:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a MAX_MESSAGE_SIZE?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid spamming the logs

# ---------------------------------------------------------------------------


class EventsOnlyFilter(logging.Filter):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm a user in the above scenario, I would be very confused. Should include a warning saying my message X will not be logged and how to properly log something to the system.



def to_structured_json(log_dict: dict[str, Any]) -> str:
"""Convert a log dict to 4-column JSON (int/normal/double/normvector)."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's for scuba

Can you say more? Is this just a scuba thing or is it the standard for other visualization/logging tools as well?

TEXT = "text"


class ExtraFields(StrEnum):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I just don't understand this one. Any more you can share?

return self.value


class EventType(StrEnum):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If anything, I would strip this down to the bare minimum for shared Enums (fwd_bwd, step, optim). The rest could be project specific. So when I go to setup my RL workflow, I can choose to include some common RL event types I want to track. Or I can ignore them entirely. The point is the granularity is up to me.

dist.all_reduce(
loss_scalar, op=dist.ReduceOp.SUM, group=self.dp_mesh.get_group()
)
record_event(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, you actually wouldn't be logging the metrics here - you're just doing it to show it's possible but it would actually be handled by record_metric?

Otherwise, the conflation of the two is a little confusing.

self.step = step
tokens, labels, loss_mask = next(data_iterator)
self.train_step(tokens, labels, loss_mask)
self.metrics_processor.set_step(step)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be clearer if this functionality was pushed outside of Metrics Processor - otherwise, like @tianyu-l, I'm not sure when I would use set_step directly or use MetricProcessor.set_step.


# Simulate GC on every 5th step (mirrors gc_handler.run)
if step % 5 == 0:
add_step_tag("gc")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: Why wouldn't I use record_span to get a timing on how long gc takes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should. We dot it in pr7. I dont know if i forgot to do it in the toy example.

add_step_tag("gc") is just a tag that will be added to very other log moving forward. Its a way to say: "hey, this step may be slower than normal. So, if you want to, you can filter out all steps with "gc" tag".


if step % EVAL_FREQ == 0:
self.validate(tokens, labels, loss_mask)
add_step_tag("eval")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: Why does "eval" get a step tag and a record_span? Do these go to different things? If I record span, shouldn't that also tell me where I am in the step?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, record_span saves the time. add_step_tag will add a tag to any subsequent log, so we can know that that log was generated in a step that also had eval. Does this make sense?


if rank == 0:
print(f"Done. Output: {OUTPUT_DIR}")
sys_logs = os.path.join(OUTPUT_DIR, "system_logs")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These would be nice as convenience methods - not sure if added later in the stack

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants