API documentation

This section documents the public API of Congrads.

Callbacks

Callback and Operation Framework for Modular Training Pipelines.

This module provides a structured system for defining and executing callbacks and operations at different stages of a training lifecycle. It is designed to support:

  • Stateless, reusable operations that produce outputs merged into the event-local data.

  • Callbacks that group operations and/or custom logic for specific stages of training, epochs, batches, and steps.

  • A central CallbackManager to orchestrate multiple callbacks, maintain shared context, and execute stage-specific pipelines in deterministic order.

Stages supported:
  • on_train_start

  • on_train_end

  • on_epoch_start

  • on_epoch_end

  • on_batch_start

  • on_batch_end

  • on_test_start

  • on_test_end

  • on_train_batch_start

  • on_train_batch_end

  • on_valid_batch_start

  • on_valid_batch_end

  • on_test_batch_start

  • on_test_batch_end

  • after_train_forward

  • after_valid_forward

  • after_test_forward

Usage:
  1. Define Operations by subclassing Operation and implementing the compute method.

  2. Create a Callback subclass or instance and register Operations to stages via add(stage, operation).

  3. Register callbacks with CallbackManager.

  4. Invoke CallbackManager.run(stage, data) at appropriate points in the training loop, passing in event-local data.

class congrads.callbacks.base.Callback

Bases: ABC

Abstract base class representing a callback that can have multiple operations registered to different stages of the training lifecycle.

Each stage method executes all operations registered for that stage in insertion order. Operations can modify the event-local data dictionary.

add(stage: Literal['on_train_start', 'on_train_end', 'on_epoch_start', 'on_epoch_end', 'on_test_start', 'on_test_end', 'on_batch_start', 'on_batch_end', 'on_train_batch_start', 'on_train_batch_end', 'on_valid_batch_start', 'on_valid_batch_end', 'on_test_batch_start', 'on_test_batch_end', 'after_train_forward', 'after_valid_forward', 'after_test_forward'], op: Operation) Self

Register an operation to execute at the given stage.

Parameters:
  • stage (Stage) – Lifecycle stage at which to run the operation.

  • op (Operation) – Operation instance to add.

Returns:

Returns self for method chaining.

Return type:

Self

after_test_forward(data: dict[str, Any], ctx: dict[str, Any]) dict[str, Any]

Execute operations registered for the ‘after_test_forward’ stage.

after_train_forward(data: dict[str, Any], ctx: dict[str, Any]) dict[str, Any]

Execute operations registered for the ‘after_train_forward’ stage.

after_valid_forward(data: dict[str, Any], ctx: dict[str, Any]) dict[str, Any]

Execute operations registered for the ‘after_valid_forward’ stage.

on_batch_end(data: dict[str, Any], ctx: dict[str, Any]) dict[str, Any]

Execute operations registered for the ‘on_batch_end’ stage.

on_batch_start(data: dict[str, Any], ctx: dict[str, Any]) dict[str, Any]

Execute operations registered for the ‘on_batch_start’ stage.

on_epoch_end(data: dict[str, Any], ctx: dict[str, Any])

Execute operations registered for the ‘on_epoch_end’ stage.

on_epoch_start(data: dict[str, Any], ctx: dict[str, Any])

Execute operations registered for the ‘on_epoch_start’ stage.

on_test_batch_end(data: dict[str, Any], ctx: dict[str, Any]) dict[str, Any]

Execute operations registered for the ‘on_test_batch_end’ stage.

on_test_batch_start(data: dict[str, Any], ctx: dict[str, Any]) dict[str, Any]

Execute operations registered for the ‘on_test_batch_start’ stage.

on_test_end(data: dict[str, Any], ctx: dict[str, Any])

Execute operations registered for the ‘on_test_end’ stage.

on_test_start(data: dict[str, Any], ctx: dict[str, Any])

Execute operations registered for the ‘on_test_start’ stage.

on_train_batch_end(data: dict[str, Any], ctx: dict[str, Any]) dict[str, Any]

Execute operations registered for the ‘on_train_batch_end’ stage.

on_train_batch_start(data: dict[str, Any], ctx: dict[str, Any]) dict[str, Any]

Execute operations registered for the ‘on_train_batch_start’ stage.

on_train_end(data: dict[str, Any], ctx: dict[str, Any])

Execute operations registered for the ‘on_train_end’ stage.

on_train_start(data: dict[str, Any], ctx: dict[str, Any])

Execute operations registered for the ‘on_train_start’ stage.

on_valid_batch_end(data: dict[str, Any], ctx: dict[str, Any]) dict[str, Any]

Execute operations registered for the ‘on_valid_batch_end’ stage.

on_valid_batch_start(data: dict[str, Any], ctx: dict[str, Any]) dict[str, Any]

Execute operations registered for the ‘on_valid_batch_start’ stage.

class congrads.callbacks.base.CallbackManager(callbacks: Iterable[Callback] | None = None)

Bases: object

Orchestrates multiple callbacks and executes them at specific lifecycle stages.

  • Callbacks are executed in registration order.

  • Event-local data flows through all callbacks.

  • Shared context is available for cross-callback communication.

add(callback: Callback) Self

Register a single callback.

Parameters:

callback (Callback) – Callback instance to add.

Returns:

Returns self for fluent chaining.

Return type:

Self

property callbacks: tuple[Callback, ...]

Return a read-only tuple of registered callbacks.

Returns:

Registered callbacks.

Return type:

tuple[Callback, …]

extend(callbacks: Iterable[Callback]) None

Register multiple callbacks at once.

Parameters:

callbacks (Iterable[Callback]) – Iterable of callbacks to add.

run(stage: Literal['on_train_start', 'on_train_end', 'on_epoch_start', 'on_epoch_end', 'on_test_start', 'on_test_end', 'on_batch_start', 'on_batch_end', 'on_train_batch_start', 'on_train_batch_end', 'on_valid_batch_start', 'on_valid_batch_end', 'on_test_batch_start', 'on_test_batch_end', 'after_train_forward', 'after_valid_forward', 'after_test_forward'], data: dict[str, Any]) dict[str, Any]

Execute all registered callbacks for a specific stage.

Parameters:
  • stage (Stage) – Lifecycle stage to run (e.g., “on_batch_start”).

  • data (dict[str, Any]) – Event-local data dictionary to pass through callbacks.

Returns:

The final merged data dictionary after executing all callbacks.

Return type:

dict[str, Any]

Raises:
  • ValueError – If a callback does not implement the requested stage.

  • RuntimeError – If any callback raises an exception during execution.

class congrads.callbacks.base.Operation

Bases: ABC

Abstract base class representing a stateless unit of work executed inside a callback stage.

Subclasses should implement the compute method which returns a dictionary of outputs to merge into the running event data.

abstract compute(data: dict[str, Any], ctx: dict[str, Any]) dict[str, Any] | None

Perform the operation’s computation.

Parameters:
  • data (dict[str, Any]) – Event-local dictionary containing the current data.

  • ctx (dict[str, Any]) – Shared context dictionary.

Returns:

Outputs to merge into the running data.

Returning None is equivalent to {}.

Return type:

dict[str, Any] or None

Checkpoints

Module for managing PyTorch model checkpoints.

Provides the CheckpointManager class to save and load model and optimizer states during training, track the best metric values, and optionally report checkpoint events.

class congrads.checkpoints.CheckpointManager(criteria_function: Callable[[dict[str, Tensor], dict[str, Tensor]], bool], network: Module, optimizer: Optimizer, metric_manager: MetricManager, save_dir: str = 'checkpoints', create_dir: bool = False, report_save: bool = False)

Bases: object

Manage saving and loading checkpoints for PyTorch models and optimizers.

Handles checkpointing based on a criteria function, restores metric states, and optionally reports when a checkpoint is saved.

evaluate_criteria(epoch: int, metric_group: str = 'during_training')

Evaluate the criteria function to determine if a better model is found.

Aggregates the current metric values during training and applies the criteria function. If the criteria function indicates improvement, the best metric values are updated, a checkpoint is saved, and a message is optionally printed.

Parameters:
  • epoch (int) – The current epoch number.

  • metric_group (str, optional) – The metric group to evaluate. Defaults to ‘during_training’.

load(filename: str)

Load a checkpoint and restore the training state.

Loads the checkpoint from the specified file and restores the network weights, optimizer state, and best metric values.

Parameters:

filename (str) – Name of the checkpoint file.

Returns:

A dictionary containing the loaded checkpoint information,

including epoch, loss, and other relevant training state.

Return type:

dict

resume(filename: str = 'checkpoint.pth', ignore_missing: bool = False) int

Resumes training from a saved checkpoint file.

Parameters:
  • filename (str) – The name of the checkpoint file to load. Defaults to “checkpoint.pth”.

  • ignore_missing (bool) – If True, does not raise an error if the checkpoint file is missing and continues without loading, starting from epoch 0. Defaults to False.

Returns:

The epoch number from the loaded checkpoint, or 0 if

ignore_missing is True and no checkpoint was found.

Return type:

int

Raises:
  • TypeError – If a provided attribute has an incompatible type.

  • FileNotFoundError – If the specified checkpoint file does not exist.

save(epoch: int, filename: str = 'checkpoint.pth')

Save a checkpoint.

Parameters:
  • epoch (int) – Current epoch number.

  • filename (str) – Name of the checkpoint file. Defaults to ‘checkpoint.pth’.

Constraints

Module providing constraint classes for guiding neural network training.

This module defines constraints that enforce specific conditions on network outputs to steer learning. Available constraint types include:

  • Constraint: Base class for all constraint types, defining the interface and core behavior.

  • ImplicationConstraint: Enforces one condition only if another condition is met, useful for modeling implications between outputs.

  • ScalarConstraint: Enforces scalar-based comparisons on a network’s output.

  • BinaryConstraint: Enforces a binary comparison between two tags using a comparison function (e.g., less than, greater than).

  • SumConstraint: Ensures the sum of selected tags’ outputs equals a specified value, controlling total output.

These constraints can steer the learning process by applying logical implications or numerical bounds.

Usage:
  1. Define a custom constraint class by inheriting from Constraint.

  2. Apply the constraint to your neural network during training.

  3. Use helper classes like IdentityTransformation for transformations and comparisons in constraints.

class congrads.constraints.registry.ANDConstraint(*constraints: Constraint, name: str = None, enforce: bool = False, rescale_factor: Number = 1.5)

Bases: Constraint

A composite constraint that enforces the logical AND of multiple constraints.

This class combines multiple sub-constraints and evaluates them jointly:

  • The satisfaction of the AND constraint is True only if all sub-constraints

are satisfied (elementwise logical AND). * The corrective direction is computed by weighting each sub-constraint’s direction with its satisfaction mask and summing across all sub-constraints.

calculate_direction(data: dict[str, Tensor])

Compute the corrective direction by aggregating sub-constraint directions.

Each sub-constraint contributes its corrective direction, weighted by its satisfaction mask. The directions are summed across constraints for each affected layer.

Parameters:

data – Model predictions and associated batch/context information.

Returns:

A mapping from layer identifiers to correction tensors. Each entry represents the aggregated correction to apply to that layer, based on the satisfaction-weighted sum of sub-constraint directions.

Return type:

dict[str, Tensor]

check_constraint(data: dict[str, Tensor])

Evaluate whether all sub-constraints are satisfied.

Parameters:

data – Model predictions and associated batch/context information.

Returns:

A tuple (total_satisfaction, mask) where:
  • total_satisfaction: A boolean or numeric tensor indicating

elementwise whether all constraints are satisfied (logical AND). * mask: A tensor of ones with the same shape as total_satisfaction. Typically used as a weighting mask in downstream processing.

Return type:

tuple[Tensor, Tensor]

class congrads.constraints.registry.BinaryConstraint(operand_left: str | Transformation, comparator: Literal['>', '<', '>=', '<='], operand_right: str | Transformation, name: str = None, enforce: bool = True, rescale_factor: Number = 1.5)

Bases: Constraint

A constraint that enforces a binary comparison between two tags.

This class ensures that the output of one tag satisfies a comparison operation with the output of another tag (e.g., less than, greater than, etc.). It uses a comparator function to validate the condition and calculates adjustment directions accordingly.

Parameters:
  • operand_left (Union[str, Transformation]) – Name of the left tag or a transformation to apply.

  • comparator (Literal[">", "<", ">=", "<="]) – Comparison operator used in the constraint.

  • operand_right (Union[str, Transformation]) – Name of the right tag or a transformation to apply.

  • name (str, optional) – A unique name for the constraint. If not provided, a name is auto-generated in the format “<operand_left> <comparator> <operand_right>”.

  • enforce (bool, optional) – If False, only monitor the constraint without adjusting the loss. Defaults to True.

  • rescale_factor (Number, optional) – Factor to scale the constraint-adjusted loss. Defaults to 1.5.

Raises:

TypeError – If a provided attribute has an incompatible type.

Notes

  • The tags must be defined in the descriptor mapping.

  • The constraint name is composed using the left tag, comparator, and right tag.

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Compute adjustment directions for the tags involved in the binary constraint.

The returned directions indicate how to adjust each tag’s output to satisfy the constraint. Only currently supported for dense layers.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

A mapping from layer names to tensors specifying the normalized adjustment directions for each tag involved in the constraint.

Return type:

dict[str, Tensor]

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Evaluate whether the binary constraint is satisfied for the current predictions.

The constraint compares the outputs of two tags using the specified comparator function. A result of 1 indicates the constraint is satisfied for a sample, and 0 indicates it is violated.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

  • result (Tensor): Binary tensor indicating constraint satisfaction

(1 for satisfied, 0 for violated) for each sample. - mask (Tensor): Tensor of ones with the same shape as result, used for constraint aggregation.

Return type:

tuple[Tensor, Tensor]

class congrads.constraints.registry.EncodedGroupedMonotonicityConstraint(base: MonotonicityConstraint, tag_group: str, name: str | None = None)

Bases: Constraint

Applies a base monotonicity constraint to groups via interval encoding.

This constraint enforces a monotonic relationship between a prediction tag (base.tag_prediction) and a reference tag (base.tag_reference) for each group defined by tag_group.

Instead of looping over groups, each group’s predictions and targets are shifted by a large offset to place them in non-overlapping intervals. This allows the base constraint to be applied in a single, vectorized operation.

This is a vectorized alternative to PerGroupMonotonicityConstraint, which enforces the same logic via explicit per-group iteration.

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Calculates ranking adjustments for monotonicity enforcement.

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Evaluate whether the monotonicity constraint is satisfied by mapping each group on a non-overlapping interval.

class congrads.constraints.registry.ImplicationConstraint(head: Constraint, body: Constraint, name: str = None)

Bases: Constraint

Represents an implication constraint between two constraints (head and body).

The implication constraint ensures that the body constraint only applies when the head constraint is satisfied. If the head constraint is not satisfied, the body constraint does not apply.

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Compute adjustment directions for tags to satisfy the constraint.

Uses the body constraint directions as the update vector. Only applies updates if the head constraint is satisfied. Currently, this method only works for dense layers due to tag-to-index translation limitations.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

Dictionary mapping tags to tensors

specifying the adjustment direction for each tag.

Return type:

dict[str, Tensor]

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Check whether the implication constraint is satisfied.

Evaluates the head and body constraints. The body constraint is enforced only if the head constraint is satisfied. If the head constraint is not satisfied, the body constraint does not affect the result.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

  • result: Tensor indicating satisfaction of the implication

constraint (1 if satisfied, 0 otherwise). - head_satisfaction: Tensor indicating satisfaction of the head constraint alone.

Return type:

tuple[Tensor, Tensor]

class congrads.constraints.registry.ORConstraint(*constraints: Constraint, name: str = None, enforce: bool = False, rescale_factor: Number = 1.5)

Bases: Constraint

A composite constraint that enforces the logical OR of multiple constraints.

This class combines multiple sub-constraints and evaluates them jointly:

  • The satisfaction of the OR constraint is True if at least one sub-constraint

is satisfied (elementwise logical OR). * The corrective direction is computed by weighting each sub-constraint’s direction with its satisfaction mask and summing across all sub-constraints.

calculate_direction(data: dict[str, Tensor])

Compute the corrective direction by aggregating sub-constraint directions.

Each sub-constraint contributes its corrective direction, weighted by its satisfaction mask. The directions are summed across constraints for each affected layer.

Parameters:

data – Model predictions and associated batch/context information.

Returns:

A mapping from layer identifiers to correction tensors. Each entry represents the aggregated correction to apply to that layer, based on the satisfaction-weighted sum of sub-constraint directions.

Return type:

dict[str, Tensor]

check_constraint(data: dict[str, Tensor])

Evaluate whether any sub-constraints are satisfied.

Parameters:

data – Model predictions and associated batch/context information.

Returns:

A tuple (total_satisfaction, mask) where:
  • total_satisfaction: A boolean or numeric tensor indicating

elementwise whether any constraints are satisfied (logical OR). * mask: A tensor of ones with the same shape as total_satisfaction. Typically used as a weighting mask in downstream processing.

Return type:

tuple[Tensor, Tensor]

class congrads.constraints.registry.PairwiseMonotonicityConstraint(tag_prediction: str, tag_reference: str, rescale_factor_lower: float = 1.5, rescale_factor_upper: float = 1.75, stable: bool = True, direction: Literal['ascending', 'descending'] = 'ascending', name: str = None, enforce: bool = True)

Bases: MonotonicityConstraint

Constraint that enforces a monotonic relationship between two tags.

This constraint ensures that the activations of a prediction tag (tag_prediction) are monotonically ascending or descending with respect to a target tag (tag_reference).

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Calculates ranking adjustments for monotonicity enforcement.

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Evaluate whether the monotonicity constraint is satisfied.

class congrads.constraints.registry.PerGroupMonotonicityConstraint(base: MonotonicityConstraint, tag_group: str, name: str | None = None)

Bases: Constraint

Applies a monotonicity constraint independently per group of samples.

This class wraps an existing MonotonicityConstraint instance (base) and enforces it separately for each unique group defined by tag_group.

Each group is treated as an independent mini-batch: - The base constraint is applied to the group’s subset of data. - Violations and directions are computed per group and then reassembled into the original batch order.

This is an explicit alternative to EncodedGroupedMonotonicityConstraint, which enforces the same logic using vectorized interval encoding.

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Calculates ranking adjustments for monotonicity enforcement.

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Evaluate whether the monotonicity constraint is satisfied per group.

class congrads.constraints.registry.RankedMonotonicityConstraint(tag_prediction: str, tag_reference: str, rescale_factor_lower: float = 1.5, rescale_factor_upper: float = 1.75, stable: bool = True, direction: Literal['ascending', 'descending'] = 'ascending', name: str = None, enforce: bool = True)

Bases: MonotonicityConstraint

Constraint that enforces a monotonic relationship between two tags.

This constraint ensures that the activations of a prediction tag (tag_prediction) are monotonically ascending or descending with respect to a target tag (tag_reference).

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Calculates ranking adjustments for monotonicity enforcement.

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Evaluate whether the monotonicity constraint is satisfied.

class congrads.constraints.registry.ScalarConstraint(operand: str | Transformation, comparator: Literal['>', '<', '>=', '<='], scalar: Number, name: str = None, enforce: bool = True, rescale_factor: Number = 1.5)

Bases: Constraint

A constraint that enforces scalar-based comparisons on a specific tag.

This class ensures that the output of a specified tag satisfies a scalar comparison operation (e.g., less than, greater than, etc.). It uses a comparator function to validate the condition and calculates adjustment directions accordingly.

Parameters:
  • operand (Union[str, Transformation]) – Name of the tag or a transformation to apply.

  • comparator (Literal[">", "<", ">=", "<="]) – Comparison operator used in the constraint.

  • scalar (Number) – The scalar value to compare against.

  • name (str, optional) – A unique name for the constraint. If not provided, a name is auto-generated in the format “<tag> <comparator> <scalar>”.

  • enforce (bool, optional) – If False, only monitor the constraint without adjusting the loss. Defaults to True.

  • rescale_factor (Number, optional) – Factor to scale the constraint-adjusted loss. Defaults to 1.5.

Raises:

TypeError – If a provided attribute has an incompatible type.

Notes

  • The tag must be defined in the descriptor mapping.

  • The constraint name is composed using the tag, comparator, and scalar value.

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Compute adjustment directions to satisfy the scalar constraint.

Only works for dense layers due to tag-to-index translation.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

Dictionary mapping layers to tensors specifying

the adjustment direction for each tag.

Return type:

dict[str, Tensor]

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Check if the scalar constraint is satisfied for a given tag.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

  • result: Tensor indicating whether the tag satisfies the constraint.

  • ones_like(result): Tensor of ones with same shape as result.

Return type:

tuple[Tensor, Tensor]

class congrads.constraints.registry.SumConstraint(operands_left: list[str | Transformation], comparator: Literal['>', '<', '>=', '<='], operands_right: list[str | Transformation], weights_left: list[Number] = None, weights_right: list[Number] = None, name: str = None, enforce: bool = True, rescale_factor: Number = 1.5)

Bases: Constraint

A constraint that enforces a weighted summation comparison between two groups of tags.

This class evaluates whether the weighted sum of outputs from one set of tags satisfies a comparison operation with the weighted sum of outputs from another set of tags.

calculate_direction(data: dict[str, Tensor]) dict[str, Tensor]

Compute adjustment directions for tags involved in the weighted sum constraint.

The directions indicate how to adjust each tag’s output to satisfy the constraint. Only dense layers are currently supported.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

Mapping from layer names to normalized tensors specifying adjustment directions for each tag involved in the constraint.

Return type:

dict[str, Tensor]

check_constraint(data: dict[str, Tensor]) tuple[Tensor, Tensor]

Evaluate whether the weighted sum constraint is satisfied.

Computes the weighted sum of outputs from the left and right tags, applies the specified comparator function, and returns a binary result for each sample.

Parameters:

data (dict[str, Tensor]) – Dictionary that holds batch data, model predictions and context.

Returns:

  • result (Tensor): Binary tensor indicating whether the constraint

is satisfied (1) or violated (0) for each sample. - mask (Tensor): Tensor of ones, used for constraint aggregation.

Return type:

tuple[Tensor, Tensor]

Core

Defines the BatchRunner, which executes individual batches for training, validation, and testing.

Responsibilities: - Move batch data to the appropriate device - Run forward passes through the network - Compute base and constraint-adjusted losses - Perform backpropagation during training - Accumulate metrics for loss and other monitored quantities - Trigger callbacks at key points in the batch lifecycle

class congrads.core.batch_runner.BatchRunner(network: Module, criterion, optimizer: Optimizer, constraint_engine: ConstraintEngine, metric_manager: MetricManager | None, callback_manager: CallbackManager | None, device: device)

Bases: object

Executes a single batch for training, validation, or testing.

The BatchRunner handles moving data to the correct device, running the network forward, computing base and constraint-adjusted losses, performing backpropagation during training, accumulating metrics, and dispatching callbacks at key points in the batch lifecycle.

test_batch(batch: dict[str, Tensor]) Tensor

Run a single test batch.

Steps performed: 1. Move batch to device and run “on_test_batch_start” callback. 2. Forward pass through the network. 3. Compute base loss using the criterion and accumulate metric. 4. Evaluate constraints via the ConstraintEngine (does not modify loss). 5. Run “on_test_batch_end” callback.

Parameters:

batch – Dictionary of input and target tensors for the batch.

Returns:

The base loss computed for the batch.

Return type:

Tensor

train_batch(batch: dict[str, Tensor]) Tensor

Run a single training batch.

Steps performed: 1. Move batch to device and run “on_train_batch_start” callback. 2. Forward pass through the network. 3. Compute base loss using the criterion and accumulate metric. 4. Apply constraint-based adjustments to the loss. 5. Perform backward pass and optimizer step. 6. Run “on_train_batch_end” callback.

Parameters:

batch – Dictionary of input and target tensors for the batch.

Returns:

The base loss computed before constraint adjustments.

Return type:

Tensor

valid_batch(batch: dict[str, Tensor]) Tensor

Run a single validation batch.

Steps performed: 1. Move batch to device and run “on_valid_batch_start” callback. 2. Forward pass through the network. 3. Compute base loss using the criterion and accumulate metric. 4. Evaluate constraints via the ConstraintEngine (does not modify loss). 5. Run “on_valid_batch_end” callback.

Parameters:

batch – Dictionary of input and target tensors for the batch.

Returns:

The base loss computed for the batch.

Return type:

Tensor

This module provides the core CongradsCore class for the main training functionality.

It is designed to integrate constraint-guided optimization into neural network training. It extends traditional training processes by enforcing specific constraints on the model’s outputs, ensuring that the network satisfies domain-specific requirements during both training and evaluation.

The CongradsCore class serves as the central engine for managing the training, validation, and testing phases of a neural network model, incorporating constraints that influence the loss function and model updates. The model is trained with standard loss functions while also incorporating constraint-based adjustments, which are tracked and logged throughout the process.

Key features: - Support for various constraints that can influence the training process. - Integration with PyTorch’s DataLoader for efficient batch processing. - Metric management for tracking loss and constraint satisfaction. - Checkpoint management for saving and evaluating model states.

class congrads.core.congradscore.CongradsCore(descriptor: ~congrads.descriptor.Descriptor, constraints: list[~congrads.constraints.base.Constraint], network: ~torch.nn.modules.module.Module, criterion: ~torch.nn.modules.loss._Loss, optimizer: ~torch.optim.optimizer.Optimizer, device: ~torch.device, dataloader_train: ~torch.utils.data.dataloader.DataLoader, dataloader_valid: ~torch.utils.data.dataloader.DataLoader | None = None, dataloader_test: ~torch.utils.data.dataloader.DataLoader | None = None, metric_manager: ~congrads.metrics.MetricManager | None = None, callback_manager: ~congrads.callbacks.base.CallbackManager | None = None, checkpoint_manager: ~congrads.checkpoints.CheckpointManager | None = None, network_uses_grad: bool = False, epsilon: float = 1e-06, constraint_aggregator: ~collections.abc.Callable[[...], ~torch.Tensor] = <built-in method sum of type object>, enforce_all: bool = True, disable_progress_bar_epoch: bool = False, disable_progress_bar_batch: bool = False, epoch_runner_cls: type[~congrads.core.epoch_runner.EpochRunner] | None = None, batch_runner_cls: type[~congrads.core.batch_runner.BatchRunner] | None = None, constraint_engine_cls: type[~congrads.core.constraint_engine.ConstraintEngine] | None = None)

Bases: object

The CongradsCore class is the central training engine for constraint-guided optimization.

It integrates standard neural network training with additional constraint-driven adjustments to the loss function, ensuring that the network satisfies domain-specific constraints during training.

fit(*, start_epoch: int = 0, max_epochs: int = 100, test_model: bool = True, final_checkpoint_name: str = 'checkpoint_final.pth') None

Run the full training loop, including optional validation, testing, and checkpointing.

This method performs training over multiple epochs with the following steps: 1. Trigger “on_train_start” callbacks if a callback manager is present. 2. For each epoch: - Trigger “on_epoch_start” callbacks. - Run a training epoch via the EpochRunner. - Run a validation epoch via the EpochRunner. - Evaluate checkpoint criteria if a checkpoint manager is present. - Trigger “on_epoch_end” callbacks. 3. Trigger “on_train_end” callbacks after all epochs. 4. Optionally run a test epoch via the EpochRunner if test_model is True, with corresponding “on_test_start” and “on_test_end” callbacks. 5. Save a final checkpoint using the checkpoint manager.

Parameters:
  • start_epoch – Index of the starting epoch (default 0). Useful for resuming training.

  • max_epochs – Maximum number of epochs to run (default 100).

  • test_model – Whether to run a test epoch after training (default True).

  • final_checkpoint_name – Filename for the final checkpoint saved at the end of training (default “checkpoint_final.pth”).

Returns:

None

Manages the evaluation and optional enforcement of constraints on neural network outputs.

Responsibilities: - Compute and log Constraint Satisfaction Rate (CSR) for training, validation, and test batches. - Optionally adjust loss during training based on constraint directions and rescale factors. - Handle gradient computation and CGGD application.

class congrads.core.constraint_engine.ConstraintEngine(*, constraints: list[Constraint], descriptor: Descriptor, metric_manager: MetricManager, device: device, epsilon: float, aggregator: callable, enforce_all: bool)

Bases: object

Manages constraint evaluation and enforcement for a neural network.

The ConstraintEngine coordinates constraints defined in Constraint objects, computes gradients for layers that affect the loss, logs metrics, and optionally modifies the loss during training according to the constraints. It supports separate phases for training, validation, and testing.

test(data: dict[str, Tensor], loss: Tensor) Tensor

Evaluate constraints during testing without modifying the loss.

Computes and logs the Constraint Satisfaction Rate (CSR) for each constraint, but does not apply rescale adjustments to the loss.

Parameters:
  • data – Dictionary containing input and prediction tensors for the batch.

  • loss – The original loss tensor computed from the network output.

Returns:

The original loss tensor, unchanged.

Return type:

Tensor

train(data: dict[str, Tensor], loss: Tensor) Tensor

Apply all active constraints during training.

Computes the original loss gradients for layers that affect the loss, evaluates each constraint, logs the Constraint Satisfaction Rate (CSR), and adjusts the loss according to constraint satisfaction.

Parameters:
  • data – Dictionary containing input and prediction tensors for the batch.

  • loss – The original loss tensor computed from the network output.

Returns:

The loss tensor after applying constraint-based adjustments.

Return type:

Tensor

validate(data: dict[str, Tensor], loss: Tensor) Tensor

Evaluate constraints during validation without modifying the loss.

Computes and logs the Constraint Satisfaction Rate (CSR) for each constraint, but does not apply rescale adjustments to the loss.

Parameters:
  • data – Dictionary containing input and prediction tensors for the batch.

  • loss – The original loss tensor computed from the network output.

Returns:

The original loss tensor, unchanged.

Return type:

Tensor

Defines the EpochRunner class for running full training, validation, and test epochs.

This module handles: - Switching the network between training and evaluation modes - Iterating over DataLoaders with optional progress bars - Delegating per-batch processing to a BatchRunner instance - Optional gradient tracking control for evaluation phases - Warnings when validation or test loaders are not provided

class congrads.core.epoch_runner.EpochRunner(network: Module, batch_runner: BatchRunner, train_loader: DataLoader, valid_loader: DataLoader | None = None, test_loader: DataLoader | None = None, *, network_uses_grad: bool = False, disable_progress_bar: bool = False)

Bases: object

Runs full epochs over DataLoaders.

Responsibilities: - Model mode switching - Iteration over DataLoader - Delegation to BatchRunner - Progress bars

test() None

Run a test epoch over the test DataLoader.

Sets the network to evaluation mode and iterates over batches, delegating each batch to the BatchRunner for processing. Skips testing if no test_loader is provided.

train() None

Run a training epoch over the training DataLoader.

Sets the network to training mode and iterates over batches, delegating each batch to the BatchRunner for processing.

validate() None

Run a validation epoch over the validation DataLoader.

Sets the network to evaluation mode and iterates over batches, delegating each batch to the BatchRunner for processing. Skips validation if no valid_loader is provided.

Datasets

This module defines several PyTorch dataset classes for loading and working with various datasets.

Each dataset class extends the torch.utils.data.Dataset class and provides functionality for downloading, loading, and transforming specific datasets where applicable.

Classes:

  • BiasCorrection: A dataset class for the Bias Correction dataset focused on temperature forecast data.

  • FamilyIncome: A dataset class for the Family Income and Expenditure dataset.

  • SectionedGaussians: A synthetic dataset generating smoothly varying Gaussian signals across multiple sections.

  • SyntheticMonotonicity: A synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.

  • SyntheticClusters: A dataset class for generating synthetic clustered 2D data with labels.

Each dataset class provides methods for downloading the data (if not already available or synthetic), checking the integrity of the dataset, loading the data from CSV files or generating synthetic data, and applying transformations to the data.

class congrads.datasets.registry.BiasCorrection(root: str | Path, transform: Callable, download: bool = False)

Bases: Dataset

A dataset class for accessing the Bias Correction dataset.

This class extends the Dataset class and provides functionality for downloading, loading, and transforming the Bias Correction dataset. The dataset is focused on temperature forecast data and is made available for use with PyTorch. If download is set to True, the dataset will be downloaded if it is not already available. The data is then loaded, and a transformation function is applied to it.

Parameters:
  • root (Union[str, Path]) – The root directory where the dataset will be stored or loaded from.

  • transform (Callable) – A function to transform the dataset (e.g., preprocessing).

  • download (bool, optional) – Whether to download the dataset if it’s not already present. Defaults to False.

Raises:

RuntimeError – If the dataset is not found and download is not set to True or if all mirrors fail to provide the dataset.

property data_folder: str

Returns the path to the folder where the dataset is stored.

Returns:

The path to the dataset folder.

Return type:

str

download() None

Downloads and extracts the dataset.

This method attempts to download the dataset from the mirrors and extract it into the appropriate folder. If any error occurs during downloading, it will try each mirror in sequence.

Raises:

RuntimeError – If all mirrors fail to provide the dataset.

mirrors = ['https://archive.ics.uci.edu/static/public/514/']
resources = [('bias+correction+of+numerical+prediction+model+temperature+forecast.zip', '3deee56d461a2686887c4ae38fe3ccf3')]
class congrads.datasets.registry.FamilyIncome(root: str | Path, transform: Callable, download: bool = False)

Bases: Dataset

A dataset class for accessing the Family Income and Expenditure dataset.

This class extends the Dataset class and provides functionality for downloading, loading, and transforming the Family Income and Expenditure dataset. The dataset is intended for use with PyTorch-based projects, offering a convenient interface for data handling. This class provides access to the Family Income and Expenditure dataset for use with PyTorch. If download is set to True, the dataset will be downloaded if it is not already available. The data is then loaded, and a user-defined transformation function is applied to it.

Parameters:
  • root (Union[str, Path]) – The root directory where the dataset will be stored or loaded from.

  • transform (Callable) – A function to transform the dataset (e.g., preprocessing).

  • download (bool, optional) – Whether to download the dataset if it’s not already present. Defaults to False.

Raises:

RuntimeError – If the dataset is not found and download is not set to True or if all mirrors fail to provide the dataset.

property data_folder: str

Returns the path to the folder where the dataset is stored.

Returns:

The path to the dataset folder.

Return type:

str

download() None

Downloads and extracts the dataset.

This method attempts to download the dataset from the mirrors and extract it into the appropriate folder. If any error occurs during downloading, it will try each mirror in sequence.

Raises:

RuntimeError – If all mirrors fail to provide the dataset.

mirrors = ['https://www.kaggle.com/api/v1/datasets/download/grosvenpaul/family-income-and-expenditure']
resources = [('archive.zip', '7d74bc7facc3d7c07c4df1c1c6ac563e')]
class congrads.datasets.registry.SectionedGaussians(sections: list[dict], n_samples: int = 1000, n_runs: int = 1, seed: int | None = None, device='cpu', ground_truth_steepness: float = 0.0, blend_k: float = 10.0)

Bases: Dataset

Synthetic dataset generating smoothly varying Gaussian signals across multiple sections.

Each section defines a subrange of x-values with its own Gaussian distribution (mean and standard deviation). Instead of abrupt transitions, the parameters are blended smoothly between sections using a sigmoid function.

The resulting signal can represent a continuous process where statistical properties gradually evolve over time or position.

Features:
  • Input: Gaussian signal samples (y-values)

  • Context: Concatenation of time (x) and normalized energy feature

  • Target: Exponential decay ground truth from 1 at x_min to 0 at x_max

sections

List of section definitions.

Type:

list[dict]

n_samples

Total number of samples across all sections.

Type:

int

n_runs

Number of random waveforms generated from base configuration.

Type:

int

time

Sampled x-values, shape [n_samples, 1].

Type:

torch.Tensor

signal

Generated Gaussian signal values, shape [n_samples, 1].

Type:

torch.Tensor

energies

Normalized energy feature, shape [n_samples, 1].

Type:

torch.Tensor

context

Concatenation of time and energy, shape [n_samples, 2].

Type:

torch.Tensor

x_min

Minimum x-value across all sections.

Type:

float

x_max

Maximum x-value across all sections.

Type:

float

ground_truth_steepness

Exponential decay steepness for target output.

Type:

float

blend_k

Sharpness parameter controlling how rapidly means and standard deviations transition between sections.

Type:

float

class congrads.datasets.registry.SyntheticClusters(cluster_centers, cluster_sizes, cluster_std, cluster_labels)

Bases: Dataset

PyTorch dataset for generating synthetic clustered 2D data with labels.

Each cluster is defined by its center, size, spread (standard deviation), and label. The dataset samples points from a Gaussian distribution centered at the cluster mean.

Parameters:
  • cluster_centers (list[tuple[float, float]]) – Coordinates of each cluster center, e.g. [(x1, y1), (x2, y2), …].

  • cluster_sizes (list[int]) – Number of points to generate in each cluster.

  • cluster_std (list[float]) – Standard deviation (spread) of each cluster.

  • cluster_labels (list[int]) – Class label for each cluster (e.g., 0 or 1).

Raises:

AssertionError – If the input lists do not all have the same length.

data

A concatenated tensor of all generated points with shape (N, 2).

Type:

torch.Tensor

labels

A concatenated tensor of class labels with shape (N,), where N is the total number of generated points.

Type:

torch.Tensor

class congrads.datasets.registry.SyntheticMonotonicity(n_samples=200, x_range=(0.0, 5.0), noise_base=0.05, noise_scale=0.15, noise_sharpness=4.0, noise_center=2.5, osc_amplitude=0.08, osc_frequency=6.0, osc_prob=0.5, seed=None)

Bases: Dataset

Synthetic 1D dataset with monotone ground truth (log(1+x)), plus configurable structured noise.

True function:

y_true(x) = log(1 + x)

Observed:

y(x) = y_true(x) + heteroscedastic_noise(x) + local oscillatory perturbation

Parameters:
  • n_samples (int) – number of data points (default 200)

  • x_range (tuple) – range of x values (default [0, 5])

  • noise_base (float) – baseline noise level (default 0.05)

  • noise_scale (float) – scale of heteroscedastic noise (default 0.15)

  • noise_sharpness (float) – steepness of heteroscedastic transition (default 4.0)

  • noise_center (float) – center point of heteroscedastic increase (default 2.5)

  • osc_amplitude (float) – amplitude of oscillatory perturbation (default 0.08)

  • osc_frequency (float) – frequency of oscillation (default 6.0)

  • osc_prob (float) – probability each sample receives oscillation (default 0.5)

  • seed (int or None) – random seed

Descriptor

Descriptor utilities for mapping semantic tags to model layers such as inputs and outputs.

This module defines the Descriptor class, which manages a structured mapping between:

  • Layers (data dictionary keys representing model layers such as inputs, outputs, or intermediate features)

  • Tags (named references to full layers or specific feature indices)

Layers describe properties of a model output tensor, such as whether it is constant and whether it contributes to the loss. Tags provide semantic names for selecting either:

  • The full layer output (index=None)

  • A single feature column (index=int)

  • Multiple feature columns (index=Iterable[int])

This abstraction allows constraints and other logic to reference model outputs by descriptive tag names instead of hard-coded tensor indices.

class congrads.descriptor.Descriptor

Bases: object

Registry for model output layers and semantic tags.

A Descriptor maintains two mappings:

  • Layers: represent keys in the model’s data dictionary and store metadata such as whether the layer is constant and whether it affects the loss.

  • Tags: map semantic names to a specific layer and an optional index (or indices) within that layer.

Tags allow selecting either: - The full layer tensor (index=None) - A single feature column (index=int) - Multiple feature columns (index=tuple[int, …])

The descriptor does not validate tensor shapes at registration time. Selection logic assumes tensors follow the shape convention: [batch_size, features].

add_layer(key: str, constant: bool = False, affects_loss: bool = True, gradients_from: str | None = None)

Register a new layer.

A layer corresponds to a key in the model’s data dictionary and describes metadata about that output tensor.

Parameters:
  • key (str) – Name of the layer (must match a key in the data dictionary).

  • constant (bool, optional) – Whether the layer represents constant data. Defaults to False.

  • affects_loss (bool, optional) – Whether this layer contributes to the loss computation. Defaults to True.

  • gradients_from (str | None, optional) – If specified, indicates that gradients from another layer should be used instead.

Raises:
  • TypeError – If arguments have incorrect types.

  • ValueError – If a layer with the same key is already registered.

add_tag(tag: str, layer: str, index: int | Iterable[int] | None = None)

Register a semantic tag for a layer or part of a layer.

A tag maps a descriptive name to a specific selection within a registered layer.

Index behavior:
  • None → select the full layer tensor

  • int → select a single feature column

  • Iterable[int] → select multiple feature columns

Parameters:
  • tag (str) – Unique name of the tag.

  • layer (str) – Name of a previously registered layer.

  • index (int | Iterable[int] | None, optional) – Feature index or indices within the layer tensor.

Raises:
  • TypeError – If arguments have incorrect types.

  • ValueError

    • If the layer is not registered. - If the tag already exists. - If the index contains duplicates.

property affects_loss_layers: set[str]

Return the set of registered layer keys that affect the loss.

property constant_layers: set[str]

Return the set of registered layer keys marked as constant.

get_layer(key: str) Layer

Return the Layer object associated with a key.

Parameters:

key (str) – Registered layer key.

Returns:

The Layer object containing metadata for the specified key.

Return type:

Layer

Raises:

ValueError – If the layer key is not registered.

get_tag(tag: str) Tag

Return the Tag object associated with a tag name.

Parameters:

tag (str) – Registered tag name.

Returns:

The Tag object containing the layer and index for the specified tag.

Return type:

Tag

Raises:

ValueError – If the tag is not registered.

has_layer(key: str) bool

Return whether a layer is registered.

has_tag(tag: str) bool

Return whether a tag is registered.

location(tag: str) tuple[str, int | tuple[int, ...] | None]

Return the layer name and index associated with a tag.

Parameters:

tag (str) – Registered tag name.

Returns:

  • Layer name

  • Index specification (None, int, or tuple[int, …])

Return type:

tuple[str, int | tuple[int, …] | None]

Raises:

ValueError – If the tag is not registered.

select(tag: str, data: dict[str, Tensor]) Tensor

Select data from a layer using a registered tag.

The tensor is retrieved from the provided data dictionary using the tag’s associated layer key.

Selection behavior:
  • index=None → return full tensor

  • index=int → return a single feature column with shape

    [batch_size, 1]

  • index=tuple → return selected feature columns

Parameters:
  • tag (str) – Registered tag name.

  • data (dict[str, Tensor]) – Dictionary containing model outputs.

Returns:

Selected tensor slice.

Return type:

Tensor

Raises:

ValueError

  • If the tag is not registered. - If indexed selection is requested but the tensor does not have at least two dimensions.

property variable_layers: set[str]

Return the set of registered layer keys marked as variable.

Metrics

Module for managing metrics during training.

Provides the Metric and MetricManager classes for accumulating, aggregating, and resetting metrics over training batches. Supports grouping metrics and using custom accumulation functions.

class congrads.metrics.Metric(name: str, accumulator: ~collections.abc.Callable[[...], ~torch.Tensor] = <built-in method nanmean of type object>)

Bases: object

Represents a single metric to be accumulated and aggregated.

Stores metric values over multiple batches and computes an aggregated result using a specified accumulation function.

accumulate(value: Tensor) None

Accumulate a new value for the metric.

Parameters:

value (Tensor) – Metric values for the current batch.

aggregate() Tensor

Compute the aggregated value of the metric.

Returns:

The aggregated metric value. Returns NaN if no values

have been accumulated.

Return type:

Tensor

reset() None

Reset the accumulated values and sample count for the metric.

class congrads.metrics.MetricManager

Bases: object

Manages multiple metrics and groups for training or evaluation.

Supports registering metrics, accumulating values by name, aggregating metrics by group, and resetting metrics by group.

accumulate(name: str, value: Tensor) None

Accumulate a value for a specific metric by name.

Parameters:
  • name (str) – Name of the metric.

  • value (Tensor) – Metric values for the current batch.

aggregate(group: str = 'default') dict[str, Tensor]

Aggregate all metrics in a specified group.

Parameters:

group (str, optional) – The group of metrics to aggregate. Defaults to “default”.

Returns:

Dictionary mapping metric names to their

aggregated values.

Return type:

dict[str, Tensor]

register(name: str, group: str = 'default', accumulator: ~collections.abc.Callable[[...], ~torch.Tensor] = <built-in method nanmean of type object>) None

Register a new metric under a specified group.

Parameters:
  • name (str) – Name of the metric.

  • group (str, optional) – Group name for the metric. Defaults to “default”.

  • accumulator (Callable[..., Tensor], optional) – Function to aggregate accumulated values. Defaults to torch.nanmean.

reset(group: str = 'default') None

Reset all metrics in a specified group.

Parameters:

group (str, optional) – The group of metrics to reset. Defaults to “default”.

reset_all() None

Reset all metrics across all groups.

Networks

Module defining the network architectures and components.

class congrads.networks.registry.MLPNetwork(n_inputs, n_outputs, n_hidden_layers=3, hidden_dim=35, activation=None)

Bases: Module

A multi-layer perceptron (MLP) neural network with configurable hidden layers.

forward(data: dict[str, Tensor])

Run a forward pass through the network.

Parameters:

data (dict[str, Tensor]) – Input data to be processed by the network.

Returns:

The original data tensor augmented with the network’s output (having key “output”).

Return type:

dict

Transformations

Module defining transformations and components.

class congrads.transformations.base.Transformation(tag: str)

Bases: ABC

Abstract base class for tag data transformations.

Module holding specific transformation implementations.

class congrads.transformations.registry.ApplyOperator(tag: str, operator: callable, value: Number)

Bases: Transformation

A transformation that applies a binary operator to the input tensor.

class congrads.transformations.registry.DenormalizeMinMax(tag: str, min: Number, max: Number)

Bases: Transformation

A transformation that denormalizes data using min-max scaling.

class congrads.transformations.registry.IdentityTransformation(tag: str)

Bases: Transformation

A transformation that returns the input unchanged.

Utils

Preprocessing functions for various datasets.

This module provides preprocessing pipelines for multiple datasets: - BiasCorrection: Temperature bias correction dataset - FamilyIncome: Family income and expenses dataset - AdultCensusIncome: Adult Census Income dataset

Each preprocessing function applies appropriate transformations including normalization, feature engineering, constraint filtering, and sampling.

congrads.utils.preprocessors.preprocess_AdultCensusIncome(df: DataFrame) DataFrame

Preprocesses the Adult Census Income dataset for PyTorch ML.

Sequential steps: - Drop rows with missing values. - Encode categorical variables to integer labels. - Map the target ‘income’ column to 0/1. - Convert all data to float32. - Add a multiindex to denote Input vs Output columns.

Parameters:

df (pd.DataFrame) – Raw dataframe containing Adult Census Income data.

Returns:

Preprocessed dataframe.

Return type:

pd.DataFrame

congrads.utils.preprocessors.preprocess_BiasCorrection(df: DataFrame) DataFrame

Preprocesses the given dataframe for bias correction by performing a series of transformations.

The function sequentially:

  • Drops rows with missing values.

  • Converts a date string to datetime format and adds year, month, and day columns.

  • Normalizes the columns with specific logic for input and output variables.

  • Adds a multi-index indicating which columns are input or output variables.

  • Samples 2500 examples from the dataset without replacement.

Parameters:

df (pd.DataFrame) – The input dataframe containing the data to be processed.

Returns:

The processed dataframe after applying the transformations.

Return type:

pd.DataFrame

congrads.utils.preprocessors.preprocess_FamilyIncome(df: DataFrame) DataFrame

Preprocesses the given Family Income dataframe.

The function sequentially:

  • Drops rows with missing values.

  • Converts object columns to appropriate data types and removes string columns.

  • Removes certain unnecessary columns like ‘Agricultural Household indicator’ and related features.

  • Adds labels to columns indicating whether they are input or output variables.

  • Normalizes the columns individually.

  • Checks and removes rows that do not satisfy predefined constraints (household income > expenses, food expenses > sub-expenses).

  • Samples 2500 examples from the dataset without replacement.

Parameters:

df (pd.DataFrame) – The input Family Income dataframe containing the data to be processed.

Returns:

The processed dataframe after applying the transformations and constraints.

Return type:

pd.DataFrame

This module holds utility functions and classes for the congrads package.

class congrads.utils.utility.CSVLogger(file_path: str, overwrite: bool = False, merge: bool = True)

Bases: object

A utility class for logging key-value pairs to a CSV file, organized by epochs.

Supports merging with existing logs or overwriting them.

Parameters:
  • file_path (str) – The path to the CSV file for logging.

  • overwrite (bool) – If True, overwrites any existing file at the file_path.

  • merge (bool) – If True, merges new values with existing data in the file.

Raises:
  • ValueError – If both overwrite and merge are True.

  • FileExistsError – If the file already exists and neither overwrite nor merge is True.

add_value(name: str, value: float, epoch: int)

Adds a value to the logger for a specific epoch and name.

Parameters:
  • name (str) – The name of the metric or value to log.

  • value (float) – The value to log.

  • epoch (int) – The epoch associated with the value.

load()

Loads data from the CSV file into the logger.

Converts the CSV data into the internal dictionary format for further updates or operations.

save()

Saves the logged values to the specified CSV file.

If the file exists and merge is enabled, merges the current data with the existing file.

static to_dataframe(values: dict[tuple[int, str], float]) DataFrame

Converts a dictionary of values into a DataFrame.

Parameters:

values (dict[tuple[int, str], float]) – A dictionary of values keyed by (epoch, name).

Returns:

A DataFrame where epochs are rows, names are columns, and values are the cell data.

Return type:

pd.DataFrame

static to_dict(df: DataFrame) dict[tuple[int, str], float]

Converts a CSVLogger DataFrame to a dictionary the format {(epoch, name): value}.

class congrads.utils.utility.DictDatasetWrapper(base_dataset: Dataset, field_names: list[str] | None = None)

Bases: Dataset

A wrapper for PyTorch datasets that converts each sample into a dictionary.

This class takes any PyTorch dataset and returns its samples as dictionaries, where each element of the original sample is mapped to a key. This is useful for integration with the Congrads toolbox or other frameworks that expect dictionary-formatted data.

base_dataset

The underlying PyTorch dataset being wrapped.

Type:

Dataset

field_names

Names assigned to each field of a sample. If None, default names like ‘field0’, ‘field1’, … are generated.

Type:

list[str] | None

Parameters:
  • base_dataset (Dataset) – The PyTorch dataset to wrap.

  • field_names (list[str] | None, optional) – Custom names for each field. If provided, the list is truncated or extended to match the number of elements in a sample. Defaults to None.

Example

Wrapping a TensorDataset with custom field names:

>>> from torch.utils.data import TensorDataset
>>> import torch
>>> dataset = TensorDataset(torch.randn(5, 3), torch.randint(0, 2, (5,)))
>>> wrapped = DictDatasetWrapper(dataset, field_names=["features", "label"])
>>> wrapped[0]
{'features': tensor([...]), 'label': tensor(1)}

Wrapping a built-in dataset like CIFAR10:

>>> from torchvision.datasets import CIFAR10
>>> from torchvision import transforms
>>> cifar = CIFAR10(
...     root="./data", train=True, download=True, transform=transforms.ToTensor()
... )
>>> wrapped_cifar = DictDatasetWrapper(cifar, field_names=["input", "output"])
>>> wrapped_cifar[0]
{'input': tensor([...]), 'output': tensor(6)}
class congrads.utils.utility.LossWrapper(loss_fn: Callable)

Bases: object

Wraps a loss function to optionally accept batch-level data.

This adapter allows both standard PyTorch loss functions (e.g. nn.MSELoss) and custom loss functions that accept an additional data keyword argument to be used interchangeably.

The wrapped loss can always be called with the same signature:

loss(output, target, data=batch)

If the underlying loss function does not accept data, the argument is silently ignored.

class congrads.utils.utility.Seeder(base_seed: int)

Bases: object

A deterministic seed manager for reproducible experiments.

This class provides a way to consistently generate pseudo-random seeds derived from a fixed base seed. It ensures that different libraries (Python’s random, NumPy, and PyTorch) are initialized with reproducible seeds, making experiments deterministic across runs.

roll_seed() int

Generate a new deterministic pseudo-random seed.

Each call returns an integer seed derived from the internal pseudo-random generator, which itself is initialized by the base seed.

Returns:

A pseudo-random integer seed in the range [0, 2**31 - 1].

Return type:

int

set_reproducible() None

Configure global random states for reproducibility.

Seeds the following libraries with deterministically generated seeds based on the base seed:
  • Python’s built-in random

  • NumPy’s random number generator

  • PyTorch (CPU and GPU)

Also enforces deterministic behavior in PyTorch by:
  • Seeding all CUDA devices

  • Disabling CuDNN benchmarking

  • Enabling CuDNN deterministic mode

class congrads.utils.utility.ZeroLoss(reduction: str = 'mean')

Bases: _Loss

A loss function that always returns zero.

This custom loss function ignores the input and target tensors and returns a constant zero loss, which can be useful for debugging or when no meaningful loss computation is required.

Parameters:

reduction (str, optional) – Specifies the reduction to apply to the output. Defaults to “mean”. Although specified, it has no effect as the loss is always zero.

forward(predictions: Tensor, target: Tensor, **kwargs) Tensor

Return a dummy loss of zero regardless of input and target.

congrads.utils.utility.process_data_monotonicity_constraint(data: Tensor, ordering: Tensor, identifiers: Tensor)

Reorders input samples to support monotonicity checking.

Reorders input samples such that: 1. Samples from the same run are grouped together. 2. Within each run, samples are sorted chronologically.

Parameters:
  • data (Tensor) – The input data.

  • ordering (Tensor) – On what to order the data.

  • identifiers (Tensor) – Identifiers specifying different runs.

Returns:

Sorted data, ordering, and identifiers.

Return type:

Tuple[Tensor, Tensor, Tensor]

congrads.utils.utility.split_data_loaders(data: Dataset, loader_args: dict = None, train_loader_args: dict = None, valid_loader_args: dict = None, test_loader_args: dict = None, train_size: float = 0.8, valid_size: float = 0.1, test_size: float = 0.1, split_generator: Generator = None) tuple[DataLoader, DataLoader, DataLoader]

Splits a dataset into training, validation, and test sets, and returns corresponding DataLoader objects.

Parameters:
  • data (Dataset) – The dataset to be split.

  • loader_args (dict, optional) – Default DataLoader arguments, merges with loader-specific arguments, overlapping keys from loader-specific arguments are superseded.

  • train_loader_args (dict, optional) – Training DataLoader arguments, merges with loader_args, overriding overlapping keys.

  • valid_loader_args (dict, optional) – Validation DataLoader arguments, merges with loader_args, overriding overlapping keys.

  • test_loader_args (dict, optional) – Test DataLoader arguments, merges with loader_args, overriding overlapping keys.

  • train_size (float, optional) – Proportion of data to be used for training. Defaults to 0.8.

  • valid_size (float, optional) – Proportion of data to be used for validation. Defaults to 0.1.

  • test_size (float, optional) – Proportion of data to be used for testing. Defaults to 0.1.

  • split_generator (Generator, optional) – Optional random seed generator to control the splitting of the dataset.

Returns:

A tuple containing three DataLoader objects: one for the training, validation and test set.

Return type:

tuple

Raises:

ValueError – If the train_size, valid_size, and test_size are not between 0 and 1, or if their sum does not equal 1.

Validation utilities for type checking and argument validation.

This module provides utility functions for validating function arguments, including type validation, callable validation, and PyTorch-specific validation functions.

congrads.utils.validation.validate_callable(name, value, allow_none=False)

Validate that a value is callable function.

Parameters:
  • name (str) – Name of the argument for error messages.

  • value – Value to validate.

  • allow_none (bool) – Whether to allow the value to be None. Defaults to False.

Raises:

TypeError – If the value is not callable.

congrads.utils.validation.validate_callable_iterable(name, value, allowed_iterables=(<class 'list'>, <class 'set'>, <class 'tuple'>), allow_none=False)

Validate that a value is an iterable containing only callable elements.

This function ensures that the given value is an iterable (e.g., list or set and that all its elements are callable functions.

Parameters:
  • name (str) – Name of the argument for error messages.

  • value – The value to validate.

  • allowed_iterables (tuple of types, optional) – Iterable types that are allowed. Defaults to (list, set).

  • allow_none (bool, optional) – Whether to allow the value to be None. Defaults to False.

Raises:

TypeError – If the value is not an allowed iterable type or if any element is not callable.

congrads.utils.validation.validate_comparator(name, value, comparator_map: dict)

Validate that a value is a callable PyTorch comparator function.

Parameters:
  • name (str) – Name of the argument for error messages.

  • value – Value to validate.

  • comparator_map – The comparator function map.

Raises:

TypeError – If the value is not a valid comparator.

congrads.utils.validation.validate_iterable(name, value, expected_element_types, allowed_iterables=(<class 'list'>, <class 'set'>, <class 'tuple'>), allow_empty=False, allow_none=False)

Validate that a value is an iterable (e.g., list, set) with elements of the specified type(s).

Parameters:
  • name (str) – Name of the argument for error messages.

  • value – Value to validate.

  • expected_element_types (type or tuple of types) – Expected type(s) for the elements.

  • allowed_iterables (tuple of types) – Iterable types that are allowed (default: list and set).

  • allow_empty (bool) – Whether to allow empty iterables. Defaults to False.

  • allow_none (bool) – Whether to allow the value to be None. Defaults to False.

Raises:

TypeError – If the value is not an allowed iterable type or if any element is not of the expected type(s).

congrads.utils.validation.validate_loaders(name: str, loaders: tuple[DataLoader, DataLoader, DataLoader])

Validates that loaders is a tuple of three DataLoader instances.

Parameters:
  • name (str) – The name of the parameter being validated.

  • loaders (tuple[DataLoader, DataLoader, DataLoader]) – A tuple of three DataLoader instances.

Raises:

TypeError – If loaders is not a tuple of three DataLoader instances or contains invalid types.

congrads.utils.validation.validate_type(name, value, expected_types, allow_none=False)

Validate that a value is of the specified type(s).

Parameters:
  • name (str) – Name of the argument for error messages.

  • value – Value to validate.

  • expected_types (type or tuple of types) – Expected type(s) for the value.

  • allow_none (bool) – Whether to allow the value to be None. Defaults to False.

Raises:

TypeError – If the value is not of the expected type(s).