Arithmetic Coding - kedartatwawadi/stanford_compression_library GitHub Wiki

Introduction

Issues with symbol coders

We know that Huffman coding is the best per-symbol encoder. But, there is a fundamental limitation with being per-symbol, which is that the overhead can be 1 bit per symbol over entropy. This can be very significant. One simple case of this is:

p = {A: 0.1, B: 0.9}, H(p) = 0.47

Even though the entropy is half a bit, we know that no per-symbol coder can do better than 1 bit per symbol. Thus, we are effectively going to use twice as many bits to encode the source if we use Huffman coding.

Possible idea: combining symbols

One possible solution to our problem is making a source out of multiple symbols.. for example if we make a new source out of tuples we get:

p2 = {AA: 0.01, AB: 0.09, BA: 0.09, BB: 0.81} 

This should perform better as we have a max-overhead of 1 bit for 2 symbols. In this case, Huffman coding tuples gives us the average bits/symbol as 0.65 which is better than 1.0 but still not equal to entropy of 0.47. We can of course go on and continue and combine three symbols etc. and eventually Huffman coding will converge to the entropy. However, this is computationally in-efficient as the number of symbols increases exponentially.

Arithmetic coding solves this problem, and essentially has an overhead which is theoretically just 2 bits over the optimal codelength for the entire sequence. Even practically, Arithmetic coding and its variants achieve incredible computational and compression performance.

Other benefits of Arithmetic coding

Along with with computationally efficient there are lots of desirable properties which arithmetic coding offers:

  1. Adaptability -> Arithmetic coding can use different distributions for different symbols, and still be optimal
  2. model/coding separation -> As Arithmetic coding is optimal in essentially any scenario, it separates the compression problem into two parts.. coming up with a model .. or a distribution for the data, and secondly encoding data using the distribution. Because of optimality of arithmetic coding, for a lot of purposes we can focus on the task of thinking about the model for the data

Theoretical Arithmetic coding

The core idea of Arithmetic encoding can be explained using the following two steps:

  1. STEP I: Represent the entire input sequence as an interval [low, high) within the interval [0,1]
|----------------[.......)------------|
0               low     high          1 
  1. Step II: Represent the [low, high) range using a single floating point number (the state) within the range which has a short binary expansion. For example, for low = 0.1, high = 0.6. One possibility for the state is state = 0.25 ~ 0.01b. As the state = 0.01b the final arithmetic code for the input becomes 01
log <= state < high

The decoder then has to perform the reverse operation to infer the input data. We will next try to understand how to decide the interval and then how to represent this interval using bits.

Infinite precision Arithmetic Encoder

NOTE: For this discussion we are going to assume we have infinite precision, and that we can represent any floating point number exactly

STEP I: getting the [low, high) range

The process the get the [low, high) range corresponding to the input is also known as the cake cutting method. It will be obvious why so!

We start with low=0, high=1, and then proceed to recursively shrink the range into a smaller range based on the input symbols. The code block to do this is given below: Lets take a concrete example, as this would be much more clear that way.

from core.prob_dist import ProbabilityDist

# define a sample distribution
prob = ProbabilityDist({'A': 0.2, 'B': 0.4, 'C': 0.4})

# define a sample input
data = DataBlock(['B', 'A', 'C', 'B'])

The code block to recursively compute the low, high range is given below:

# initalize low, high values
low, high = 0.0, 1.0

# recursively shrink the range
for s in data.data_list:
    rng = (high - low)
    low = low + prob.cumulative_prob_dict[s]*rng
    high = low + prob.probability(s)*rng

As one can see from the python snippet above:

  1. We split the current range rng = (high-low) into slices which are proportional to the probability of the symbols in the distribution. For example, initially when low=0.0, high=1.0, the slices are [0.0, 0.2), [0.2, 0.6), [0.6, 1.0), corresponding to A, B, C respectively.
  2. We continue this process until we are done will all the alphabets. The progression of the low, high values for the sample input are shown below
initial range: low 0.0000, high: 1.0000
0: symbol B, low 0.2000, high: 0.6000
1: symbol A, low 0.2000, high: 0.2800
2: symbol C, low 0.2480, high: 0.2800
3: symbol B, low 0.2544, high: 0.2672

Notice that the final range [0.2544, 0.2672) losslessly represents the entire input sequence ['B', 'A', 'C', 'B']. Thus, if the decoder knows this range, it can recover the entire sequence.

STEP II: Communicating the range

One way to communicate the range information is to communicate a number which lies inside the range [0.2544, 0.2672). One way to achieve this is as follows:

  1. We know that (low + high)/2 lies in the interval [low, high). Thus we want to communicate this floating point number. Let us call this the mid.
mid = (low + high)/2`
  1. Floating point mid can actually have infinite bits in binary (for example 1/3 in binary = b0.010101...). So, it can be impossible to do this. For example in our example:
from utils.bitarray_utils import float_to_bitarrays

# low ~ 0.2544, high ~ 0.2672
mid = (low + high)/2 #mid = 0.26080000000000003

_,float_bitarray = float_to_bitarrays(mid, max_precision=20)
# mid = b0.01000010110000111100...
  1. Note that we if we truncate the binary expansion of mid, then the resulting floating point value will be close to mid but will now be feasible to be represented. Thus, the the final step is to truncate the binary expansion of mid to sufficient number of bits so that the resulting fraction (lets call it the state) will still lie inside [low, high). We also want to be mindful not to use too many bits, as after all we want to compress the input :).

If we truncate the binary expansion of mid to k bits after the decimal point, then it is clear than the resulting fraction mid_k follows:

(mid - mid_k) < 2^{-k}

Thus we can calculate k so that mid_k lies in the range [low,high).

(mid - mid_k) < 2^{-k} <= |high - low|/2
which implies: 
k >= -log2(|high - low|) + 1 

In our example, we can thus calculate k and the state as:


import numpy as np

#low ~ 0.2544, high ~ 0.2672
k = np.ceil(-np.log2(high - low) + 1))

# get the truncated mid-point
_,code = float_to_bitarrays(mid, max_precision=int(k))
# >> code = bitarray('01000010')
# state = b0.01000010

Thus, the Arithmetic encoder has encoded the sequence ['B', 'A', 'C', 'B'] as 01000010, which is just 8 bits! The decoder operations should be clear, but we will explicitly look into that in the next section. One more question here however if the Arithmetic coder explained above is any good. i.e. how well does it compress data?

Infinite precision arithmetic decoding

References:

NOTE: I found these sequence of lectures on arithmetic coding extremely useful in understanding the intricacies.

  1. Introduction to Arithmetic coding A good introduction on what are the key benefits of arithmetic coding
  2. Arithmetic coding examples: Example 1, Example 2: the examples are quite useful to demonstrate what happens "theoretically" in arithmetic coding
  3. Why arithmetic coding intervals need to be contained: explains how we
  4. Rescaling operation for AEC: Explains the core intuition on how the rescaling occurs in arithmetic coding