API documentation

Checkpoints

Module for managing PyTorch model checkpoints.

Provides the CheckpointManager class to save and load model and optimizer states during training, track the best metric values, and optionally report checkpoint events.

class congrads.checkpoints.CheckpointManager(criteria_function: Callable[[dict[str, Tensor], dict[str, Tensor]], bool], network: Module, optimizer: Optimizer, metric_manager: MetricManager, save_dir: str = 'checkpoints', create_dir: bool = False, report_save: bool = False)

Bases: object

Manage saving and loading checkpoints for PyTorch models and optimizers.

Handles checkpointing based on a criteria function, restores metric states, and optionally reports when a checkpoint is saved.

evaluate_criteria(epoch: int, metric_group: str = 'during_training')

Evaluate the criteria function to determine if a better model is found.

Aggregates the current metric values during training and applies the criteria function. If the criteria function indicates improvement, the best metric values are updated, a checkpoint is saved, and a message is optionally printed.

Parameters:
  • epoch (int) – The current epoch number.

  • metric_group (str, optional) – The metric group to evaluate. Defaults to ‘during_training’.

load(filename: str)

Load a checkpoint and restore the training state.

Loads the checkpoint from the specified file and restores the network weights, optimizer state, and best metric values.

Parameters:

filename (str) – Name of the checkpoint file.

Returns:

A dictionary containing the loaded checkpoint information,

including epoch, loss, and other relevant training state.

Return type:

dict

resume(filename: str = 'checkpoint.pth', ignore_missing: bool = False) int

Resumes training from a saved checkpoint file.

Parameters:
  • filename (str) – The name of the checkpoint file to load. Defaults to “checkpoint.pth”.

  • ignore_missing (bool) – If True, does not raise an error if the checkpoint file is missing and continues without loading, starting from epoch 0. Defaults to False.

Returns:

The epoch number from the loaded checkpoint, or 0 if

ignore_missing is True and no checkpoint was found.

Return type:

int

Raises:
  • TypeError – If a provided attribute has an incompatible type.

  • FileNotFoundError – If the specified checkpoint file does not exist.

save(epoch: int, filename: str = 'checkpoint.pth')

Save a checkpoint.

Parameters:
  • epoch (int) – Current epoch number.

  • filename (str) – Name of the checkpoint file. Defaults to ‘checkpoint.pth’.

Constraints

Module providing constraint classes for guiding neural network training.

This module defines constraints that enforce specific conditions on network outputs to steer learning. Available constraint types include:

  • Constraint: Base class for all constraint types, defining the interface and core behavior.

  • ImplicationConstraint: Enforces one condition only if another condition is met, useful for modeling implications between outputs.

  • ScalarConstraint: Enforces scalar-based comparisons on a network’s output.

  • BinaryConstraint: Enforces a binary comparison between two tags using a comparison function (e.g., less than, greater than).

  • SumConstraint: Ensures the sum of selected tags’ outputs equals a specified value, controlling total output.

These constraints can steer the learning process by applying logical implications or numerical bounds.

Usage:
  1. Define a custom constraint class by inheriting from Constraint.

  2. Apply the constraint to your neural network during training.

  3. Use helper classes like IdentityTransformation for transformations and comparisons in constraints.

class congrads.constraints.ANDConstraint(*constraints: Constraint, name: str = None, monitor_only: bool = False, rescale_factor: Number = 1.5)

Bases: Constraint

A composite constraint that enforces the logical AND of multiple constraints.

This class combines multiple sub-constraints and evaluates them jointly:

  • The satisfaction of the AND constraint is True only if all sub-constraints

are satisfied (elementwise logical AND). * The corrective direction is computed by weighting each sub-constraint’s direction with its satisfaction mask and summing across all sub-constraints.

calculate_direction(data: dict[str, Tensor])

Compute the corrective direction by aggregating sub-constraint directions.

Each sub-constraint contributes its corrective direction, weighted by its satisfaction mask. The directions are summed across constraints for each affected layer.

Parameters:

data – Model predictions and associated batch/context information.

Returns:

A mapping from layer identifiers to correction tensors. Each entry represents the aggregated correction to apply to that layer, based on the satisfaction-weighted sum of sub-constraint directions.

Return type:

dict[str, Tensor]

check_constraint(data: dict[str, Tensor])

Evaluate whether all sub-constraints are satisfied.

Parameters:

data – Model predictions and associated batch/context information.

Returns:

A tuple (total_satisfaction, mask) where:
  • total_satisfaction: A boolean or numeric tensor indicating

elementwise whether all constraints are satisfied (logical AND). * mask: A tensor of ones with the same shape as total_satisfaction. Typically used as a weighting mask in downstream processing.

Return type:

tuple[Tensor, Tensor]

class congrads.constraints.BinaryConstraint(operand_left: str | Transformation, comparator: Callable[[Tensor, Number], Tensor], operand_right: str | Transformation, name: str = None, enforce: bool = True, rescale_factor: Number = 1.5)

Bases: Constraint

A constraint that enforces a binary comparison between two tags.

This class ensures that the output of one tag satisfies a comparison operation with the output of another tag (e.g., less than, greater than, etc.). It uses a comparator function to validate the condition and calculates adjustment directions accordingly.

Parameters:
  • operand_left (Union[str, Transformation]) – Name of the left tag or a transformation to apply.

  • comparator (Callable[[Tensor, Number], Tensor]) – A comparison function (e.g., torch.ge, torch.lt).

  • operand_right (Union[str, Transformation]) – Name of the right tag or a transformation to apply.

  • name (str, optional) – A unique name for the constraint. If not provided, a name is auto-generated in the format “<operand_left> <comparator> <operand_right>”.

  • enforce (bool, optional) – If False, only monitor the constraint without adjusting the loss. Defaults to True.

  • rescale_factor (Number, optional) – Factor to scale the constraint-adjusted loss. Defaults to 1.5.

Raises:

TypeError – If a provided attribute has an incompatible type.

Notes

  • The tags must be defined in the descriptor mapping.

  • The constraint name is composed using the left tag, comparator, and right tag.

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Compute adjustment directions for the tags involved in the binary constraint.

The returned directions indicate how to adjust each tag’s output to satisfy the constraint. Only currently supported for dense layers.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

A mapping from layer names to tensors specifying the normalized adjustment directions for each tag involved in the constraint.

Return type:

dict[str, Tensor]

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Evaluate whether the binary constraint is satisfied for the current predictions.

The constraint compares the outputs of two tags using the specified comparator function. A result of 1 indicates the constraint is satisfied for a sample, and 0 indicates it is violated.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

  • result (Tensor): Binary tensor indicating constraint satisfaction

(1 for satisfied, 0 for violated) for each sample. - mask (Tensor): Tensor of ones with the same shape as result, used for constraint aggregation.

Return type:

tuple[Tensor, Tensor]

class congrads.constraints.Constraint(tags: set[str], name: str = None, enforce: bool = True, rescale_factor: Number = 1.5)

Bases: ABC

Abstract base class for defining constraints applied to neural networks.

A Constraint specifies conditions that the neural network outputs should satisfy. It supports monitoring constraint satisfaction during training and can adjust loss to enforce constraints. Subclasses must implement the check_constraint and calculate_direction methods.

Parameters:
  • tags (set[str]) – Tags referencing parts of the network where this constraint applies to.

  • name (str, optional) – A unique name for the constraint. If not provided, a name is generated based on the class name and a random suffix.

  • enforce (bool, optional) – If False, only monitor the constraint without adjusting the loss. Defaults to True.

  • rescale_factor (Number, optional) – Factor to scale the constraint-adjusted loss. Defaults to 1.5. Should be greater than 1 to give weight to the constraint.

Raises:
  • TypeError – If a provided attribute has an incompatible type.

  • ValueError – If any tag in tags is not defined in the descriptor.

Note

  • If rescale_factor <= 1, a warning is issued.

  • If name is not provided, a name is auto-generated, and a warning is logged.

abstract calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Compute adjustment directions to better satisfy the constraint.

Given the model predictions, input batch, and context, this method calculates the direction in which the predictions referenced by a tag should be adjusted to satisfy the constraint.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

Dictionary mapping network layers to tensors that

specify the adjustment direction for each tag.

Return type:

dict[str, Tensor]

Raises:

NotImplementedError – Must be implemented by subclasses.

abstract check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Evaluates whether the given model predictions satisfy the constraint.

1 IS SATISFIED, 0 IS NOT SATISFIED

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

A tuple where the first element is a tensor of floats indicating whether the constraint is satisfied (with value 1.0 for satisfaction, and 0.0 for non-satisfaction, and the second element is a tensor mask that indicates the relevance of each sample (True for relevant samples and False for irrelevant ones).

Return type:

tuple[Tensor, Tensor]

Raises:

NotImplementedError – If not implemented in a subclass.

descriptor: Descriptor = None
device = None
class congrads.constraints.GroupedMonotonicityConstraint(tag_prediction: str, tag_reference: str, tag_group_identifier: str, rescale_factor_lower: float = 1.5, rescale_factor_upper: float = 1.75, stable: bool = True, direction: Literal['ascending', 'descending'] = 'ascending', name: str = None, enforce: bool = True)

Bases: MonotonicityConstraint

Constraint that enforces a monotonic relationship between two tags.

This constraint ensures that the activations of a prediction tag (tag_prediction) are monotonically ascending or descending with respect to a target tag (tag_reference).

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Calculates ranking adjustments for monotonicity enforcement.

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Evaluate whether the monotonicity constraint is satisfied.

class congrads.constraints.ImplicationConstraint(head: Constraint, body: Constraint, name: str = None)

Bases: Constraint

Represents an implication constraint between two constraints (head and body).

The implication constraint ensures that the body constraint only applies when the head constraint is satisfied. If the head constraint is not satisfied, the body constraint does not apply.

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Compute adjustment directions for tags to satisfy the constraint.

Uses the body constraint directions as the update vector. Only applies updates if the head constraint is satisfied. Currently, this method only works for dense layers due to tag-to-index translation limitations.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

Dictionary mapping tags to tensors

specifying the adjustment direction for each tag.

Return type:

dict[str, Tensor]

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Check whether the implication constraint is satisfied.

Evaluates the head and body constraints. The body constraint is enforced only if the head constraint is satisfied. If the head constraint is not satisfied, the body constraint does not affect the result.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

  • result: Tensor indicating satisfaction of the implication

constraint (1 if satisfied, 0 otherwise). - head_satisfaction: Tensor indicating satisfaction of the head constraint alone.

Return type:

tuple[Tensor, Tensor]

class congrads.constraints.MonotonicityConstraint(tag_prediction: str, tag_reference: str, rescale_factor_lower: float = 1.5, rescale_factor_upper: float = 1.75, stable: bool = True, direction: Literal['ascending', 'descending'] = 'ascending', name: str = None, enforce: bool = True)

Bases: Constraint

Constraint that enforces a monotonic relationship between two tags.

This constraint ensures that the activations of a prediction tag (tag_prediction) are monotonically ascending or descending with respect to a target tag (tag_reference).

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Calculates ranking adjustments for monotonicity enforcement.

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Evaluate whether the monotonicity constraint is satisfied.

class congrads.constraints.ORConstraint(*constraints: Constraint, name: str = None, monitor_only: bool = False, rescale_factor: Number = 1.5)

Bases: Constraint

A composite constraint that enforces the logical OR of multiple constraints.

This class combines multiple sub-constraints and evaluates them jointly:

  • The satisfaction of the OR constraint is True if at least one sub-constraint

is satisfied (elementwise logical OR). * The corrective direction is computed by weighting each sub-constraint’s direction with its satisfaction mask and summing across all sub-constraints.

calculate_direction(data: dict[str, Tensor])

Compute the corrective direction by aggregating sub-constraint directions.

Each sub-constraint contributes its corrective direction, weighted by its satisfaction mask. The directions are summed across constraints for each affected layer.

Parameters:

data – Model predictions and associated batch/context information.

Returns:

A mapping from layer identifiers to correction tensors. Each entry represents the aggregated correction to apply to that layer, based on the satisfaction-weighted sum of sub-constraint directions.

Return type:

dict[str, Tensor]

check_constraint(data: dict[str, Tensor])

Evaluate whether any sub-constraints are satisfied.

Parameters:

data – Model predictions and associated batch/context information.

Returns:

A tuple (total_satisfaction, mask) where:
  • total_satisfaction: A boolean or numeric tensor indicating

elementwise whether any constraints are satisfied (logical OR). * mask: A tensor of ones with the same shape as total_satisfaction. Typically used as a weighting mask in downstream processing.

Return type:

tuple[Tensor, Tensor]

class congrads.constraints.ScalarConstraint(operand: str | Transformation, comparator: Callable[[Tensor, Number], Tensor], scalar: Number, name: str = None, enforce: bool = True, rescale_factor: Number = 1.5)

Bases: Constraint

A constraint that enforces scalar-based comparisons on a specific tag.

This class ensures that the output of a specified tag satisfies a scalar comparison operation (e.g., less than, greater than, etc.). It uses a comparator function to validate the condition and calculates adjustment directions accordingly.

Parameters:
  • operand (Union[str, Transformation]) – Name of the tag or a transformation to apply.

  • comparator (Callable[[Tensor, Number], Tensor]) – A comparison function (e.g., torch.ge, torch.lt).

  • scalar (Number) – The scalar value to compare against.

  • name (str, optional) – A unique name for the constraint. If not provided, a name is auto-generated in the format “<tag> <comparator> <scalar>”.

  • enforce (bool, optional) – If False, only monitor the constraint without adjusting the loss. Defaults to True.

  • rescale_factor (Number, optional) – Factor to scale the constraint-adjusted loss. Defaults to 1.5.

Raises:

TypeError – If a provided attribute has an incompatible type.

Notes

  • The tag must be defined in the descriptor mapping.

  • The constraint name is composed using the tag, comparator, and scalar value.

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Compute adjustment directions to satisfy the scalar constraint.

Only works for dense layers due to tag-to-index translation.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

Dictionary mapping layers to tensors specifying

the adjustment direction for each tag.

Return type:

dict[str, Tensor]

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Check if the scalar constraint is satisfied for a given tag.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

  • result: Tensor indicating whether the tag satisfies the constraint.

  • ones_like(result): Tensor of ones with same shape as result.

Return type:

tuple[Tensor, Tensor]

class congrads.constraints.SumConstraint(operands_left: list[str | Transformation], comparator: Callable[[Tensor, Number], Tensor], operands_right: list[str | Transformation], weights_left: list[Number] = None, weights_right: list[Number] = None, name: str = None, enforce: bool = True, rescale_factor: Number = 1.5)

Bases: Constraint

A constraint that enforces a weighted summation comparison between two groups of tags.

This class evaluates whether the weighted sum of outputs from one set of tags satisfies a comparison operation with the weighted sum of outputs from another set of tags.

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Compute adjustment directions for tags involved in the weighted sum constraint.

The directions indicate how to adjust each tag’s output to satisfy the constraint. Only dense layers are currently supported.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

Mapping from layer names to normalized tensors specifying adjustment directions for each tag involved in the constraint.

Return type:

dict[str, Tensor]

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Evaluate whether the weighted sum constraint is satisfied.

Computes the weighted sum of outputs from the left and right tags, applies the specified comparator function, and returns a binary result for each sample.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

  • result (Tensor): Binary tensor indicating whether the constraint

is satisfied (1) or violated (0) for each sample. - mask (Tensor): Tensor of ones, used for constraint aggregation.

Return type:

tuple[Tensor, Tensor]

Core

This module provides the core CongradsCore class for the main training functionality.

It is designed to integrate constraint-guided optimization into neural network training. It extends traditional training processes by enforcing specific constraints on the model’s outputs, ensuring that the network satisfies domain-specific requirements during both training and evaluation.

The CongradsCore class serves as the central engine for managing the training, validation, and testing phases of a neural network model, incorporating constraints that influence the loss function and model updates. The model is trained with standard loss functions while also incorporating constraint-based adjustments, which are tracked and logged throughout the process.

Key features: - Support for various constraints that can influence the training process. - Integration with PyTorch’s DataLoader for efficient batch processing. - Metric management for tracking loss and constraint satisfaction. - Checkpoint management for saving and evaluating model states.

The CongradsCore class allows for the use of additional callback functions at different stages of the training process to customize behavior for specific needs. These include callbacks for the start and end of epochs, as well as the start and end of the entire training process.

class congrads.core.CongradsCore(descriptor: ~congrads.descriptor.Descriptor, constraints: list[~congrads.constraints.Constraint], loaders: tuple[~torch.utils.data.dataloader.DataLoader, ~torch.utils.data.dataloader.DataLoader, ~torch.utils.data.dataloader.DataLoader], network: ~torch.nn.modules.module.Module, criterion: ~torch.nn.modules.loss._Loss, optimizer: ~torch.optim.optimizer.Optimizer, metric_manager: ~congrads.metrics.MetricManager, device: ~torch.device, network_uses_grad: bool = False, checkpoint_manager: ~congrads.checkpoints.CheckpointManager = None, epsilon: float = 1e-06, constraint_aggregator: ~collections.abc.Callable[[...], ~torch.Tensor] = <built-in method sum of type object>, disable_progress_bar_epoch: bool = False, disable_progress_bar_batch: bool = False, enforce_all: bool = True)

Bases: object

The CongradsCore class is the central training engine for constraint-guided optimization.

It integrates standard neural network training with additional constraint-driven adjustments to the loss function, ensuring that the network satisfies domain-specific constraints during training.

fit(start_epoch: int = 0, max_epochs: int = 100, test_model: bool = True, on_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None, on_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None, on_train_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None, on_train_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None, on_valid_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None, on_valid_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None, on_test_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None, on_test_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None, on_epoch_start: list[Callable[[int], None]] | None = None, on_epoch_end: list[Callable[[int], None]] | None = None, on_train_start: list[Callable[[int], None]] | None = None, on_train_end: list[Callable[[int], None]] | None = None, on_train_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None, on_val_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None, on_test_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None, on_test_start: list[Callable[[int], None]] | None = None, on_test_end: list[Callable[[int], None]] | None = None) None

Train the model over multiple epochs with optional validation and testing.

This method manages the full training loop, including:

  • Executing epoch-level and batch-level callbacks.

  • Training and validating the model each epoch.

  • Adjusting losses according to constraints.

  • Logging metrics via the metric manager.

  • Optional evaluation on the test set.

  • Checkpointing the model during and after training.

Parameters:
  • start_epoch (int, optional) – Epoch number to start training from. Defaults to 0.

  • max_epochs (int, optional) – Total number of epochs to train. Defaults to 100.

  • test_model (bool, optional) – If True, evaluate the model on the test set after training. Defaults to True.

  • on_batch_start (list[Callable], optional) – Callbacks executed at the start of every batch. Defaults to None.

  • on_batch_end (list[Callable], optional) – Callbacks executed at the end of every batch. Defaults to None.

  • on_train_batch_start (list[Callable], optional) – Callbacks executed at the start of each training batch. Defaults to on_batch_start if not provided.

  • on_train_batch_end (list[Callable], optional) – Callbacks executed at the end of each training batch. Defaults to on_batch_end if not provided.

  • on_valid_batch_start (list[Callable], optional) – Callbacks executed at the start of each validation batch. Defaults to on_batch_start if not provided.

  • on_valid_batch_end (list[Callable], optional) – Callbacks executed at the end of each validation batch. Defaults to on_batch_end if not provided.

  • on_test_batch_start (list[Callable], optional) – Callbacks executed at the start of each test batch. Defaults to on_batch_start if not provided.

  • on_test_batch_end (list[Callable], optional) – Callbacks executed at the end of each test batch. Defaults to on_batch_end if not provided.

  • on_epoch_start (list[Callable], optional) – Callbacks executed at the start of each epoch. Defaults to None.

  • on_epoch_end (list[Callable], optional) – Callbacks executed at the end of each epoch. Defaults to None.

  • on_train_start (list[Callable], optional) – Callbacks executed before training starts. Defaults to None.

  • on_train_end (list[Callable], optional) – Callbacks executed after training ends. Defaults to None.

  • on_train_completion_forward_pass (list[Callable], optional) – Callbacks executed after the forward pass during training. Defaults to None.

  • on_val_completion_forward_pass (list[Callable], optional) – Callbacks executed after the forward pass during validation. Defaults to None.

  • on_test_completion_forward_pass (list[Callable], optional) – Callbacks executed after the forward pass during testing. Defaults to None.

  • on_test_start (list[Callable], optional) – Callbacks executed before testing starts. Defaults to None.

  • on_test_end (list[Callable], optional) – Callbacks executed after testing ends. Defaults to None.

Notes

  • If phase-specific callbacks (train/valid/test) are not provided, the global on_batch_start and on_batch_end are used.

  • Training metrics, loss adjustments, and constraint satisfaction ratios are automatically logged via the metric manager.

  • The final model checkpoint is saved if a checkpoint manager is configured.

static test_step(data: dict[str, Tensor], loss: Tensor, constraints: list[Constraint], metric_manager: MetricManager) Tensor

Evaluate constraints during testing and log constraint satisfaction metrics.

This method checks whether each constraint is satisfied for the given data, computes the constraint satisfaction ratio (CSR), and logs it using the metric manager. The base loss is not modified.

Parameters:
  • data (dict[str, Tensor]) – Dictionary containing the batch data, predictions and additional data.

  • loss (Tensor) – The base loss computed by the criterion.

  • constraints (list[Constraint]) – List of constraints to evaluate.

  • metric_manager (MetricManager) – Metric manager for logging CSR and per-constraint metrics.

Returns:

The original, unchanged base loss.

Return type:

Tensor

static train_step(data: dict[str, ~torch.Tensor], loss: ~torch.Tensor, constraints: list[~congrads.constraints.Constraint], descriptor: ~congrads.descriptor.Descriptor, metric_manager: ~congrads.metrics.MetricManager, device: ~torch.device, constraint_aggregator: ~collections.abc.Callable = <built-in method sum of type object>, epsilon: float = 1e-06, enforce_all: bool = True) Tensor

Adjust the training loss based on constraints and compute the combined loss.

This method calculates the directions in which the network outputs should be adjusted to satisfy constraints, scales these adjustments according to the constraint’s rescale factor and gradient norms, and adds the result to the base loss. It also logs the constraint satisfaction ratio (CSR) for monitoring.

Parameters:
  • data (dict[str, Tensor]) – Dictionary containing the batch data, predictions and additional data.

  • loss (Tensor) – The base loss computed by the criterion.

  • constraints (list[Constraint]) – List of constraints to enforce during training.

  • descriptor (Descriptor) – Descriptor containing layer metadata and variable/loss layer info.

  • metric_manager (MetricManager) – Metric manager for logging loss and CSR.

  • device (torch.device) – Device on which computations are performed.

  • constraint_aggregator (Callable, optional) – Function to aggregate per-layer rescaled losses. Defaults to torch.mean.

  • epsilon (float, optional) – Small value to prevent division by zero in gradient normalization. Defaults to 1e-6.

  • enforce_all (bool, optional) – If False, constraints are only monitored and do not influence the loss. Defaults to True.

Returns:

The combined loss including the original loss and constraint-based adjustments.

Return type:

Tensor

static valid_step(data: dict[str, Tensor], loss: Tensor, constraints: list[Constraint], metric_manager: MetricManager) Tensor

Evaluate constraints during validation and log constraint satisfaction metrics.

This method checks whether each constraint is satisfied for the given data, computes the constraint satisfaction ratio (CSR), and logs it using the metric manager. The base loss is not modified.

Parameters:
  • data (dict[str, Tensor]) – Dictionary containing the batch data, predictions and additional data.

  • loss (Tensor) – The base loss computed by the criterion.

  • constraints (list[Constraint]) – List of constraints to evaluate.

  • metric_manager (MetricManager) – Metric manager for logging CSR and per-constraint metrics.

Returns:

The original, unchanged base loss.

Return type:

Tensor

Datasets

This module defines several PyTorch dataset classes for loading and working with various datasets.

Each dataset class extends the torch.utils.data.Dataset class and provides functionality for downloading, loading, and transforming specific datasets where applicable.

Classes:

  • SyntheticClusterDataset: A dataset class for generating synthetic clustered 2D data with labels.

  • BiasCorrection: A dataset class for the Bias Correction dataset focused on temperature forecast data.

  • FamilyIncome: A dataset class for the Family Income and Expenditure dataset.

Each dataset class provides methods for downloading the data (if not already available or synthetic), checking the integrity of the dataset, loading the data from CSV files or generating synthetic data, and applying transformations to the data.

class congrads.datasets.BiasCorrection(root: str | Path, transform: Callable, download: bool = False)

Bases: Dataset

A dataset class for accessing the Bias Correction dataset.

This class extends the Dataset class and provides functionality for downloading, loading, and transforming the Bias Correction dataset. The dataset is focused on temperature forecast data and is made available for use with PyTorch. If download is set to True, the dataset will be downloaded if it is not already available. The data is then loaded, and a transformation function is applied to it.

Parameters:
  • root (Union[str, Path]) – The root directory where the dataset will be stored or loaded from.

  • transform (Callable) – A function to transform the dataset (e.g., preprocessing).

  • download (bool, optional) – Whether to download the dataset if it’s not already present. Defaults to False.

Raises:

RuntimeError – If the dataset is not found and download is not set to True or if all mirrors fail to provide the dataset.

property data_folder: str

Returns the path to the folder where the dataset is stored.

Returns:

The path to the dataset folder.

Return type:

str

download() None

Downloads and extracts the dataset.

This method attempts to download the dataset from the mirrors and extract it into the appropriate folder. If any error occurs during downloading, it will try each mirror in sequence.

Raises:

RuntimeError – If all mirrors fail to provide the dataset.

mirrors = ['https://archive.ics.uci.edu/static/public/514/']
resources = [('bias+correction+of+numerical+prediction+model+temperature+forecast.zip', '3deee56d461a2686887c4ae38fe3ccf3')]
class congrads.datasets.FamilyIncome(root: str | Path, transform: Callable, download: bool = False)

Bases: Dataset

A dataset class for accessing the Family Income and Expenditure dataset.

This class extends the Dataset class and provides functionality for downloading, loading, and transforming the Family Income and Expenditure dataset. The dataset is intended for use with PyTorch-based projects, offering a convenient interface for data handling. This class provides access to the Family Income and Expenditure dataset for use with PyTorch. If download is set to True, the dataset will be downloaded if it is not already available. The data is then loaded, and a user-defined transformation function is applied to it.

Parameters:
  • root (Union[str, Path]) – The root directory where the dataset will be stored or loaded from.

  • transform (Callable) – A function to transform the dataset (e.g., preprocessing).

  • download (bool, optional) – Whether to download the dataset if it’s not already present. Defaults to False.

Raises:

RuntimeError – If the dataset is not found and download is not set to True or if all mirrors fail to provide the dataset.

property data_folder: str

Returns the path to the folder where the dataset is stored.

Returns:

The path to the dataset folder.

Return type:

str

download() None

Downloads and extracts the dataset.

This method attempts to download the dataset from the mirrors and extract it into the appropriate folder. If any error occurs during downloading, it will try each mirror in sequence.

Raises:

RuntimeError – If all mirrors fail to provide the dataset.

mirrors = ['https://www.kaggle.com/api/v1/datasets/download/grosvenpaul/family-income-and-expenditure']
resources = [('archive.zip', '7d74bc7facc3d7c07c4df1c1c6ac563e')]
class congrads.datasets.SectionedGaussians(sections: list[dict], n_samples: int = 1000, n_runs: int = 1, seed: int | None = None, device='cpu', ground_truth_steepness: float = 0.0, blend_k: float = 10.0)

Bases: Dataset

Synthetic dataset generating smoothly varying Gaussian signals across multiple sections.

Each section defines a subrange of x-values with its own Gaussian distribution (mean and standard deviation). Instead of abrupt transitions, the parameters are blended smoothly between sections using a sigmoid function.

The resulting signal can represent a continuous process where statistical properties gradually evolve over time or position.

Features:
  • Input: Gaussian signal samples (y-values)

  • Context: Concatenation of time (x) and normalized energy feature

  • Target: Exponential decay ground truth from 1 at x_min to 0 at x_max

sections

List of section definitions.

Type:

list[dict]

n_samples

Total number of samples across all sections.

Type:

int

n_runs

Number of random waveforms generated from base configuration.

Type:

int

time

Sampled x-values, shape [n_samples, 1].

Type:

torch.Tensor

signal

Generated Gaussian signal values, shape [n_samples, 1].

Type:

torch.Tensor

energies

Normalized energy feature, shape [n_samples, 1].

Type:

torch.Tensor

context

Concatenation of time and energy, shape [n_samples, 2].

Type:

torch.Tensor

x_min

Minimum x-value across all sections.

Type:

float

x_max

Maximum x-value across all sections.

Type:

float

ground_truth_steepness

Exponential decay steepness for target output.

Type:

float

blend_k

Sharpness parameter controlling how rapidly means and standard deviations transition between sections.

Type:

float

class congrads.datasets.SyntheticClusters(cluster_centers, cluster_sizes, cluster_std, cluster_labels)

Bases: Dataset

PyTorch dataset for generating synthetic clustered 2D data with labels.

Each cluster is defined by its center, size, spread (standard deviation), and label. The dataset samples points from a Gaussian distribution centered at the cluster mean.

Parameters:
  • cluster_centers (list[tuple[float, float]]) – Coordinates of each cluster center, e.g. [(x1, y1), (x2, y2), …].

  • cluster_sizes (list[int]) – Number of points to generate in each cluster.

  • cluster_std (list[float]) – Standard deviation (spread) of each cluster.

  • cluster_labels (list[int]) – Class label for each cluster (e.g., 0 or 1).

Raises:

AssertionError – If the input lists do not all have the same length.

data

A concatenated tensor of all generated points with shape (N, 2).

Type:

torch.Tensor

labels

A concatenated tensor of class labels with shape (N,), where N is the total number of generated points.

Type:

torch.Tensor

class congrads.datasets.SyntheticMonotonicity(n_samples=200, x_range=(0.0, 5.0), noise_base=0.05, noise_scale=0.15, noise_sharpness=4.0, noise_center=2.5, osc_amplitude=0.08, osc_frequency=6.0, osc_prob=0.5, seed=None)

Bases: Dataset

Synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.

True function:

y_true(x) = log(1 + x)

Observed:

y(x) = y_true(x) + heteroscedastic_noise(x) + local oscillatory perturbation

Parameters:
  • n_samples (int) – number of data points (default 200)

  • x_range (tuple) – range of x values (default [0, 5])

  • noise_base (float) – baseline noise level (default 0.05)

  • noise_scale (float) – scale of heteroscedastic noise (default 0.15)

  • noise_sharpness (float) – steepness of heteroscedastic transition (default 4.0)

  • noise_center (float) – center point of heteroscedastic increase (default 2.5)

  • osc_amplitude (float) – amplitude of oscillatory perturbation (default 0.08)

  • osc_frequency (float) – frequency of oscillation (default 6.0)

  • osc_prob (float) – probability each sample receives oscillation (default 0.5)

  • seed (int or None) – random seed

Descriptor

This module defines the Descriptor class, which allows assigning tags to parts in the network.

It is designed to manage the mapping between tags, their corresponding data dictionary keys and indices, and additional properties such as constant or variable status. It provides a way to easily place constraints on parts of your network, by referencing the tags instead of indices.

The Descriptor class allows for easy constraint definitions on parts of your neural network. It supports registering tags with associated data dictionary keys, indices, and optional attributes, such as whether the data is constant or variable.

class congrads.descriptor.Descriptor

Bases: object

A class to manage the mapping between tags.

It represents data locations in the data dictionary and holds the dictionary keys, indices, and additional properties (such as min/max values, output, and constant variables).

This class is designed to manage the relationships between the assigned tags and the data dictionary keys in a neural network model. It allows for the assignment of properties (like minimum and maximum values, and whether data is an output, constant, or variable) to each tag. The data is stored in dictionaries and sets for efficient lookups.

constant_keys

A set of keys that represent constant data in the data dictionary.

Type:

set

variable_keys

A set of keys that represent variable data in the data dictionary.

Type:

set

affects_loss_keys

A set of keys that represent data affecting the loss computation.

Type:

set

add(key: str, tag: str, index: int = None, constant: bool = False, affects_loss: bool = True)

Adds a tag to the descriptor with its associated key, index, and properties.

This method registers a tag name and associates it with a data dictionary key, its index, and optional properties such as whether the key hold output or constant data.

Parameters:
  • key (str) – The key on which the tagged data is located in the data dictionary.

  • tag (str) – The identifier of the tag.

  • index (int) – The index were the data is present. Defaults to None.

  • constant (bool, optional) – Whether the data is constant and is not learned. Defaults to False.

  • affects_loss (bool, optional) – Whether the data affects the loss computation. Defaults to True.

Raises:
  • TypeError – If a provided attribute has an incompatible type.

  • ValueError – If a key or index is already assigned for a tag or a duplicate index is used within a key.

exists(tag: str) bool

Check if a tag is registered in the descriptor.

Parameters:

tag (str) – The tag identifier to check.

Returns:

True if the tag is registered, False otherwise.

Return type:

bool

location(tag: str) tuple[str, int]

Get the key and index for a given tag.

Looks up the mapping for a registered tag and returns the associated dictionary key and the index.

Parameters:

tag (str) – The tag identifier. Must be registered.

Returns:

A tuple containing:
  • The key in the data dictionary which holds the data (str).

  • The tensor index where the data is present (int).

Return type:

tuple ((str, int))

Raises:

ValueError – If the tag is not registered in the descriptor.

select(tag: str, data: dict[str, Tensor]) Tensor

Extract prediction values for a specific tag.

Retrieves the key and index associated with a tag and selects the corresponding slice from the given prediction tensor.

Parameters:
  • tag (str) – The tag identifier. Must be registered.

  • data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

A tensor slice of shape (batch_size, 1) containing the predictions for the specified tag.

Return type:

Tensor

Raises:

ValueError – If the tag is not registered in the descriptor.

Metrics

Module for managing metrics during training.

Provides the Metric and MetricManager classes for accumulating, aggregating, and resetting metrics over training batches. Supports grouping metrics and using custom accumulation functions.

class congrads.metrics.Metric(name: str, accumulator: ~collections.abc.Callable[[...], ~torch.Tensor] = <built-in method nanmean of type object>)

Bases: object

Represents a single metric to be accumulated and aggregated.

Stores metric values over multiple batches and computes an aggregated result using a specified accumulation function.

accumulate(value: Tensor) None

Accumulate a new value for the metric.

Parameters:

value (Tensor) – Metric values for the current batch.

aggregate() Tensor

Compute the aggregated value of the metric.

Returns:

The aggregated metric value. Returns NaN if no values

have been accumulated.

Return type:

Tensor

reset() None

Reset the accumulated values and sample count for the metric.

class congrads.metrics.MetricManager

Bases: object

Manages multiple metrics and groups for training or evaluation.

Supports registering metrics, accumulating values by name, aggregating metrics by group, and resetting metrics by group.

accumulate(name: str, value: Tensor) None

Accumulate a value for a specific metric by name.

Parameters:
  • name (str) – Name of the metric.

  • value (Tensor) – Metric values for the current batch.

aggregate(group: str = 'default') dict[str, Tensor]

Aggregate all metrics in a specified group.

Parameters:

group (str, optional) – The group of metrics to aggregate. Defaults to “default”.

Returns:

Dictionary mapping metric names to their

aggregated values.

Return type:

dict[str, Tensor]

register(name: str, group: str = 'default', accumulator: ~collections.abc.Callable[[...], ~torch.Tensor] = <built-in method nanmean of type object>) None

Register a new metric under a specified group.

Parameters:
  • name (str) – Name of the metric.

  • group (str, optional) – Group name for the metric. Defaults to “default”.

  • accumulator (Callable[..., Tensor], optional) – Function to aggregate accumulated values. Defaults to torch.nanmean.

reset(group: str = 'default') None

Reset all metrics in a specified group.

Parameters:

group (str, optional) – The group of metrics to reset. Defaults to “default”.

reset_all() None

Reset all metrics across all groups.

Networks

Module defining the network architectures and components.

class congrads.networks.MLPNetwork(n_inputs, n_outputs, n_hidden_layers=3, hidden_dim=35, activation=None)

Bases: Module

A multi-layer perceptron (MLP) neural network with configurable hidden layers.

forward(data: dict[str, Tensor])

Run a forward pass through the network.

Parameters:

data (dict[str, Tensor]) – Input data to be processed by the network.

Returns:

The original data tensor augmented with the network’s output (having key “output”).

Return type:

dict

Transformations

Module defining transformations and components.

class congrads.transformations.ApplyOperator(tag: str, operator: callable, value: Number)

Bases: Transformation

A transformation that applies a binary operator to the input tensor.

class congrads.transformations.DenormalizeMinMax(tag: str, min: Number, max: Number)

Bases: Transformation

A transformation that denormalizes data using min-max scaling.

class congrads.transformations.IdentityTransformation(tag: str)

Bases: Transformation

A transformation that returns the input unchanged.

class congrads.transformations.Transformation(tag: str)

Bases: ABC

Abstract base class for tag data transformations.

Utils

This module holds utility functions and classes for the congrads package.

class congrads.utils.CSVLogger(file_path: str, overwrite: bool = False, merge: bool = True)

Bases: object

A utility class for logging key-value pairs to a CSV file, organized by epochs.

Supports merging with existing logs or overwriting them.

Parameters:
  • file_path (str) – The path to the CSV file for logging.

  • overwrite (bool) – If True, overwrites any existing file at the file_path.

  • merge (bool) – If True, merges new values with existing data in the file.

Raises:
  • ValueError – If both overwrite and merge are True.

  • FileExistsError – If the file already exists and neither overwrite nor merge is True.

add_value(name: str, value: float, epoch: int)

Adds a value to the logger for a specific epoch and name.

Parameters:
  • name (str) – The name of the metric or value to log.

  • value (float) – The value to log.

  • epoch (int) – The epoch associated with the value.

load()

Loads data from the CSV file into the logger.

Converts the CSV data into the internal dictionary format for further updates or operations.

save()

Saves the logged values to the specified CSV file.

If the file exists and merge is enabled, merges the current data with the existing file.

static to_dataframe(values: dict[tuple[int, str], float]) DataFrame

Converts a dictionary of values into a DataFrame.

Parameters:

values (dict[tuple[int, str], float]) – A dictionary of values keyed by (epoch, name).

Returns:

A DataFrame where epochs are rows, names are columns, and values are the cell data.

Return type:

pd.DataFrame

static to_dict(df: DataFrame) dict[tuple[int, str], float]

Converts a CSVLogger DataFrame to a dictionary the format {(epoch, name): value}.

class congrads.utils.DictDatasetWrapper(base_dataset: Dataset, field_names: list[str] | None = None)

Bases: Dataset

A wrapper for PyTorch datasets that converts each sample into a dictionary.

This class takes any PyTorch dataset and returns its samples as dictionaries, where each element of the original sample is mapped to a key. This is useful for integration with the Congrads toolbox or other frameworks that expect dictionary-formatted data.

base_dataset

The underlying PyTorch dataset being wrapped.

Type:

Dataset

field_names

Names assigned to each field of a sample. If None, default names like ‘field0’, ‘field1’, … are generated.

Type:

list[str] | None

Parameters:
  • base_dataset (Dataset) – The PyTorch dataset to wrap.

  • field_names (list[str] | None, optional) – Custom names for each field. If provided, the list is truncated or extended to match the number of elements in a sample. Defaults to None.

Example

Wrapping a TensorDataset with custom field names:

>>> from torch.utils.data import TensorDataset
>>> import torch
>>> dataset = TensorDataset(torch.randn(5, 3), torch.randint(0, 2, (5,)))
>>> wrapped = DictDatasetWrapper(dataset, field_names=["features", "label"])
>>> wrapped[0]
{'features': tensor([...]), 'label': tensor(1)}

Wrapping a built-in dataset like CIFAR10:

>>> from torchvision.datasets import CIFAR10
>>> from torchvision import transforms
>>> cifar = CIFAR10(
...     root="./data", train=True, download=True, transform=transforms.ToTensor()
... )
>>> wrapped_cifar = DictDatasetWrapper(cifar, field_names=["input", "output"])
>>> wrapped_cifar[0]
{'input': tensor([...]), 'output': tensor(6)}
class congrads.utils.Seeder(base_seed: int)

Bases: object

A deterministic seed manager for reproducible experiments.

This class provides a way to consistently generate pseudo-random seeds derived from a fixed base seed. It ensures that different libraries (Python’s random, NumPy, and PyTorch) are initialized with reproducible seeds, making experiments deterministic across runs.

roll_seed() int

Generate a new deterministic pseudo-random seed.

Each call returns an integer seed derived from the internal pseudo-random generator, which itself is initialized by the base seed.

Returns:

A pseudo-random integer seed in the range [0, 2**31 - 1].

Return type:

int

set_reproducible() None

Configure global random states for reproducibility.

Seeds the following libraries with deterministically generated seeds based on the base seed:

  • Python’s built-in random

  • NumPy’s random number generator

  • PyTorch (CPU and GPU)

Also enforces deterministic behavior in PyTorch by:
  • Seeding all CUDA devices

  • Disabling CuDNN benchmarking

  • Enabling CuDNN deterministic mode

class congrads.utils.ZeroLoss(reduction: str = 'mean')

Bases: _Loss

A loss function that always returns zero.

This custom loss function ignores the input and target tensors and returns a constant zero loss, which can be useful for debugging or when no meaningful loss computation is required.

Parameters:

reduction (str, optional) – Specifies the reduction to apply to the output. Defaults to “mean”. Although specified, it has no effect as the loss is always zero.

forward(predictions: Tensor, target: Tensor, **kwargs) Tensor

Return a dummy loss of zero regardless of input and target.

congrads.utils.is_torch_loss(criterion) bool

Return True if the object is a PyTorch loss function.

congrads.utils.preprocess_AdultCensusIncome(df: DataFrame) DataFrame

Preprocesses the Adult Census Income dataset for PyTorch ML.

Sequential steps: - Drop rows with missing values. - Encode categorical variables to integer labels. - Map the target ‘income’ column to 0/1. - Convert all data to float32. - Add a multiindex to denote Input vs Output columns.

Parameters:

df (pd.DataFrame) – Raw dataframe containing Adult Census Income data.

Returns:

Preprocessed dataframe.

Return type:

pd.DataFrame

congrads.utils.preprocess_BiasCorrection(df: DataFrame) DataFrame

Preprocesses the given dataframe for bias correction by performing a series of transformations.

The function sequentially:

  • Drops rows with missing values.

  • Converts a date string to datetime format and adds year, month, and day columns.

  • Normalizes the columns with specific logic for input and output variables.

  • Adds a multi-index indicating which columns are input or output variables.

  • Samples 2500 examples from the dataset without replacement.

Parameters:

df (pd.DataFrame) – The input dataframe containing the data to be processed.

Returns:

The processed dataframe after applying the transformations.

Return type:

pd.DataFrame

congrads.utils.preprocess_FamilyIncome(df: DataFrame) DataFrame

Preprocesses the given Family Income dataframe.

The function sequentially:

  • Drops rows with missing values.

  • Converts object columns to appropriate data types and removes string columns.

  • Removes certain unnecessary columns like ‘Agricultural Household indicator’ and related features.

  • Adds labels to columns indicating whether they are input or output variables.

  • Normalizes the columns individually.

  • Checks and removes rows that do not satisfy predefined constraints (household income > expenses, food expenses > sub-expenses).

  • Samples 2500 examples from the dataset without replacement.

Parameters:

df (pd.DataFrame) – The input Family Income dataframe containing the data to be processed.

Returns:

The processed dataframe after applying the transformations and constraints.

Return type:

pd.DataFrame

congrads.utils.process_data_monotonicity_constraint(data: Tensor, ordering: Tensor, identifiers: Tensor)

Reorders input samples to support monotonicity checking.

Reorders input samples such that: 1. Samples from the same run are grouped together. 2. Within each run, samples are sorted chronologically.

Parameters:
  • data (Tensor) – The input data.

  • ordering (Tensor) – On what to order the data.

  • identifiers (Tensor) – Identifiers specifying different runs.

Returns:

Sorted data, ordering, and identifiers.

Return type:

Tuple[Tensor, Tensor, Tensor]

congrads.utils.split_data_loaders(data: Dataset, loader_args: dict = None, train_loader_args: dict = None, valid_loader_args: dict = None, test_loader_args: dict = None, train_size: float = 0.8, valid_size: float = 0.1, test_size: float = 0.1, split_generator: Generator = None) tuple[DataLoader, DataLoader, DataLoader]

Splits a dataset into training, validation, and test sets, and returns corresponding DataLoader objects.

Parameters:
  • data (Dataset) – The dataset to be split.

  • loader_args (dict, optional) – Default DataLoader arguments, merges with loader-specific arguments, overlapping keys from loader-specific arguments are superseded.

  • train_loader_args (dict, optional) – Training DataLoader arguments, merges with loader_args, overriding overlapping keys.

  • valid_loader_args (dict, optional) – Validation DataLoader arguments, merges with loader_args, overriding overlapping keys.

  • test_loader_args (dict, optional) – Test DataLoader arguments, merges with loader_args, overriding overlapping keys.

  • train_size (float, optional) – Proportion of data to be used for training. Defaults to 0.8.

  • valid_size (float, optional) – Proportion of data to be used for validation. Defaults to 0.1.

  • test_size (float, optional) – Proportion of data to be used for testing. Defaults to 0.1.

  • split_generator (Generator, optional) – Optional random seed generator to control the splitting of the dataset.

Returns:

A tuple containing three DataLoader objects: one for the training, validation and test set.

Return type:

tuple

Raises:

ValueError – If the train_size, valid_size, and test_size are not between 0 and 1, or if their sum does not equal 1.

congrads.utils.torch_loss_wrapper(criterion: _Loss) _Loss

Wraps a PyTorch loss function to handle the case where the loss function forward pass does not allow **kwargs.

Parameters:

criterion (_Loss) – The PyTorch loss function to wrap.

Returns:

The wrapped criterion that allows **kwargs in the forward pass.

Return type:

_Loss

congrads.utils.validate_callable(name, value, allow_none=False)

Validate that a value is callable function.

Parameters:
  • name (str) – Name of the argument for error messages.

  • value – Value to validate.

  • allow_none (bool) – Whether to allow the value to be None. Defaults to False.

Raises:

TypeError – If the value is not callable.

congrads.utils.validate_callable_iterable(name, value, allowed_iterables=(<class 'list'>, <class 'set'>, <class 'tuple'>), allow_none=False)

Validate that a value is an iterable containing only callable elements.

This function ensures that the given value is an iterable (e.g., list or set and that all its elements are callable functions.

Parameters:
  • name (str) – Name of the argument for error messages.

  • value – The value to validate.

  • allowed_iterables (tuple of types, optional) – Iterable types that are allowed. Defaults to (list, set).

  • allow_none (bool, optional) – Whether to allow the value to be None. Defaults to False.

Raises:

TypeError – If the value is not an allowed iterable type or if any element is not callable.

congrads.utils.validate_comparator_pytorch(name, value)

Validate that a value is a callable PyTorch comparator function.

Parameters:
  • name (str) – Name of the argument for error messages.

  • value – Value to validate.

Raises:

TypeError – If the value is not callable or not a PyTorch comparator.

congrads.utils.validate_iterable(name, value, expected_element_types, allowed_iterables=(<class 'list'>, <class 'set'>, <class 'tuple'>), allow_empty=False, allow_none=False)

Validate that a value is an iterable (e.g., list, set) with elements of the specified type(s).

Parameters:
  • name (str) – Name of the argument for error messages.

  • value – Value to validate.

  • expected_element_types (type or tuple of types) – Expected type(s) for the elements.

  • allowed_iterables (tuple of types) – Iterable types that are allowed (default: list and set).

  • allow_empty (bool) – Whether to allow empty iterables. Defaults to False.

  • allow_none (bool) – Whether to allow the value to be None. Defaults to False.

Raises:

TypeError – If the value is not an allowed iterable type or if any element is not of the expected type(s).

congrads.utils.validate_loaders(name: str, loaders: tuple[DataLoader, DataLoader, DataLoader])

Validates that loaders is a tuple of three DataLoader instances.

Parameters:
  • name (str) – The name of the parameter being validated.

  • loaders (tuple[DataLoader, DataLoader, DataLoader]) – A tuple of three DataLoader instances.

Raises:

TypeError – If loaders is not a tuple of three DataLoader instances or contains invalid types.

congrads.utils.validate_type(name, value, expected_types, allow_none=False)

Validate that a value is of the specified type(s).

Parameters:
  • name (str) – Name of the argument for error messages.

  • value – Value to validate.

  • expected_types (type or tuple of types) – Expected type(s) for the value.

  • allow_none (bool) – Whether to allow the value to be None. Defaults to False.

Raises:

TypeError – If the value is not of the expected type(s).