The Holt-Winters method is a popular and effective approach to forecasting seasonal time series. But different implementations will give different forecasts, depending on how the method is initialized and how the smoothing parameters are selected.

I tried finding a good implementation of Holt-Winters method in Java or Python but could not find anything useful. So I went ahead and implemented one in Java. You can easily translate the code to Python. I have used the NIST method to calculate the forecast.

More details are available at: NIST website

I tried finding a good implementation of Holt-Winters method in Java or Python but could not find anything useful. So I went ahead and implemented one in Java. You can easily translate the code to Python. I have used the NIST method to calculate the forecast.

More details are available at: NIST website

Code snippet (visit Github for latest code):

```
public static double[] forecast(int[] y, double alpha, double beta,
double gamma, int period, int m, boolean debug) {
....
int seasons = y.length / period;
double a0 = calculateInitialLevel(y, period);
double b0 = calculateInitialTrend(y, period);
double[] initialSeasonalIndices = calculateSeasonalIndices(y, period, seasons);
double[] forecast = calculateHoltWinters(y, a0, b0, alpha, beta, gamma,
initialSeasonalIndices, period, m, debug);
return forecast;
}
```

The calculateHoltWinters method implements the Holt-Winters equations.

```
private static double[] calculateHoltWinters(int[] y, double a0, double b0, double alpha, double beta, double gamma, double[] initialSeasonalIndices, int period, int m, boolean debug) {
double[] St = new double[y.length];
double[] Bt = new double[y.length];
double[] It = new double[y.length];
double[] Ft = new double[y.length + m];
//Initialize base values
St[1] = a0;
Bt[1] = b0;
for (int i = 0; i < period; i++) {
It[i] = initialSeasonalIndices[i];
}
Ft[m] = (St[0] + (m * Bt[0])) * It[0];//This is actually 0 since Bt[0] = 0
Ft[m + 1] = (St[1] + (m * Bt[1])) * It[1];//Forecast starts from period + 2
//Start calculations
for (int i = 2; i < y.length; i++) {
//Calculate overall smoothing
if((i - period) >= 0) {
St[i] = alpha * y[i] / It[i - period] + (1.0 - alpha) * (St[i - 1] + Bt[i - 1]);
} else {
St[i] = alpha * y[i] + (1.0 - alpha) * (St[i - 1] + Bt[i - 1]);
}
//Calculate trend smoothing
Bt[i] = gamma * (St[i] - St[i - 1]) + (1 - gamma) * Bt[i - 1];
//Calculate seasonal smoothing
if((i - period) >= 0) {
It[i] = beta * y[i] / St[i] + (1.0 - beta) * It[i - period];
}
//Calculate forecast
if( ((i + m) >= period) ){
Ft[i + m] = (St[i] + (m * Bt[i])) * It[i - period + m];
}
return Ft;
}
```

A simple way to invoke the forecast method:

```
public static void testRunNISTData() {
int[] y = {362,385,432,341,382,409,498,387,473,513,582,474,544,582,681,557,628,707,773,592,627,725,854,661};
double alpha = 0.06;
double beta = 0.98;
double gamma = 0.48;
int period = 4;
int m = 4;
double[] prediction = HoltWintersTripleExponentialImpl.forecast(y, alpha, beta, gamma, period, m, true);
System.out.println(String.format("MSE: %f", TSAError.calculateMSE(y, prediction, period, m, false)));
}
```

The code can be downloaded from Github: Code

Forecast calculated for NIST data (forecast starts from period 6):

MSE: 1384.316511

A Ruby port is available at: Github

The code is available under Apache License, Version 2.0. Feel free to use it. Please leave your comments below.

Nice!

ReplyDeleteYou should create a simple Github project for this or at least post it on Github/Gist.

Github project hosted at https://github.com/nchandra/ExponentialSmoothing

DeleteI took the liberty of posting on Apache Commons User list - http://markmail.org/thread/a5uvdnsh26tpvk7m

ReplyDeleteIt would be nice if you could send them a patch.

Hi, thanks for the algorithm. What about parameter estimation? Are there any open-source java libraries that would allow to easily estimate the parameters?

ReplyDeleteCheers

Two improvements in your code possible

ReplyDelete- input data should be double(more flexibility) instead of int

- your code doesnot work for m > period as I array goes out of bounds