from __future__ import annotations
import json
import logging
import os
from collections import defaultdict
from typing import Any, Dict, Optional
import torch
from .base import DEFAULT_EXPERIMENT_NAME, BaseTracker
logger = logging.getLogger(__name__)
[docs]
class DummyTracker(BaseTracker):
"""No-op tracker. Discards all logged data."""
def _init(self, config: Optional[dict] = None):
pass
def _log(self, data: dict):
pass
def _finish(self, summary: Optional[dict] = None):
pass
[docs]
class JSONTracker(BaseTracker):
"""Writes accumulated logs to a JSON file on ``finish()``."""
def __init__(
self,
experiment_name: str = DEFAULT_EXPERIMENT_NAME,
experiment_config: Optional[dict] = None,
log_file_path: str = "./logs/{experiment_name}.json",
):
super().__init__(experiment_name, experiment_config)
self.log_file_path = log_file_path.format(experiment_name=experiment_name)
os.makedirs(os.path.dirname(self.log_file_path), exist_ok=True)
self._log_data: dict = {}
def _init(self, config: Optional[dict] = None):
self._log_data = defaultdict(list)
if config:
self._log_data["config"] = config
def _log(self, data: dict):
for key, value in data.items():
self._log_data[key].append(value)
def _finish(self, summary: Optional[dict] = None):
if summary:
self._log_data["run_summary"] = summary
with open(self.log_file_path, "w") as f:
json.dump(self._log_data, f, indent=4)
[docs]
class WandbTracker(BaseTracker):
"""Logs to Weights & Biases.
Construction stores backend parameters (project, entity, etc.)
without starting a run. The run is opened on ``init()`` and closed
on ``finish()``.
"""
def __init__(
self,
experiment_name: str = DEFAULT_EXPERIMENT_NAME,
experiment_config: Optional[dict] = None,
project_name: str = DEFAULT_EXPERIMENT_NAME,
**wandb_kwargs,
):
"""
Args:
experiment_name: Run name in WandB.
project_name: WandB project name.
experiment_config: User-provided config merged with optimizer config on every run.
**wandb_kwargs: Extra kwargs forwarded to ``wandb.init()``.
"""
super().__init__(experiment_name, experiment_config)
self.project_name = project_name
self._wandb_kwargs = wandb_kwargs
def _init(self, config: Optional[dict] = None):
import wandb
wandb.init(
project=self.project_name,
name=self.experiment_name,
config=config,
**self._wandb_kwargs,
)
def _log(self, data: Dict[str, Any]):
import wandb
sanitized = {}
for k, v in data.items():
if isinstance(v, torch.Tensor):
try:
v = v.item()
except (ValueError, RuntimeError):
continue
sanitized[k] = v
wandb.log(sanitized)
def _finish(self, summary: Optional[dict] = None):
import wandb
if summary:
wandb.run.summary.update(summary)
wandb.finish()
[docs]
class DictTracker(BaseTracker):
"""Accumulates logged values in plain Python dicts.
Attributes:
records (list[dict]): Each ``log()`` call appends one record (the raw dict).
history (dict[str, list]): Per-key view — ``history[key]`` contains only values
from records that included *key*. Convenient but records from different keys
may not be index-aligned; use ``records`` when you need to join across keys.
config (dict): Run config, if logged.
summary (dict): Run summary, if logged.
"""
def _init(self, config: Optional[dict] = None):
self.records: list[dict] = []
self.history: Dict[str, list] = defaultdict(list)
self.config: dict = config or {}
self.summary: dict = {}
def _log(self, data: dict):
self.records.append(data)
for key, value in data.items():
self.history[key].append(value)
def _finish(self, summary: Optional[dict] = None):
self.summary = summary or {}
[docs]
class TrackioTracker(BaseTracker):
"""Logs to Hugging Face's Trackio.
Construction stores backend parameters (project, space_id, etc.) without starting a run.
The run is opened on ``init()`` and closed on ``finish()``.
Trackio has no per-run summary object (unlike WandB); ``finish(summary=...)`` records the
summary as a final ``trackio.log()`` entry whose keys are prefixed by ``"summary/"`` so it
can be recovered from the run history.
See https://huggingface.co/docs/trackio for backend details.
"""
def __init__(
self,
experiment_name: str = DEFAULT_EXPERIMENT_NAME,
experiment_config: Optional[dict] = None,
project_name: str = DEFAULT_EXPERIMENT_NAME,
space_id: Optional[str] = None,
**trackio_kwargs,
):
"""
Args:
experiment_name: Run name in Trackio.
project_name: Trackio project name.
experiment_config: User-provided config merged with optimizer config on every run.
space_id: Optional HuggingFace Space identifier (``"user/space_name"``) for hosted
dashboards. If ``None``, Trackio persists locally to ``~/.trackio``.
**trackio_kwargs: Extra kwargs forwarded to ``trackio.init()``.
"""
super().__init__(experiment_name, experiment_config)
self.project_name = project_name
self.space_id = space_id
self._trackio_kwargs = trackio_kwargs
def _init(self, config: Optional[dict] = None):
import trackio
trackio.init(
project=self.project_name,
name=self.experiment_name,
config=config,
space_id=self.space_id,
**self._trackio_kwargs,
)
def _log(self, data: Dict[str, Any]):
import trackio
sanitized = {}
for k, v in data.items():
if isinstance(v, torch.Tensor):
try:
v = v.item()
except (ValueError, RuntimeError):
continue
sanitized[k] = v
trackio.log(sanitized)
def _finish(self, summary: Optional[dict] = None):
import trackio
if summary:
trackio.log({f"summary/{k}": v for k, v in summary.items()})
trackio.finish()
[docs]
class PrintTracker(BaseTracker):
"""Prints each optimisation step to stdout and accumulates history.
Useful for Jupyter notebooks or any situation where you want live
step-by-step loss/trigger output without a heavyweight logging backend.
Attributes:
history (dict): Accumulated values keyed by metric name.
"""
def __init__(
self,
experiment_name: str = DEFAULT_EXPERIMENT_NAME,
print_keys: tuple = ("loss", "best_trigger_str"),
):
super().__init__(experiment_name)
self.print_keys = print_keys
self.history: dict = defaultdict(list)
self._step = 0
def _init(self, config: Optional[dict] = None):
self._step = 0
if config:
print(f"=== Run config: {config} ===", flush=True)
def _log(self, data: dict):
self._step += 1
for key, val in data.items():
self.history[key].append(val)
parts = [f"step={self._step:>4}"]
for key in self.print_keys:
if key in data:
val = data[key]
parts.append(
f"{key}={val:.4f}" if isinstance(val, float) else f"{key}={val!r}"
)
print(" | ".join(parts), flush=True)
def _finish(self, summary: Optional[dict] = None):
if summary:
parts = [f"{k}={v:.4f}" if isinstance(v, float) else f"{k}={v!r}" for k, v in summary.items()]
print(f"=== Run summary: {' | '.join(parts)} ===", flush=True)
[docs]
class LiveLossPlotTracker(BaseTracker):
"""Live-updating loss plot via ``livelossplot``."""
def __init__(
self,
experiment_name: str = DEFAULT_EXPERIMENT_NAME,
focus_on_metrics: tuple = ("loss",),
):
super().__init__(experiment_name)
self.focus_on_metrics = focus_on_metrics
self._plotlosses: Optional["livelossplot.PlotLosses"] = None
def _init(self, config: Optional[dict] = None):
import livelossplot
from livelossplot.outputs import MatplotlibPlot
def _after_subplot(ax, group_name, x_label):
ax.set_title(group_name)
ax.set_xlabel("step")
ax.legend(loc="center right")
self._plotlosses = livelossplot.PlotLosses(
outputs=[MatplotlibPlot(figsize=(7, 3), after_subplot=_after_subplot)],
)
def _log(self, data: Dict[str, Any]):
self._plotlosses.update({
k: v for k, v in data.items()
if (isinstance(v, (int, float))
and k in self.focus_on_metrics)
})
self._plotlosses.send()
def _finish(self, summary: Optional[dict] = None):
pass