Open In Colab

Linear Prediction#

The importance of linear prediction and coding#

Linear prediction and coding (LPC) play a crucial role in digital signal processing (DSP) and cognitive science. By making predictions about the future based on past experiences and sensory input, our brains are able to anticipate and respond to changes in the environment in a more efficient and adaptive manner. This process of prediction allows the brain to generate expectations about what will happen next, which helps to reduce surprise and uncertainty. Linear coding, on the other hand, involves representing sensory input in a way that can be efficiently processed and used for making predictions. By encoding sensory information in a linear fashion, a DSP system (or the brain) can extract relevant features and patterns from complex input and build a more accurate model of the world.

In computational cognitive science, LPC helps to simulate the optimization of the brain’s internal models and improve its ability to make accurate predictions about the environment. This, in turn, allows for more efficient action selection and decision-making, ultimately enhancing the brain’s ability to adapt and survive in a constantly changing world.

In DSP, LPC is used for signal modeling. It provides also an excellent introduction to parametric spectral estimation. For more information about LPC in DSP, please refer to

Below we provide a short introduction to the differntiable LPC.

Speech Decomposition with Source Filter Model#

In this example, we decompose a speech signal into its source \(e[n]\) and filter components \(a_k\), following the LPC model

\[ s[n] = e[n] + \sum_{k=1}^M a_k s[n-k] \]

We’ll first use the traditional method to estimate the LPC filter, and then we’ll use our differentiable LPC to do end-to-end decomposition.

Again, let’s first import the necessary packages and define some helper functions.

import torch
import torch.nn.functional as F
import torchaudio
import math
import numpy as np
from torchaudio.functional import lfilter
import matplotlib.pyplot as plt
from typing import Optional, Tuple, List, Union
from IPython.display import Audio
# !pip install diffsptk --quiet --force-reinstall
import diffsptk
def plot_t(
    title: str,
    ys: List[np.ndarray],
    labels: List[str] = None,
    scatter: bool = False,
    axhline: bool = False,
    x_label: str = "Samples",
    y_label: str = "Ampitude",
):
    for y, label in (
        zip(ys, labels) if labels is not None else zip(ys, [None] * len(ys))
    ):
        plt.plot(y, label=label) if not scatter else plt.scatter(
            np.arange(len(y)) + 1, y, label=label
        )
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    if label is not None:
        plt.legend()
    if axhline:
        plt.axhline(y=0, color="r", linestyle="dashed", alpha=0.5)


def plot_f(
    ys: List[np.ndarray] = None,
    paired_ys: List[Tuple[np.ndarray, np.ndarray]] = None,
    ys_labels: List[str] = None,
    paired_ys_labels: List[str] = None,
    sr: int = None,
):
    if ys is not None:
        for y, label in (
            zip(ys, ys_labels) if ys_labels is not None else zip(ys, [None] * len(ys))
        ):
            plt.magnitude_spectrum(
                y, Fs=sr, scale="dB", window=np.hanning(len(y)), label=label
            )
    if paired_ys is not None:
        for (f, y), label in (
            zip(paired_ys, paired_ys_labels)
            if paired_ys_labels is not None
            else zip(paired_ys, [None] * len(paired_ys))
        ):
            plt.plot(f, 20 * np.log10(np.abs(y)), label=label)
    plt.xlabel("Frequency (Hz)")
    plt.ylabel("Magnitude (dB)")
    plt.xlim(20, sr // 2)
    plt.title("Frequency spectrum")
    if ys_labels is not None or paired_ys_labels is not None:
        plt.legend()

We’re going to use a speech sample from the CMU Arctic speech synthesis database.

!wget "http://festvox.org/cmu_arctic/cmu_arctic/cmu_us_awb_arctic/wav/arctic_a0007.wav"
--2024-04-08 15:36:59--  http://festvox.org/cmu_arctic/cmu_arctic/cmu_us_awb_arctic/wav/arctic_a0007.wav
Resolving festvox.org (festvox.org)... 199.4.150.153
Connecting to festvox.org (festvox.org)|199.4.150.153|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 128044 (125K) [audio/x-wav]
Saving to: ‘arctic_a0007.wav.2’

arctic_a0007.wav.2  100%[===================>] 125,04K   239KB/s    in 0,5s    

2024-04-08 15:37:00 (239 KB/s) - ‘arctic_a0007.wav.2’ saved [128044/128044]
y, sr = torchaudio.load("arctic_a0007.wav")
y = y.squeeze()

plt.plot(np.arange(y.shape[0]) / sr, y.numpy())
plt.xlabel("Time [s]")
plt.ylabel("Amplitude")
plt.show()

Audio(y.numpy(), rate=sr)
../../_images/cb0bfc187dd2cb1d64be5b4688404a57c1cf34a1318d04c35cfec8f0ce1c0e85.png

Let’s pick up one short segment from the speech, with relatively static pitch and formants for a stationary model.

target = y[10000:11024]

fig = plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plot_t("Target signal", [target.numpy()])
plt.subplot(1, 2, 2)
plot_f([target.numpy()], sr=sr)
plt.show()
../../_images/990e7fc6337514e26fe156b363f0e4054e0339eb791f2471aebd3e462496770b.png

Classic LPC Estimation#

The common way to estimate the LPC filter is assuming the current sample \(s[n]\) can only be approximated from past samples. This results in minimising the prediction error \(e[n]\):

\[ \min_{a_k} \left( s[n] - \sum_{k=1}^M a_k s[n-k] \right)^2 = \min_{a_k} e[n]^2 \]

Its least squares solution can be computed from the autocorrelation of the signal {cite}=makhoul1975linear=. We’ll use the diffsptk package to compute this.

lpc_order = 18
frame_length = 1024

lpc = diffsptk.LPC(frame_length, lpc_order)
gain, coeffs = lpc(target).split([1, lpc_order], dim=-1)
print(f"Gain: {gain.item()}")
Gain: 0.23642899096012115
freq_response = (
    gain
    / torch.fft.rfft(torch.cat([coeffs.new_ones(1), coeffs]), n=frame_length)
    / frame_length
)

fig = plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plot_t("LPC Coefficients", [coeffs.numpy()], scatter=True, axhline=True, x_label="LPC order")
plt.ylim(-2, 2)
plt.subplot(1, 2, 2)
plot_f(
    ys=[target.numpy()],
    ys_labels=["target signal"],
    paired_ys=[
        (
            np.arange(frame_length // 2 + 1) / frame_length * sr,
            freq_response.numpy(),
        )
    ],
    paired_ys_labels=["filter response"],
    sr=sr,
)
plt.show()
../../_images/8433a58a6c362b196cba44f8ed198db681f876a93daf921d0402a584825cf064.png

We can get the source (or residual) \(e[n]\) by inverse filtering the signal with the LPC coefficients, which is equivalent to the filtering the signal with a FIR filter \([1, -a_1, -a_2, \dots, -a_M]\).

\[ e[n] = s[n] - \sum_{k=1}^M a_k s[n-k] \]
e = (
    target
    + F.conv1d(
        F.pad(target[None, None, :-1], (lpc_order, 0)), coeffs.flip(0)[None, None, :]
    ).squeeze()
)
e = e / gain

After cancelling the spectral envelope, the frequency response of the residual becomes flatter and has very equal energy across the spectrum. This is a result of the least squares optimisation, which assumes that the prediction error is white noise.

fig = plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plot_t("Residual", [e.numpy()])
plt.subplot(1, 2, 2)
plot_f([e.numpy()], sr=sr)
plt.show()
../../_images/f5ab554744e32ec655d3471d6fcac00317ac8973c926544c1506c1d35cba608f.png

Decomposing Speech with Differentiable LPC and a Glottal Flow Model#

In the above example, we have very little assumptions about the source \(e[n]\). We only assume that it is whilte-noise like. In the next example, we’re going to incorporate a glottal flow model to give more constraints to the source.

The model we’re going to use is the transformed-LF {cite}=fant1995lf= model, which models the periodic vibration of the vocal folds. Specifically, we’re using the derivative of the glottal flow model, which combines the glottal flow with lips radiation by assuming lips radiation is a first-order differentiator. This model has only one parameter \(R_d\), which is strongly correlated with the perceived vocal effort. Although the model is differentiable, for computational efficiency, we’re going to use a pre-computed lookup table to approximate the model.

def transformed_lf(Rd: torch.Tensor, points: int = 1024):
    # the implementation is adapted from https://github.com/dsuedholt/vocal-tract-grad/blob/main/glottis.py
    # Ra, Rk, and Rg are called R parameters in glottal flow modeling
    # We can infer the values of Ra, Rk, and Rg from Rd
    Rd = torch.as_tensor(Rd).view(-1, 1)
    Ra = -0.01 + 0.048 * Rd
    Rk = 0.224 + 0.118 * Rd
    Rg = (Rk / 4) * (0.5 + 1.2 * Rk) / (0.11 * Rd - Ra * (0.5 + 1.2 * Rk))

    # convert R parameters to Ta, Tp, and Te
    # Ta: The return phase duration
    # Tp: Time of the maximum of the pulse
    # Te: Time of the minimum of the time-derivative of the pulse
    Ta = Ra
    Tp = 1 / (2 * Rg)
    Te = Tp + Tp * Rk

    epsilon = 1 / Ta
    shift = torch.exp(-epsilon * (1 - Te))
    delta = 1 - shift

    rhs_integral = (1 / epsilon) * (shift - 1) + (1 - Te) * shift
    rhs_integral /= delta

    lower_integral = -(Te - Tp) / 2 + rhs_integral
    upper_integral = -lower_integral

    omega = torch.pi / Tp
    s = torch.sin(omega * Te)
    y = -torch.pi * s * upper_integral / (Tp * 2)
    z = torch.log(y)
    alpha = z / (Tp / 2 - Te)
    EO = -1 / (s * torch.exp(alpha * Te))

    t = torch.linspace(0, 1, points + 1)[None, :-1]
    before = EO * torch.exp(alpha * t) * torch.sin(omega * t)
    after = (-torch.exp(-epsilon * (t - Te)) + shift) / delta
    return torch.where(t < Te, before, after).squeeze()
t = torch.linspace(0, 1, 1024)
plt.plot(t, transformed_lf(0.3).numpy(), label="Rd = 0.3")
plt.plot(t, transformed_lf(0.5).numpy(), label="Rd = 0.5")
plt.plot(t, transformed_lf(0.8).numpy(), label="Rd = 0.8")
plt.plot(t, transformed_lf(2.7).numpy(), label="Rd = 2.7")
plt.title("Transformed LF")
plt.legend()
plt.xlabel("T (period)")
plt.ylabel("Amplitude")
plt.show()
../../_images/6d1be5636ad9c3f0696b0d45b2818f0cf0f2b1ab0aeefcb2a5134b9f4a9a7cfb.png
# 0.3 <= Rd <= 2.7 is a reasonable range for Rd
# we sampled them logarithmically for better resolution at lower values
table = transformed_lf(torch.exp(torch.linspace(math.log(0.3), math.log(2.7), 100)))

# align the peaks of the transformed LF for better optimisation
peaks = table.argmin(dim=-1)
shifts = peaks.max() - peaks
aligned_table = torch.stack(
    [torch.roll(table[i], shifts[i].item(), 0) for i in range(table.shape[0])]
)

plt.title("Transformed LF wavetables")
plt.imshow(aligned_table, aspect="auto", origin="lower")
plt.xlabel("T (samples)")
plt.ylabel("Table index")
plt.colorbar()
plt.show()
../../_images/9ead456327166a222d44a3bb2a0e17ba5d75b3e5a01c6433ef6fc74b27ed85c9.png

The full model we’re going to use is:

\[ s[n] = g \cdot w\left((\frac{n f_0}{f_s} + \phi) \mod 1; R_d \right) + \sum_{k=1}^M a_k s[n-k]. \]

We replace source \(e[n]\) with the following parameters: gain \(g\), fundamental frequency \(f_0\), phase offset \(\phi\), and \(R_d\). \(w\) is the pre-computed glottal flow model, and \(f_s\) is the sampling rate. Let’s define this model in code.

class SourceFilter(torch.nn.Module):
    def __init__(
        self,
        lpc_order: int,
        sr: int,
        table_points=1024,
        num_tables=100,
        init_f0: float = 100.0,
        init_offset: float = 0.0,
        init_log_gain: float = 0.0,
    ):
        super().__init__()

        Rd_sampled = torch.exp(torch.linspace(math.log(0.3), math.log(2.7), num_tables))
        table = transformed_lf(Rd_sampled, points=table_points)
        peaks = table.argmin(dim=-1)
        shifts = peaks.max() - peaks
        aligned_table = torch.stack(
            [torch.roll(table[i], shifts[i].item(), 0) for i in range(table.shape[0])]
        )
        self.register_buffer("table", aligned_table)
        self.register_buffer("Rd_sampled", Rd_sampled)

        self.f0 = torch.nn.Parameter(torch.tensor(init_f0))
        self.offset = torch.nn.Parameter(torch.tensor(init_offset))
        self.Rd_index_logits = torch.nn.Parameter(torch.zeros(1))
        self.log_gain = torch.nn.Parameter(torch.tensor(init_log_gain))

        # we use the reflection coefficients parameterisation for stable optimisation
        self.log_area_ratios = torch.nn.Parameter(torch.zeros(lpc_order))
        self.logits2lpc = torch.nn.Sequential(
            diffsptk.LogAreaRatioToParcorCoefficients(lpc_order),
            diffsptk.ParcorCoefficientsToLinearPredictiveCoefficients(lpc_order),
        )

        self.lpc_order = lpc_order
        self.table_points = table_points
        self.num_tables = num_tables
        self.sr = sr

    @property
    def Rd_index(self):
        return torch.sigmoid(self.Rd_index_logits) * (self.num_tables - 1)

    @property
    def Rd(self):
        return self.Rd_sampled[torch.round(self.Rd_index).long().item()]

    @property
    def gain(self):
        return torch.exp(self.log_gain)

    @property
    def filter_coeffs(self):
        return self.logits2lpc(
            torch.cat([self.log_gain.view(1), self.log_area_ratios])
        ).split([1, self.lpc_order])

    def source(self, steps):
        """
        Generate the gloottal pulse source signal
        """

        # select the wavetable using linear interpolation
        select_index_floor = self.Rd_index.long().item()
        p = self.Rd_index - select_index_floor
        selected_table = (
            table[select_index_floor] * (1 - p) + table[select_index_floor + 1] * p
        )

        # generate the source signal by interpolating the wavetable
        phase = (
            torch.arange(
                steps, device=selected_table.device, dtype=selected_table.dtype
            )
            / self.sr
            * self.f0
            + self.offset
        ) % 1
        phase_index = phase * self.table_points
        # append the first sample to the end for easier interpolation
        padded_table = torch.cat([selected_table, selected_table[:1]])
        phase_index_floor = phase_index.long()
        phase_index_ceil = phase_index_floor + 1
        p = phase_index - phase_index_floor
        glottal_pulse = (
            padded_table[phase_index_floor] * (1 - p)
            + padded_table[phase_index_ceil] * p
        )
        return glottal_pulse

    def forward_filt(self, e):
        """
        Apply the LPC filter to the input signal
        """
        # get filter coefficients
        log_gain, lpc_coeffs = self.filter_coeffs

        # IIR filtering
        b = log_gain.new_zeros(1 + lpc_coeffs.shape[-1])
        b[0] = torch.exp(log_gain)
        a = torch.cat([lpc_coeffs.new_ones(1), lpc_coeffs])
        return lfilter(e, a, b, clamp=False)

    def forward(self, steps):
        """
        Generate the speech signal
        """
        return self.forward_filt(self.source(steps))

    def inverse_filt(self, s):
        """
        Inverse filtering
        """
        # get filter coefficients
        _, lpc_coeffs = self.filter_coeffs

        e = (
            s
            + F.conv1d(
                F.pad(s[None, None, :-1], (self.lpc_order, 0)),
                lpc_coeffs.flip(0)[None, None, :],
            ).squeeze()
        )
        e = e / self.gain
        return e

Proper initialisation of the parameters plays an important role in the optimisation. We’re going to use the following initialisation.

model = SourceFilter(lpc_order, sr, init_f0=130.0, init_offset=0.0, init_log_gain=-1.3)
print(f"Gain: {model.gain.item()}")
print(f"Rd: {model.Rd.item()}")
print(f"f0: {model.f0.item()}")
print(f"Offset: {model.offset.item() % 1}")
Gain: 0.27253180742263794
Rd: 0.9100430011749268
f0: 130.0
Offset: 0.0
with torch.no_grad():
    output = model(1024)


fig = plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plot_t("Initial prediction", [output.numpy(), target.numpy()], labels=["predict (initial)", "target"])
plt.subplot(1, 2, 2)
plot_f(
    ys=[output.numpy(), target.numpy()],
    ys_labels=["predict (initial)", "target"],
    sr=sr,
)
plt.show()
../../_images/df6a40f8b5535af39ebe236b24a7848a48fbe0b4bb67412662de759c5ed4746d.png

Let’s optimise the parameters with gradient descent. We’re going to use the famous Adam optimiser with a learning rate of 0.001 and run it for 2000 iterations. The loss function we’re going to use is the L1 loss between the original signal and the modelled signal.

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

losses = []
for _ in range(2000):
    optimizer.zero_grad()
    output = model(1024)
    loss = F.l1_loss(output, target)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

plt.plot(losses)
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.show()
../../_images/670bbf17d809cfcaad4e6df8894ccb4622cf5f0ac7e50fdf3c280515a2afdd3a.png
with torch.no_grad():
    final_output = model(1024)

print(f"Gain: {model.gain.item()}")
print(f"Rd: {model.Rd.item()}")
print(f"f0: {model.f0.item()}")
print(f"Offset: {model.offset.item() % 1}")

fig = plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plot_t("Final prediction", [final_output.numpy(), target.numpy()], labels=["predict (optimised)", "target"])
plt.subplot(1, 2, 2)
plot_f(
    ys=[final_output.numpy(), target.numpy()],
    ys_labels=["predict (optimised)", "target"],
    sr=sr,
)
plt.show()
Gain: 0.18439224362373352
Rd: 1.5502203702926636
f0: 131.0269775390625
Offset: 0.9482236728072166
../../_images/38cbb60052bb06eec902a40160bb091e780af1c4cba45ba0f8c6f501d4475911.png

Wow, this is pretty good! We can see that the model reconstructs the original signal quite well with very similar waveforms. Moreover, the model tells what are the optimal parameters to construct the source signal. Let’s see what is the source signal looks like.

with torch.no_grad():
    e = model.source(1024)
    s = model.forward_filt(e)

fig = plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plot_t("Waveform", [e.numpy() / 4, s.numpy()], labels=["e[n]", "s[n]"])
plt.subplot(1, 2, 2)
plot_f(
    ys=[e.numpy() / 4, s.numpy()],
    ys_labels=["e[n]", "s[n]"],
    sr=sr,
)
plt.show()
../../_images/6e54bb984a882c56d5760ebf027f0f963b4014253a66dd96d1820748d44a863f.png

Let’s compare the spectrum of the two filters.

_, lpc_coeffs = model.filter_coeffs
with torch.no_grad():
    freq_response_opt = (
        model.gain
        / torch.fft.rfft(
            torch.cat([lpc_coeffs.new_ones(1), lpc_coeffs]), n=frame_length
        )
        / frame_length
    )


fig = plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plot_t(
    "LPC Coefficients",
    [coeffs.numpy(), lpc_coeffs.detach().numpy()],
    labels=["least squares LPC", "differentiable LPC"],
    scatter=True,
    axhline=True,
    x_label="LPC order",
)
plt.ylim(-2, 2)
plt.subplot(1, 2, 2)
freqs = np.arange(frame_length // 2 + 1) / frame_length * sr
plot_f(
    paired_ys=[
        (
            freqs,
            freq_response.numpy(),
        ),
        (
            freqs,
            freq_response_opt.numpy(),
        ),
    ],
    paired_ys_labels=["least squares LPC", "differentiable LPC"],
    sr=sr,
)
plt.show()
../../_images/40e2508bdcf4b8a19f19e58db97bd39023e5f8925cf591c059a6221b0cc30dcf.png

Conclusion#

Interestingly, the two filters looks very different. The biggest reason is because we restricted the source signal to have specific shapes. The gradient method also can not achieve a lossless decomposition, while the classic LPC method can. However, the source signal we get from the gradient method is much more interpretable. In fact, the latter method is a simplified version of the synthesiser used in GOLF vocoder proposed by &Yu-2023.

References#

Yu, Chin-Yun, and György Fazekas. 2023. “Singing Voice Synthesis Using Differentiable LPC and Glottal-Flow-Inspired Wavetables.” arXiv. https://doi.org/10.48550/arXiv.2306.17252.

John Makhoul. Linear prediction: a tutorial review. Proceedings of the IEEE, 63(4):561–580, 1975.