How to efficiently calculate a running standard deviation? How to efficiently calculate a running standard deviation? python python

How to efficiently calculate a running standard deviation?


The answer is to use Welford's algorithm, which is very clearly defined after the "naive methods" in:

It's more numerically stable than either the two-pass or online simple sum of squares collectors suggested in other responses. The stability only really matters when you have lots of values that are close to each other as they lead to what is known as "catastrophic cancellation" in the floating point literature.

You might also want to brush up on the difference between dividing by the number of samples (N) and N-1 in the variance calculation (squared deviation). Dividing by N-1 leads to an unbiased estimate of variance from the sample, whereas dividing by N on average underestimates variance (because it doesn't take into account the variance between the sample mean and the true mean).

I wrote two blog entries on the topic which go into more details, including how to delete previous values online:

You can also take a look at my Java implement; the javadoc, source, and unit tests are all online:


The basic answer is to accumulate the sum of both x (call it 'sum_x1') and x2 (call it 'sum_x2') as you go. The value of the standard deviation is then:

stdev = sqrt((sum_x2 / n) - (mean * mean)) 

where

mean = sum_x / n

This is the sample standard deviation; you get the population standard deviation using 'n' instead of 'n - 1' as the divisor.

You may need to worry about the numerical stability of taking the difference between two large numbers if you are dealing with large samples. Go to the external references in other answers (Wikipedia, etc) for more information.


Here is a literal pure Python translation of the Welford's algorithm implementation from http://www.johndcook.com/standard_deviation.html:

https://github.com/liyanage/python-modules/blob/master/running_stats.py

import mathclass RunningStats:    def __init__(self):        self.n = 0        self.old_m = 0        self.new_m = 0        self.old_s = 0        self.new_s = 0    def clear(self):        self.n = 0        def push(self, x):        self.n += 1            if self.n == 1:            self.old_m = self.new_m = x            self.old_s = 0        else:            self.new_m = self.old_m + (x - self.old_m) / self.n            self.new_s = self.old_s + (x - self.old_m) * (x - self.new_m)                    self.old_m = self.new_m            self.old_s = self.new_s    def mean(self):        return self.new_m if self.n else 0.0    def variance(self):        return self.new_s / (self.n - 1) if self.n > 1 else 0.0        def standard_deviation(self):        return math.sqrt(self.variance())

Usage:

rs = RunningStats()rs.push(17.0)rs.push(19.0)rs.push(24.0)mean = rs.mean()variance = rs.variance()stdev = rs.standard_deviation()print(f'Mean: {mean}, Variance: {variance}, Std. Dev.: {stdev}')