API documentation

Checkpoints

Module for managing PyTorch model checkpoints.

Provides the CheckpointManager class to save and load model and optimizer states during training, track the best metric values, and optionally report checkpoint events.

class congrads.checkpoints.CheckpointManager(criteria_function: Callable[[dict[str, Tensor], dict[str, Tensor]], bool], network: Module, optimizer: Optimizer, metric_manager: MetricManager, save_dir: str = 'checkpoints', create_dir: bool = False, report_save: bool = False)

Bases: object

Manage saving and loading checkpoints for PyTorch models and optimizers.

Handles checkpointing based on a criteria function, restores metric states, and optionally reports when a checkpoint is saved.

evaluate_criteria(epoch: int, metric_group: str = 'during_training')

Evaluate the criteria function to determine if a better model is found.

Aggregates the current metric values during training and applies the criteria function. If the criteria function indicates improvement, the best metric values are updated, a checkpoint is saved, and a message is optionally printed.

Parameters:
  • epoch (int) – The current epoch number.

  • metric_group (str, optional) – The metric group to evaluate. Defaults to ‘during_training’.

load(filename: str)

Load a checkpoint and restore the training state.

Loads the checkpoint from the specified file and restores the network weights, optimizer state, and best metric values.

Parameters:

filename (str) – Name of the checkpoint file.

Returns:

A dictionary containing the loaded checkpoint information,

including epoch, loss, and other relevant training state.

Return type:

dict

resume(filename: str = 'checkpoint.pth', ignore_missing: bool = False) int

Resumes training from a saved checkpoint file.

Parameters:
  • filename (str) – The name of the checkpoint file to load. Defaults to “checkpoint.pth”.

  • ignore_missing (bool) – If True, does not raise an error if the checkpoint file is missing and continues without loading, starting from epoch 0. Defaults to False.

Returns:

The epoch number from the loaded checkpoint, or 0 if

ignore_missing is True and no checkpoint was found.

Return type:

int

Raises:
  • TypeError – If a provided attribute has an incompatible type.

  • FileNotFoundError – If the specified checkpoint file does not exist.

save(epoch: int, filename: str = 'checkpoint.pth')

Save a checkpoint.

Parameters:
  • epoch (int) – Current epoch number.

  • filename (str) – Name of the checkpoint file. Defaults to ‘checkpoint.pth’.

Constraints

Core

Datasets

Descriptor

This module defines the Descriptor class, which allows assigning tags to parts in the network.

It is designed to manage the mapping between tags, their corresponding data dictionary keys and indices, and additional properties such as constant or variable status. It provides a way to easily place constraints on parts of your network, by referencing the tags instead of indices.

The Descriptor class allows for easy constraint definitions on parts of your neural network. It supports registering tags with associated data dictionary keys, indices, and optional attributes, such as whether the data is constant or variable.

class congrads.descriptor.Descriptor

Bases: object

A class to manage the mapping between tags.

It represents data locations in the data dictionary and holds the dictionary keys, indices, and additional properties (such as min/max values, output, and constant variables).

This class is designed to manage the relationships between the assigned tags and the data dictionary keys in a neural network model. It allows for the assignment of properties (like minimum and maximum values, and whether data is an output, constant, or variable) to each tag. The data is stored in dictionaries and sets for efficient lookups.

constant_keys

A set of keys that represent constant data in the data dictionary.

Type:

set

variable_keys

A set of keys that represent variable data in the data dictionary.

Type:

set

affects_loss_keys

A set of keys that represent data affecting the loss computation.

Type:

set

add(key: str, tag: str, index: int = None, constant: bool = False, affects_loss: bool = True)

Adds a tag to the descriptor with its associated key, index, and properties.

This method registers a tag name and associates it with a data dictionary key, its index, and optional properties such as whether the key hold output or constant data.

Parameters:
  • key (str) – The key on which the tagged data is located in the data dictionary.

  • tag (str) – The identifier of the tag.

  • index (int) – The index were the data is present. Defaults to None.

  • constant (bool, optional) – Whether the data is constant and is not learned. Defaults to False.

  • affects_loss (bool, optional) – Whether the data affects the loss computation. Defaults to True.

Raises:
  • TypeError – If a provided attribute has an incompatible type.

  • ValueError – If a key or index is already assigned for a tag or a duplicate index is used within a key.

exists(tag: str) bool

Check if a tag is registered in the descriptor.

Parameters:

tag (str) – The tag identifier to check.

Returns:

True if the tag is registered, False otherwise.

Return type:

bool

location(tag: str) tuple[str, int | None]

Get the key and index for a given tag.

Looks up the mapping for a registered tag and returns the associated dictionary key and the index.

Parameters:

tag (str) – The tag identifier. Must be registered.

Returns:

A tuple containing:
  • The key in the data dictionary which holds the data (str).

  • The tensor index where the data is present or None (int | None).

Return type:

tuple ((str, int | None))

Raises:

ValueError – If the tag is not registered in the descriptor.

select(tag: str, data: dict[str, Tensor]) Tensor

Extract prediction values for a specific tag.

Retrieves the key and index associated with a tag and selects the corresponding slice from the given prediction tensor. Returns the full tensor if no index was specified when registering the tag.

Parameters:
  • tag (str) – The tag identifier. Must be registered.

  • data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

A tensor slice of shape (batch_size, 1) containing the predictions for the specified tag, or the full tensor if no index was specified when registering the tag.

Return type:

Tensor

Raises:

ValueError – If the tag is not registered in the descriptor.

Metrics

Module for managing metrics during training.

Provides the Metric and MetricManager classes for accumulating, aggregating, and resetting metrics over training batches. Supports grouping metrics and using custom accumulation functions.

class congrads.metrics.Metric(name: str, accumulator: ~collections.abc.Callable[[...], ~torch.Tensor] = <built-in method nanmean of type object>)

Bases: object

Represents a single metric to be accumulated and aggregated.

Stores metric values over multiple batches and computes an aggregated result using a specified accumulation function.

accumulate(value: Tensor) None

Accumulate a new value for the metric.

Parameters:

value (Tensor) – Metric values for the current batch.

aggregate() Tensor

Compute the aggregated value of the metric.

Returns:

The aggregated metric value. Returns NaN if no values

have been accumulated.

Return type:

Tensor

reset() None

Reset the accumulated values and sample count for the metric.

class congrads.metrics.MetricManager

Bases: object

Manages multiple metrics and groups for training or evaluation.

Supports registering metrics, accumulating values by name, aggregating metrics by group, and resetting metrics by group.

accumulate(name: str, value: Tensor) None

Accumulate a value for a specific metric by name.

Parameters:
  • name (str) – Name of the metric.

  • value (Tensor) – Metric values for the current batch.

aggregate(group: str = 'default') dict[str, Tensor]

Aggregate all metrics in a specified group.

Parameters:

group (str, optional) – The group of metrics to aggregate. Defaults to “default”.

Returns:

Dictionary mapping metric names to their

aggregated values.

Return type:

dict[str, Tensor]

register(name: str, group: str = 'default', accumulator: ~collections.abc.Callable[[...], ~torch.Tensor] = <built-in method nanmean of type object>) None

Register a new metric under a specified group.

Parameters:
  • name (str) – Name of the metric.

  • group (str, optional) – Group name for the metric. Defaults to “default”.

  • accumulator (Callable[..., Tensor], optional) – Function to aggregate accumulated values. Defaults to torch.nanmean.

reset(group: str = 'default') None

Reset all metrics in a specified group.

Parameters:

group (str, optional) – The group of metrics to reset. Defaults to “default”.

reset_all() None

Reset all metrics across all groups.

Networks

Transformations

Utils