pytorch_runstats

torch_runstats implements memory-efficient online reductions on tensors. Notable features:

  • Arbitrary sample shapes beyond single scalars

  • Reduction over arbitrary dimensions of each sample

  • “Batched”/”binned” reduction into multiple running tallies using a per-sample bin index. This can be useful, for example, in accumulating statistics over samples by some kind of “type” index or for accumulating statistics per-graph in a pytorch_geometric-like batching scheme . (This feature is similar to torch_scatter .)

  • Option to ignore NaN values with correct sample counting

Note

The implementations currently heavily uses in-place operations for peformance and memory efficiency. This probably doesn’t play nice with the autograd engine — this is currently likely the wrong library for accumulating running statistics you want to backward through. (See TorchMetrics for a possible alternative.)

Examples

Basic

import torch
from torch_runstats import Reduction, RunningStats

# Interspersed ones and zeros with a ratio of 2:1 ones to zeros
data = torch.cat([torch.ones(5), torch.zeros(3), torch.ones(5), torch.zeros(2)])
data.unsqueeze_(-1)

rs = RunningStats(
   dim=(1,),
   reduction=Reduction.MEAN,
)
# Accumulate the statistics over the data in batches
# Note that each call to accumulate_batch also returns the statistic for the current batch:
print(rs.accumulate_batch(data[:5]))  # => tensor([[1.]])
rs.accumulate_batch(data[5:7])
rs.accumulate_batch(data[7:13])
rs.accumulate_batch(data[13:])
print(rs.current_result())  # => tensor([[0.6667]])

# Accumulated data can be cleared
rs.reset()
# An empty object returns the identity for the reduction:
print(rs.current_result())  # => tensor([[0.]])

“Binned”

A main feature of torch_runstats is accumulating different samples in a batch into different “bins” — different running statistics — based on a provided index:

import torch
from torch_runstats import Reduction, RunningStats

data = torch.cat([torch.ones(5), torch.zeros(3), torch.ones(5), torch.zeros(2)])
data.unsqueeze_(-1)
sample_type = torch.cat([torch.zeros(8, dtype=torch.long), torch.ones(7, dtype=torch.long)])

rs = RunningStats(
   dim=(1,),
   reduction=Reduction.MEAN,
)
rs.accumulate_batch(data, accumulate_by=sample_type)
# The first entry is for "bin" (sample_type) 0, the second for 1:
print(rs.current_result())  # => tensor([[0.6250], [0.7143]])
# These values are what we expect:
print(5/8, 5/7)  # => 0.625 0.714

Reduce over arbitrary dimensions

A reduction can also be taken over a sample dimension:

import torch
from torch_runstats import Reduction, RunningStats

data = torch.cat([torch.ones(5, 3, 2), torch.zeros(3, 3, 2)], dim=0)

rs = RunningStats(
   dim=(3, 2),
   reduction=Reduction.MEAN,
   reduce_dims=0,  # reduce the sample dimension of size 3
)
rs.accumulate_batch(data)
# Note that the reduction has a bin index (len 1),
# and the sample dimension of shape 2,
# but that the dimension of size 3 has been reduced out:
print(rs.current_result())  # => tensor([[0.6250, 0.6250]])

Ignore NaNs

When the ignore_nan option is enabled, RunningStats will only count and reduce over non-NaN elements:

import torch
from torch_runstats import Reduction, RunningStats

NaN = float("nan")

data = torch.Tensor([
   [1.0, NaN, NaN],
   [NaN, NaN, NaN],
   [1.0, NaN, 1.0],
   [1.0, 3.0, 1.0],
   [1.0, NaN, NaN]
])
accumulate_by = torch.LongTensor([0, 0, 1, 1, 1])

rs = RunningStats(
   dim=(3,),
   reduction=Reduction.MEAN,
   reduce_dims=0,  # reduce the sample dimension of size 3
   ignore_nan=True
)
rs.accumulate_batch(data, accumulate_by=accumulate_by)
# In the first bin, we see that the mean was taken over only one sample-
# the one non-NaN sample, giving a value of 1.0
#
# In the second bin, we see that we got the mean of the non-NaN
# elements: (1 * 5 + 3) / 6 = 1.33333...
print(rs.current_result())  # => tensor([1.0000, 1.3333])

Class Reference

Currently supported Reduction s are:

class torch_runstats.Reduction(value)

Enum indicating a reduction over \(N\) values \(x_i\).

Currently supported reductions:

  • Reduction.MEAN: \(\frac{1}{N}\sum_i^{N}{x_i}\)

  • Reduction.RMS: \(\sqrt{\frac{1}{N}\sum_i^{N}{x_i^2}}\)

Support for bincounting integers and combined one-pass mean/standard deviation are planned.

The core of the library is the RunningStats class:

class torch_runstats.RunningStats(dim: Union[int, Tuple[int, ...]] = 1, reduction: Reduction = Reduction.MEAN, reduce_dims: Union[int, Sequence[int]] = (), ignore_nan: bool = False)

Compute running statistics over batches of samples.

Parameters:
  • dim – the shape of a single sample. If an integer, interpreted as (dim,).

  • reduction – the statistic to compute

  • reduce_dims

    extra dimensions within each sample to reduce over. If an integer, interpreted as (reduce_dims,).

    This is a tuple of dimension indexes that are interpreted as dimension indexes within each sample: reduce_dims=(1,) implies that in a batch of size (N, A, B, C) with dim = (A, B, C) the N and B dimensions will be reduced over. (To reduce over A instead, you would use reduce_dims=(0,) to reduce over the first non-batch dimension.)

    By default an empty tuple, i.e., reduce only over the batch dimension.

  • ignore_nan – if True, NaNs in the data will be ignored, both in the accumulation and the sample count. If False (default), NaNs will propagate as normal.

accumulate_batch(batch: torch.Tensor, accumulate_by: Optional[torch.Tensor] = None) torch.Tensor

Accumulate a batch of samples into the running statistics.

Parameters:
  • batch – tensor of shape (N_samples,) + self.dim. The batch of samples to process.

  • accumulate_by – tensor of indexes of shape (N_samples,). If provided, the nth sample will be accumulated into the accumulate_by[n]``th bin. If ``None (the default), all samples will be accumulated into the first (0th) bin. The indexes should be non-negative integer.

Returns:

tensor of shape (N_bins,) + self.output_dim giving the aggregated statistics for this input batch. Accumulated statistics up to this point can be retreived with current_result().

N_bins is accumulate_by.max() + 1 — the number of bins in the batch — and not the overall number of bins self.n_bins.

batch_result(batch: torch.Tensor, accumulate_by: Optional[torch.Tensor] = None) torch.Tensor

Accumulate a batch of samples into the running statistics.

Parameters:
  • batch – tensor of shape (N_samples,) + self.dim. The batch of samples to process.

  • accumulate_by – tensor of indexes of shape (N_samples,). If provided, the nth sample will be accumulated into the accumulate_by[n]``th bin. If ``None (the default), all samples will be accumulated into the first (0th) bin. The indexes should be non-negative integer.

Returns:

tensor of shape (N_bins,) + self.output_dim giving the aggregated statistics for this input batch. Accumulated statistics up to this point can be retreived with current_result().

N_bins is accumulate_by.max() + 1 — the number of bins in the batch — and not the overall number of bins self.n_bins.

current_result() torch.Tensor

Get the current value of the running statistic.

Returns:

A tensor of shape (self.n_bins,) + self.output_dim. The nth bin contains the accumulated statistics for all processed samples whose accumulate_by was n.

property dim: Tuple[int, ...]

The shape of a single input sample for this RunningStats

property n: torch.Tensor

The number of samples processed so far in each bin.

Returns:

A LongTensor of shape (self.n_bins,)

property n_bins: int

The number of accumulate_by bins currently maintained by this object.

property output_dim: Tuple[int, ...]

The shape of the output statistic for a single bin.

property reduce_dims: Tuple[int, ...]

Indexes of dimensions in each sample that will be reduced.

property reduction: Reduction

The reduction computed by this object.

reset(reset_n_bins: bool = False) None

Forget all previously accumulated state.

This method does not forget self.n_bins unless reset_n_bins is True.

Parameters:

reset_n_bins – whether to reset this object to one accumulation bin. This defaults to False on the assumption that a reset object will likely be used to process data with a similar or equal number of bins.

to(device=None, dtype=None) None

Move this RunningStats to a new dtype and/or device.

Parameters:
  • dtype – like torch.Tensor.to

  • device – like torch.Tensor.to