pytorch_runstats
torch_runstats
implements memory-efficient online reductions on tensors. Notable features:
Arbitrary sample shapes beyond single scalars
Reduction over arbitrary dimensions of each sample
“Batched”/”binned” reduction into multiple running tallies using a per-sample bin index. This can be useful, for example, in accumulating statistics over samples by some kind of “type” index or for accumulating statistics per-graph in a
pytorch_geometric
-like batching scheme . (This feature is similar to torch_scatter .)Option to ignore NaN values with correct sample counting
Note
The implementations currently heavily uses in-place operations for peformance and memory efficiency. This probably doesn’t play nice with the autograd engine — this is currently likely the wrong library for accumulating running statistics you want to backward through. (See TorchMetrics for a possible alternative.)
Examples
Basic
import torch from torch_runstats import Reduction, RunningStats # Interspersed ones and zeros with a ratio of 2:1 ones to zeros data = torch.cat([torch.ones(5), torch.zeros(3), torch.ones(5), torch.zeros(2)]) data.unsqueeze_(-1) rs = RunningStats( dim=(1,), reduction=Reduction.MEAN, ) # Accumulate the statistics over the data in batches # Note that each call to accumulate_batch also returns the statistic for the current batch: print(rs.accumulate_batch(data[:5])) # => tensor([[1.]]) rs.accumulate_batch(data[5:7]) rs.accumulate_batch(data[7:13]) rs.accumulate_batch(data[13:]) print(rs.current_result()) # => tensor([[0.6667]]) # Accumulated data can be cleared rs.reset() # An empty object returns the identity for the reduction: print(rs.current_result()) # => tensor([[0.]])
“Binned”
A main feature of torch_runstats
is accumulating different samples in a batch into different “bins” — different running statistics — based on a provided index:
import torch
from torch_runstats import Reduction, RunningStats
data = torch.cat([torch.ones(5), torch.zeros(3), torch.ones(5), torch.zeros(2)])
data.unsqueeze_(-1)
sample_type = torch.cat([torch.zeros(8, dtype=torch.long), torch.ones(7, dtype=torch.long)])
rs = RunningStats(
dim=(1,),
reduction=Reduction.MEAN,
)
rs.accumulate_batch(data, accumulate_by=sample_type)
# The first entry is for "bin" (sample_type) 0, the second for 1:
print(rs.current_result()) # => tensor([[0.6250], [0.7143]])
# These values are what we expect:
print(5/8, 5/7) # => 0.625 0.714
Reduce over arbitrary dimensions
A reduction can also be taken over a sample dimension:
import torch
from torch_runstats import Reduction, RunningStats
data = torch.cat([torch.ones(5, 3, 2), torch.zeros(3, 3, 2)], dim=0)
rs = RunningStats(
dim=(3, 2),
reduction=Reduction.MEAN,
reduce_dims=0, # reduce the sample dimension of size 3
)
rs.accumulate_batch(data)
# Note that the reduction has a bin index (len 1),
# and the sample dimension of shape 2,
# but that the dimension of size 3 has been reduced out:
print(rs.current_result()) # => tensor([[0.6250, 0.6250]])
Ignore NaNs
When the ignore_nan
option is enabled, RunningStats
will only count and reduce over non-NaN elements:
import torch
from torch_runstats import Reduction, RunningStats
NaN = float("nan")
data = torch.Tensor([
[1.0, NaN, NaN],
[NaN, NaN, NaN],
[1.0, NaN, 1.0],
[1.0, 3.0, 1.0],
[1.0, NaN, NaN]
])
accumulate_by = torch.LongTensor([0, 0, 1, 1, 1])
rs = RunningStats(
dim=(3,),
reduction=Reduction.MEAN,
reduce_dims=0, # reduce the sample dimension of size 3
ignore_nan=True
)
rs.accumulate_batch(data, accumulate_by=accumulate_by)
# In the first bin, we see that the mean was taken over only one sample-
# the one non-NaN sample, giving a value of 1.0
#
# In the second bin, we see that we got the mean of the non-NaN
# elements: (1 * 5 + 3) / 6 = 1.33333...
print(rs.current_result()) # => tensor([1.0000, 1.3333])
Class Reference
Currently supported Reduction
s are:
- class torch_runstats.Reduction(value)
Enum indicating a reduction over \(N\) values \(x_i\).
Currently supported reductions:
Reduction.MEAN
: \(\frac{1}{N}\sum_i^{N}{x_i}\)
Reduction.RMS
: \(\sqrt{\frac{1}{N}\sum_i^{N}{x_i^2}}\)
Support for bincounting integers and combined one-pass mean/standard deviation are planned.
The core of the library is the RunningStats
class:
- class torch_runstats.RunningStats(dim: Union[int, Tuple[int, ...]] = 1, reduction: Reduction = Reduction.MEAN, reduce_dims: Union[int, Sequence[int]] = (), ignore_nan: bool = False)
Compute running statistics over batches of samples.
- Parameters:
dim – the shape of a single sample. If an integer, interpreted as
(dim,)
.reduction – the statistic to compute
reduce_dims –
extra dimensions within each sample to reduce over. If an integer, interpreted as
(reduce_dims,)
.This is a tuple of dimension indexes that are interpreted as dimension indexes within each sample:
reduce_dims=(1,)
implies that in a batch of size(N, A, B, C)
withdim = (A, B, C)
theN
andB
dimensions will be reduced over. (To reduce overA
instead, you would usereduce_dims=(0,)
to reduce over the first non-batch dimension.)By default an empty tuple, i.e., reduce only over the batch dimension.
ignore_nan – if True, NaNs in the data will be ignored, both in the accumulation and the sample count. If False (default), NaNs will propagate as normal.
- accumulate_batch(batch: torch.Tensor, accumulate_by: Optional[torch.Tensor] = None) torch.Tensor
Accumulate a batch of samples into the running statistics.
- Parameters:
batch – tensor of shape
(N_samples,) + self.dim
. The batch of samples to process.accumulate_by – tensor of indexes of shape
(N_samples,)
. If provided, the nth sample will be accumulated into theaccumulate_by[n]``th bin. If ``None
(the default), all samples will be accumulated into the first (0th) bin. The indexes should be non-negative integer.- Returns:
tensor of shape
(N_bins,) + self.output_dim
giving the aggregated statistics for this input batch. Accumulated statistics up to this point can be retreived withcurrent_result()
.
N_bins
isaccumulate_by.max() + 1
— the number of bins in the batch — and not the overall number of binsself.n_bins
.
- batch_result(batch: torch.Tensor, accumulate_by: Optional[torch.Tensor] = None) torch.Tensor
Accumulate a batch of samples into the running statistics.
- Parameters:
batch – tensor of shape
(N_samples,) + self.dim
. The batch of samples to process.accumulate_by – tensor of indexes of shape
(N_samples,)
. If provided, the nth sample will be accumulated into theaccumulate_by[n]``th bin. If ``None
(the default), all samples will be accumulated into the first (0th) bin. The indexes should be non-negative integer.- Returns:
tensor of shape
(N_bins,) + self.output_dim
giving the aggregated statistics for this input batch. Accumulated statistics up to this point can be retreived withcurrent_result()
.
N_bins
isaccumulate_by.max() + 1
— the number of bins in the batch — and not the overall number of binsself.n_bins
.
- current_result() torch.Tensor
Get the current value of the running statistic.
- Returns:
A tensor of shape
(self.n_bins,) + self.output_dim
. The nth bin contains the accumulated statistics for all processed samples whoseaccumulate_by
was n.
- property dim: Tuple[int, ...]
The shape of a single input sample for this
RunningStats
- property n: torch.Tensor
The number of samples processed so far in each bin.
- Returns:
A
LongTensor
of shape(self.n_bins,)
- property n_bins: int
The number of
accumulate_by
bins currently maintained by this object.
- property output_dim: Tuple[int, ...]
The shape of the output statistic for a single bin.
- property reduce_dims: Tuple[int, ...]
Indexes of dimensions in each sample that will be reduced.
- reset(reset_n_bins: bool = False) None
Forget all previously accumulated state.
This method does not forget
self.n_bins
unlessreset_n_bins
is True.
- Parameters:
reset_n_bins – whether to reset this object to one accumulation bin. This defaults to False on the assumption that a reset object will likely be used to process data with a similar or equal number of bins.
- to(device=None, dtype=None) None
Move this
RunningStats
to a new dtype and/or device.
- Parameters:
dtype – like
torch.Tensor.to
device – like
torch.Tensor.to