SPIS Workshop#
Audio Processing from a Machine Learning Perspective#
Differentiable Digital Signal Processing (DDSP)#
Anders R. Bargum (PhD Student), Cumhur Erkut, Monday 7th of April, 2025#
Welcome to the main notebook of the workshop “Audio Processing from a Machine Learning Perspective”, created for the “Signal Processing for Interactive Systems” course at Aalborg University Copenhagen. During this workshop we will cover how machine learning principles can be applied to audio, specifically for audio-effects and audio-processing. Some of you may be familiar with “neural networks”, which are models that operate on an input using different layers of functions, additions and multiplications in order to predict a specific target. This workshop is NOT a walkthrough of neural networks. Rather, we will look at machine-learning principles from a signal-based approach, more specifically called “Differential Digital Signal Processing” (DDSP).

This notebook will cover the following:
Introduction: What is DDSP and how can we use it?
PyTorch and differentiability: A quick recap
Toy problem: A differentiable gain control
Loss functions: L1, MSE, Spectral Loss
Other use-cases: Filter design, waveshaping etc.
Optimizing physical models: Applicable to SMC physical modelling class
Much of this notebook is based on the workshop “Introduction to DDSP for Audio Synthesis” by Ben Hayes, Jordie Shier, Chin-Yun Yu, David Südholt, Rodrigo Diaz (https://intro2ddsp.github.io/intro.html#).
I additionally refer you to the following work for more information:
DDSP, Differentiable Digital Signal Processing: https://magenta.tensorflow.org/ddsp (2019)
Kuznetsov et. al: Differentiable IIR Filters for Machine Learning Applications (2020)
Hayes et. al: A Review of Differentiable Digital Signal Processing for Music & Speech Synthesis (2023)
Steinmetz et. al: Style Transfer of Audio Effects with Differentiable Signal Processing https://csteinmetz1.github.io/DeepAFx-ST/ (2022)
Bargum et. al: Differentiable Allpass Filters for Phase Response Estimation and Automatic Signal Alignment (2023)
Let’s start by installing any needed packages:
try:
import google.colab
IN_COLAB = True
!wget https://raw.githubusercontent.com/SMC-AAU-CPH/SPIS/refs/heads/main/08-Workshop/utils.py
!mkdir -p sound-files
!wget https://raw.githubusercontent.com/SMC-AAU-CPH/SPIS/refs/heads/main/08-Workshop/sound-files/guitar.wav -P sound-files
!wget https://raw.githubusercontent.com/SMC-AAU-CPH/SPIS/refs/heads/main/08-Workshop/sound-files/guitar-nsynth.wav -P sound-files
except:
IN_COLAB = False
--2025-04-07 13:09:58-- https://raw.githubusercontent.com/SMC-AAU-CPH/SPIS/refs/heads/main/08-Workshop/utils.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3628 (3.5K) [text/plain]
Saving to: ‘utils.py’
utils.py 100%[===================>] 3.54K --.-KB/s in 0s
2025-04-07 13:09:58 (32.7 MB/s) - ‘utils.py’ saved [3628/3628]
--2025-04-07 13:09:58-- https://raw.githubusercontent.com/SMC-AAU-CPH/SPIS/refs/heads/main/08-Workshop/sound-files/guitar.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1583068 (1.5M) [audio/wav]
Saving to: ‘sound-files/guitar.wav’
guitar.wav 100%[===================>] 1.51M --.-KB/s in 0.04s
2025-04-07 13:09:59 (35.2 MB/s) - ‘sound-files/guitar.wav’ saved [1583068/1583068]
--2025-04-07 13:09:59-- https://raw.githubusercontent.com/SMC-AAU-CPH/SPIS/refs/heads/main/08-Workshop/sound-files/guitar-nsynth.wav
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 128044 (125K) [audio/wav]
Saving to: ‘sound-files/guitar-nsynth.wav’
guitar-nsynth.wav 100%[===================>] 125.04K --.-KB/s in 0.02s
2025-04-07 13:09:59 (7.34 MB/s) - ‘sound-files/guitar-nsynth.wav’ saved [128044/128044]
!pip install ipython
!pip install torch
!pip install numpy
!pip install matplotlib
!pip install numpy
!pip install librosa
!pip install torchaudio
Requirement already satisfied: ipython in /usr/local/lib/python3.11/dist-packages (7.34.0)
Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython) (75.2.0)
Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython) (0.19.2)
Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython) (4.4.2)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython) (0.7.5)
Requirement already satisfied: traitlets>=4.2 in /usr/local/lib/python3.11/dist-packages (from ipython) (5.7.1)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from ipython) (3.0.50)
Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from ipython) (2.18.0)
Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython) (0.2.0)
Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.11/dist-packages (from ipython) (0.1.7)
Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython) (4.9.0)
Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython) (0.8.4)
Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.11/dist-packages (from pexpect>4.3->ipython) (0.7.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython) (0.2.13)
Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)
Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)
Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.13.0)
Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.4.2)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)
Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.3.2)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)
Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)
Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)
Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)
Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)
Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)
Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)
Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)
Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)
Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)
Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)
Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (2.0.2)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.11/dist-packages (3.10.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (1.3.1)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (4.56.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (1.4.8)
Requirement already satisfied: numpy>=1.23 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (2.0.2)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (24.2)
Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (11.1.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (3.2.3)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/dist-packages (from matplotlib) (2.8.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.7->matplotlib) (1.17.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (2.0.2)
Requirement already satisfied: librosa in /usr/local/lib/python3.11/dist-packages (0.11.0)
Requirement already satisfied: audioread>=2.1.9 in /usr/local/lib/python3.11/dist-packages (from librosa) (3.0.1)
Requirement already satisfied: numba>=0.51.0 in /usr/local/lib/python3.11/dist-packages (from librosa) (0.60.0)
Requirement already satisfied: numpy>=1.22.3 in /usr/local/lib/python3.11/dist-packages (from librosa) (2.0.2)
Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from librosa) (1.14.1)
Requirement already satisfied: scikit-learn>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from librosa) (1.6.1)
Requirement already satisfied: joblib>=1.0 in /usr/local/lib/python3.11/dist-packages (from librosa) (1.4.2)
Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.11/dist-packages (from librosa) (4.4.2)
Requirement already satisfied: soundfile>=0.12.1 in /usr/local/lib/python3.11/dist-packages (from librosa) (0.13.1)
Requirement already satisfied: pooch>=1.1 in /usr/local/lib/python3.11/dist-packages (from librosa) (1.8.2)
Requirement already satisfied: soxr>=0.3.2 in /usr/local/lib/python3.11/dist-packages (from librosa) (0.5.0.post1)
Requirement already satisfied: typing_extensions>=4.1.1 in /usr/local/lib/python3.11/dist-packages (from librosa) (4.13.0)
Requirement already satisfied: lazy_loader>=0.1 in /usr/local/lib/python3.11/dist-packages (from librosa) (0.4)
Requirement already satisfied: msgpack>=1.0 in /usr/local/lib/python3.11/dist-packages (from librosa) (1.1.0)
Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from lazy_loader>=0.1->librosa) (24.2)
Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /usr/local/lib/python3.11/dist-packages (from numba>=0.51.0->librosa) (0.43.0)
Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.11/dist-packages (from pooch>=1.1->librosa) (4.3.7)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.11/dist-packages (from pooch>=1.1->librosa) (2.32.3)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=1.1.0->librosa) (3.6.0)
Requirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.11/dist-packages (from soundfile>=0.12.1->librosa) (1.17.1)
Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0->soundfile>=0.12.1->librosa) (2.22)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->pooch>=1.1->librosa) (3.4.1)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->pooch>=1.1->librosa) (3.10)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->pooch>=1.1->librosa) (2.3.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->pooch>=1.1->librosa) (2025.1.31)
Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)
Requirement already satisfied: torch==2.6.0 in /usr/local/lib/python3.11/dist-packages (from torchaudio) (2.6.0+cu124)
Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (3.18.0)
Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (4.13.0)
Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (3.4.2)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (3.1.6)
Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (2025.3.2)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (12.4.127)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (12.4.127)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (12.4.127)
Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (9.1.0.70)
Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (12.4.5.8)
Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (11.2.1.3)
Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (10.3.5.147)
Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (11.6.1.9)
Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (12.3.1.170)
Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (0.6.2)
Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (2.21.5)
Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (12.4.127)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (12.4.127)
Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (3.2.0)
Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch==2.6.0->torchaudio) (1.13.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch==2.6.0->torchaudio) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch==2.6.0->torchaudio) (3.0.2)
Import packages#
from utils import plot_graph, get_sine, plot_tf, DIIRDataSet, DIIR_WRAPPER
import IPython.display as ipd
import torch
import numpy as np
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from matplotlib import pyplot as plt
from torch.nn import Module, Parameter
from torch import FloatTensor
from numpy.random import uniform
from torch.utils.data import DataLoader
import json
import librosa
import torchaudio
try:
import google.colab
IN_COLAB = False
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print("Using device:", device)
except:
IN_COLAB = True
Using device: cuda
Introduction: Differential Digital Signal Processing (DDSP)#
From a mathematical perspective a differentiable function is a function whose derivative exists at all points in its domain i.e. that we can take the derivate of function \(f(x)\) no matter what value \(x\) takes. In short, when we differentiate we find the rate of change of a function \(f(x)\) with respect to its input \(x\) (the rise in \(y\) with respect to the rise in \(x\)). Basically, we are finding the slope of the tangent line at the specific point of \(x\). When the slope of the tangent is 0, it indicates a critical point, which could be a minimum, maximum, or a saddle point. In many ML cases we want to find the minimum of a function.

In a machine learning context I like to think of differentiable as something that is “updatable”.#
In a signal processing manner, this means that we can take a signal processor with different parameters (being a function), and approximate the specific parameters of a given behaviour. We do this by implementing the function and automatically updating the parameters such that we reach a specific minimum (can only be done if the signal processing function itself is differentiable). With this in mind, we can implement a signal processor, or a chain of different signal processors, and automatically update its internal parameters using an automatic differentiation framework such as TensorFlow, PyTorch, or Jax (which does all the troublesome differentiation-work for you).
Lets look at an example:
A simple Linear Gain Effect#
A linear gain effect is a really simple and good example of how differentiable signal processing works, as it presents a differentiable system (we can take the derivative of \(f(x)\)) with an obvious parameter (the gain factor \(g\)). Normally, it would be very easy for us as developers just to change the value \(g\) to find the desired gain.
But what if the system was part of a larger sub-system? In that case it would not be as easy. Or what if we were to create a specific frequency response using a filter with coefficients a1, a2, a3, b1, b2, b3? Then it would also be difficult to tune the coefficients by hand.
In the case of the linear gain effect, we can finde the gain value \(g\) using differentiability.
We need:
An input, could be anything from a sine wave to a complex instrument signal
An output, the same as the input signal but affected by the system we want to approximate (lets say at half the amplitude value of the input)
The system we want to approximate, implemented for differentiation (using nn.Module in PyTorch)
A loss function that can compare how far our target is from our prediction (this is the function we want to find the minimum of)
Gradients telling us how far the given parameters are from minimizing the loss function
Following the pipeline below, we can then recursively keep updating gain parameter \(g\) until the output of the gain-system matches that of the target.

PyTorch and Differentiability#
To implement above system we can use PyTorch. PyTorch provides many utilities around neural networks and deep learning, but at its very core it consists of two main features: GPU-accelerated linear algebra operations, and automatic differentiation.
Let’s quickly recap the basics of PyTorch.
Tensors#
Tensors are the basic data structure in PyTorch. They are similar to numpy arrays, but offer support for the two main features mentioned above.
# Create a scalar (0D tensor)
scalar = torch.tensor(42, dtype=torch.float32) # explicitly enforce float type
print("a =", scalar)
# Create a vector (1D tensor)
vector = torch.tensor([1., 2., 3.])
print("v =", vector)
# Create a matrix (2D tensor)
matrix = torch.tensor([[1., 2., 3.], [4., 5., 6.]])
print("M =", matrix)
a = tensor(42.)
v = tensor([1., 2., 3.])
M = tensor([[1., 2., 3.],
[4., 5., 6.]])
Basic arithmetic operations are applied element-wise:
print("v + v =", vector + vector) #addition of two vectors
print("v * v =", vector * vector) #multiplication of two vectors (element-wise)
print("M * v =", matrix * vector) #broadcasting (expanding dimensions automatically)
print("M x v", torch.matmul(matrix, vector)) #matrix multiplication/dot product ((2x3) * (3x1) = (2x1))
print("M @ v =", matrix @ vector) #matrix multiplication/dot product ((2x3) * (3x1) = (2x1))
print("v^2 =", vector ** 2) #take the power of 2 element-wise
print("exp(v) =", torch.exp(vector)) #take the exponent element-wise
print("sin(v) =", torch.sin(vector)) #take the sine element-wise
Gradients and Auto-Differentiation#
As in the linear gain example, we are interested in a value that minimize the difference between the system-output and the target (also called the loss or the error). In essence, the loss is an objective function that we try to minimize. In order to do this we can use automatic differentiation: if we know the gradient of a function with respect to its inputs (the slope of all tangent lines), we know that adjusting the inputs in the opposite direction of the gradient will decrease the value of the function and go towards a minimum. This is called gradient descent optimization.
To let PyTorch know that we want to compute the gradient with respect to a certain tensor, we need to set the requires_grad flag. Let’s take a look at what happens when we do this and perform operations on the tensor to calculate: y = g * x.
#We initalise the input at a random value
x = torch.tensor(0.1, requires_grad=True)
#We intialise g at a random value
g = torch.tensor(0.8, requires_grad=True)
#We define the function we want to optimise
y = g * x
#We see that every tensor now carries an attribute grad_fn that describes how to compute the gradient of the operation that it resulted from.
print("g:", g)
print("x:", x)
print("y:", y)
# In the backward pass, this graph is used to compute the gradient of the final output with respect to the initial inputs.
# The backward pass can be triggered by calling the backward() method on the final output.
y.backward()
print("dy/dx evaluated at g=0.8:", x.grad.numpy())
print("\nThis makes sense as the derivative of our function g * x with respect to x equals to g")
Lets look at a more “complicated” example, with function:
\(\begin{aligned} z = sin((w + 1)^3) \end{aligned}\).
We can either split the function into sub-computations - using the chain rule from calculus we can calculate the gradient of the final output with respect to the input by decomposing it into a product of gradients from the sub-computations:
\(\dfrac{dz}{dw} = \dfrac{dz}{dy} \cdot \dfrac{dy}{dx} \cdot \dfrac{dx}{dw}\)
Or, we can do it directly on the function.
#sub computations
w = torch.tensor(1., requires_grad=True)
x = w + 1
y = x ** 3
z = torch.sin(y)
z.backward()
print("dz/dw evaluated at w=1:", w.grad)
#directly
w = torch.tensor(1., requires_grad=True)
z = torch.sin(torch.pow((w+1),3))
z.backward()
print("dz/dw evaluated at w=1:", w.grad)
This is the essence of auto-differentiation. In the forward pass, PyTorch builds a computational graph of operations that know how to compute their gradients locally (as done in the sub-computations). In the backward pass, this graph is used to compute the gradient of the final output with respect to the initial inputs
Optimizers#
Now that we know how to compute gradients, we can use them to find the parameters that minimize some objective function. In the most basic version of gradient descent, we update our estimate of parameters of a function according to the following rule:
\(x \leftarrow x - \lambda \nabla_x f(x)\)
Where x is the parameter we want to update, \(\nabla_x f(x)\) is the gradient of the function \(f\) with respect to x, and \(\lambda\) is a small scalar defining how much we want to update our parameter based on the gradient (often very small, as we do want to overshoot). \(\lambda\) is also called the “learning rate”.
The function above is called Stochastic Gradient Descent (SGD), whereas the update itself is called a “step”. SGD and step are automatically implemented in PyTorch.
As seen earlier, the gradient points into the direction of steepest ascent (increase), so we need to subtract it from \(x\) to move in the direction of steepest descent (decrease), i.e. towards the minimum of \(f\) (remember that the minimum is where the loss is 0 i.e. where the processed input is similar to the target).
Let’s look at a simple example of finding the minimum of a function.#
We use the function \(f(x) = x^2 - 4x + 2x - 1\)
Analytically we can find the minimum of \(f(x)\) by taking its derivative, setting it to zero and solving for x:
\(\frac{dy}{dx} = 2x - 4 + 2\)
\(0 = 2x - 4 + 2 \rightarrow x = 1\)
Let’s try to find the same minimum using gradient descent and automatic differentiation. More specifically we update the parameters based on the gradient of x with respect to y iteratively. As a default value we choose 500 iterations (we can also experiment with different learning rates and see how it affects the updated parameter).
# Initial estimate
x = torch.tensor(0., requires_grad=True)
# Initialize optimizer with parameters to be optimized and learning rate "lambda"
optim = torch.optim.SGD([x], lr=0.01)
# Number of iterations
n_iter = 500
#create list to track value of x
xs = []
#ys = []
# Gradient descent loop
for i in range(n_iter):
# Logging
xs.append(x.item())
f = x**2 - 4*x + 2*x - 1 # forward pass
f.backward() # backward pass
#ys.append(f.item())
optim.step() # perform the gradient descent step
optim.zero_grad() # reset the gradients
# Plot how the estimate for x converged
plot_graph('Iteration', 'Estimate of x', [-0.5, 2.0], xs)
#plot_graph('Iteration', 'Estimate of x', [-3.5, 2.0], ys)
We clearly see that through differentation and and automatic update, we approximate the value of x that gives us the minimum of our function \(f(x)\).
Now we will start to implement this in an actual-example. Please note that we now will start to compare the processed output towards a target using a specific loss function. This means that we are trying to optimize the loss function and not the DSP function. However, the DSP function still needs to be differentiable simply because this allows us to retrieve and track gradients as well as update the parameters in the backward() call. ####
All loss-functions provided in PyTorch are differentiable too.
Lets take the learned components and combine them to create a differentiable linear gain model that can predict the gain value for a specific input output pair.
A simple Differentiable Linear Gain Effect#
We will now take a pair of dry audio and wet audio processed by a gain factor. We will use gradient descent to estimate the parameters that were applied to the dry signal to obtain the processed one. We can inherit the nn.Module and Parameter classes from PyTorch to define the behaviour and parameters of our gain control.
We start by creating our train and target signals.
The gain value we want to predict is 0.2#
#Create input and target
sr = 16000
freq = 200
target_gain = 0.2
# generate half a second of sine wave at 300 Hz
input_audio = get_sine(1.0, freq, sr)
target_audio = get_sine(target_gain, freq, sr)
plot_graph('Sample', 'Amplitude', [-1.2, 1.2], input_audio[:200], target_audio[:200], ["Target", "Original"])
ipd.display(ipd.Audio(input_audio, rate=sr, normalize=False))
ipd.display(ipd.Audio(target_audio, rate=sr, normalize=False))
We define our model using nn.Module (see that it inherits a forward function), and train. The forward() function of nn.Module stores all values and computations making sure that we can call backward() where we calculate the gradients used for gradient descent.
#create linear gain function
class LinearGain(torch.nn.Module):
def __init__(self, gain=1.0):
super().__init__()
self.gain = torch.nn.Parameter(torch.tensor(gain))
def forward(self, x):
return self.gain * x
We initialise the LinearGain class, a loss function we want to use to compare the processed input and target, as well as the SGD optimiser. We train for 300 iterations. Note that we use optim.zero_grad() for each iteration. This ensures that all gradient values are swiped for every parameter update.
diff_gain = LinearGain() # initialise module
l1_loss = torch.nn.L1Loss() # measures the mean absolute error (MAE) between each element in the input x and target
optim = torch.optim.SGD(diff_gain.parameters(), lr=0.01) #initialise optimizer
n_iter = 300
gains = []
losses = []
for i in range(n_iter):
#logging
gains.append(diff_gain.gain.item())
optim.zero_grad()
estim_audio = diff_gain(input_audio) # forward pass
loss = l1_loss(estim_audio, target_audio)
losses.append(loss.item())
loss.backward() #calculate gradients based on loss
optim.step() #update the parameter
# Plot how the estimate for x converged
plot_graph('Iteration', 'Loss', [-0.2, 0.8], losses)
Lets look at the how the processed signal changes for each iteration compared to the target.
# Animate the fitting process
def get_gain_loss_animation(org, tgt, gains, losses):
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
# Plot target and estimate
ax[0].plot(tgt[:200])
line, = ax[0].plot([], [])
ax[0].set_xlabel("Time (samples)")
ax[0].set_ylabel("Amplitude")
ax[0].set_ylim(-1, 1)
ax[0].legend(["Target", "Estimate"], loc=1)
# Plot losses animation
ax[1].set_xlim(0, len(losses))
ax[1].set_ylim(min(losses), max(losses))
line_loss, = ax[1].plot([], [], lw=2)
ax[1].set_xlabel("Iteration")
ax[1].set_ylabel("Loss")
def init():
line.set_data([], [])
line_loss.set_data([], [])
return line, line_loss,
def animate(i):
# Update estimate plot
line.set_data(np.arange(200), org[:200] * gains[i * 5])
ax[0].set_title(f"Estimated signal after {i * 5} iterations")
ax[1].set_title(f"Loss after {i * 5} iterations")
# Update losses plot
line_loss.set_data(np.arange((i*5)+1), losses[:(i*5)+1])
return line, line_loss,
# Create the animation
anim = FuncAnimation(fig, animate, init_func=init, frames=len(gains) // 5, interval=50, blit=True)
plt.close(fig)
return anim
display(HTML(get_gain_loss_animation(input_audio, target_audio, gains, losses).to_html5_video()))
We see that we over time learn the gain value that results in the target. We also see that the loss, being the mean squared amplitude difference between the processed input and the target, decreases and reaches 0 at the same iteration that the sines perfectly align in amplitude.
However, the problem we are trying to solve is very easy. Simply because it is linear, contains input and target that is aligned time-wise and only includes one parameter.
Rather, let’s look at something a bit more complex.
A simple Differentiable Waveshaper#
A waveshaper is used to shape a sound giving it more harmonics. This often results in a warm or harsh feeling also known from saturation or overdrive effects (espesically prominent in guitar pedals). If we want to model the characteristics of analog distortion, and especially tube distortion, we can use a modified tanh() function as this allows us to model the positive and negative slopes of the input differently. The modified tanh function is given by:
\(\begin{aligned}tanh_{mod}(x) = \frac{e^{x*(a+G)} - e^{x*(b+G)}}{e^{x*G} + e^{x*-G}}\end{aligned}\)
Here the distortion amount \(G\) defines the overall shape/drive, whereas \(a\) and \(b\) are small offsets added to the positive and negative parts of the input signal respectively.
Lets create a training and random target signal using the modified tanh function. The target could here also be an analog tube distortion effect that you do not know the inner workings of. Using DDSP you could try to model it.
def mod_tanh(x, a, b, g):
numerator = np.exp(x*(a+g)) - np.exp(x*(b-g))
denominator = np.exp(x*g) + np.exp(x*(-g))
return numerator/denominator
#Create input and target
sr = 48000
freq = 300
target_a = 0.6
target_b = 0.4
target_g = 4.5
# generate half a second of sine wave at 300 Hz
input_audio = get_sine(1.0, freq, sr)
target_audio = mod_tanh(input_audio, target_a, target_b, target_g)
plot_graph('Sample', 'Amplitude', [-1.2, 3.2], input_audio[:200], target_audio[:200], ["Target", "Original"])
ipd.display(ipd.Audio(input_audio, rate=sr, normalize=True))
ipd.display(ipd.Audio(target_audio, rate=sr, normalize=True))
guitar, _ = librosa.load('sound-files/guitar.wav', sr=sr, mono=True)
guitar_dist = mod_tanh(guitar, target_a, target_b, target_g)
ipd.display(ipd.Audio(guitar, rate=sr, normalize=True))
ipd.display(ipd.Audio(guitar_dist, rate=sr, normalize=True))
We implement it in the PyTorch framework for automatic differentiation
class Modified_Tanh(torch.nn.Module):
def __init__(self, a=0.0, b=0.0, g=0.0):
super().__init__()
self.a = torch.nn.Parameter(torch.tensor(a))
self.b = torch.nn.Parameter(torch.tensor(b))
self.g = torch.nn.Parameter(torch.tensor(g))
def forward(self, x):
numerator = torch.exp(x*(self.a+self.g)) - torch.exp(x*(self.b-self.g))
denominator = torch.exp(x*self.g) + torch.exp(x*(-self.g))
return numerator/denominator
Again, we train using the SGD optimizer to find the waveshaping values we applied to the target
diff_effect = Modified_Tanh() # initialise module
l1_loss = torch.nn.L1Loss() # measures the mean absolute error (MAE) between each element in the input x and target
optim = torch.optim.SGD(diff_effect.parameters(), lr=0.01) #initialise optimizer
n_iter = 5000
a = []
b = []
g = []
losses = []
for i in range(n_iter):
#logging
a.append(diff_effect.a.item())
b.append(diff_effect.b.item())
g.append(diff_effect.g.item())
optim.zero_grad()
estim_audio = diff_effect(input_audio) # forward pass
loss = l1_loss(estim_audio, target_audio)
losses.append(loss.item())
loss.backward() #calculate gradients based on loss
optim.step() #update the parameter
# Plot how the estimate for x converged
plot_graph('Iteration', 'Loss', [-0.2, 0.8], losses)
We can now track the parameters/coefficients to see how they change over time when we update them through gradient descent
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
# Animate the fitting process
def get_loss_animation(losses_list):
num_iterations = len(losses_list[0])
num_plots = len(losses_list)
fig, ax = plt.subplots(figsize=(6, 3))
lines_loss = []
annotations = []
for i in range(num_plots):
if i == 0:
label = "a"
elif i == 1:
label = "b"
else:
label = "g"
line_loss, = ax.plot([], [], lw=2, label=label)
lines_loss.append(line_loss)
annotations.append(ax.text(0.95, 0.9-i*0.1, '', transform=ax.transAxes, ha='right', va='center'))
ax.set_xlim(0, num_iterations)
ax.set_ylim(0, max([max(losses) for losses in losses_list]))
ax.set_xlabel("Iteration")
ax.set_ylabel("Param Val")
ax.legend(loc='upper left')
def init():
for line, annotation in zip(lines_loss, annotations):
line.set_data([], [])
annotation.set_text('')
return lines_loss + annotations
def animate(iter_idx):
for i, (line_loss, annotation) in enumerate(zip(lines_loss, annotations)):
if i == 0:
label = "a"
target = target_a
elif i == 1:
label = "b"
target = target_b
else:
label = "g"
target = target_g
line_loss.set_data(np.arange((iter_idx*30)+1), losses_list[i][:(iter_idx*30)+1])
annotation.set_text(f"{label}: {losses_list[i][iter_idx*30]:.2f} - target: {target}")
annotation.set_position((0.95, 0.9-i*0.1))
ax.set_title(f"Parameters after {iter_idx * 30} iterations")
return lines_loss + annotations
# Create the animation
anim = FuncAnimation(fig, animate, init_func=init, frames=num_iterations // 30, interval=50, blit=True)
plt.close(fig)
return anim
coeffs = [a, b, g]
# Example usage:
display(HTML(get_loss_animation(coeffs).to_html5_video()))
We are pretty close! Lets try to hear it on a guitar
diff_effect.eval()
input_to_process = torch.tensor(guitar)
with torch.no_grad():
processed = diff_effect(input_to_process)
processed = processed.reshape(-1).cpu().numpy()
print("Original")
ipd.display(ipd.Audio(guitar, rate=sr, normalize=True))
print("Target")
ipd.display(ipd.Audio(guitar_dist, rate=sr, normalize=True))
print("Predicted")
ipd.display(ipd.Audio(processed, rate=sr, normalize=True))
The art of choosing the right Loss function#
There exists many different loss-functions, operating in different ways. Until now we have only used the L1 loss (also called MAE) that measures the average distance between the absolute values of our output and target:
\(\begin{aligned}L1=\sum_{i=1}^n\left|y_{\text {true }}-y_{\text {predicted }}\right|\end{aligned}\)
Many other loss functions exist, with each their behaviour. Below we see the different loss functions in 2D. As we see the L1/MAE is not differentiable at the minima.

Until now we’ve made the tasks a bit easy for ourselves. What if the signal we’re trying to match differs from the processed signal in more ways than just the shape or gain? What will happen if we phase-shift the target signal by 180 degree?
Using the L1 loss, as done until now, will most likely have troubles. Although a phase shift changes nothing about the human perception of the sound, the loss function we use will not be sufficient in comparing the signals (simply because we are comparing the signals data-point by data-point). The loss will thus no longer deliver gradients that point us in the correct direction.
sr = 16000
freq = 300
true_gain = 0.15
# generate half a second of sine wave at 300 Hz
input_audio = get_sine(1.0, freq, sr)
target_sine = torch.cos(torch.linspace(0, 2 * torch.pi * freq, sr // 2))
target_audio = true_gain * target_sine
plot_graph('Sample', 'Amplitude', [-1.2, 1.2], input_audio[:200], target_audio[:200], ["Original", "Target"])
The correct choice of the loss function can play a crucial role in optimizing parameters for audio controls. In this case, we want the loss to be invariant to phase shifts. In the case of gain, we can as an example compute the spectrogram and compare the magnitudes of the frequency bins.
We can write a custom loss module to do just that:
class SpectralLoss(torch.nn.Module):
def __init__(self, power=1):
super().__init__()
self.power = power
def forward(self, x, y):
x_mags = torch.fft.rfft(x).abs() ** self.power
y_mags = torch.fft.rfft(y).abs() ** self.power
return torch.nn.functional.l1_loss(x_mags, y_mags)
def get_gain_animation(org, tgt, gains):
fig, ax = plt.subplots(figsize=(6, 3))
ax.plot(tgt[:200])
line, = ax.plot([], [])
ax.set_xlabel("Time (samples)")
ax.set_ylabel("Amplitude")
ax.set_ylim(-1, 1)
ax.legend(["Target", "Estimate"], loc=1)
def init():
line.set_data([], [])
return line,
def animate(i):
line.set_data(np.arange(200), org[:200] * gains[i * 5])
ax.set_title(f"Estimated signal after {i * 5} iterations")
return line,
# Create the animation
anim = FuncAnimation(fig, animate, init_func=init, frames=len(gains) // 5, interval=50, blit=True)
plt.close(fig)
return anim
model = LinearGain()
spectral_loss = SpectralLoss()
optim = torch.optim.SGD(model.parameters(), lr=0.01)
n_iter = 300
gains = []
for i in range(n_iter):
gains.append(model.gain.item())
optim.zero_grad()
estim_audio = model(input_audio)
loss = spectral_loss(estim_audio, target_audio)
loss.backward()
optim.step()
display(HTML(get_gain_animation(input_audio, target_audio, gains).to_html5_video()))
More loss functions can be found in the PyTorch documentation https://pytorch.org/docs/stable/nn.html#loss-functions, while perceptual loss functions like spectral losses can be found in https://github.com/csteinmetz1/auraloss
A simple Differentiable IIR Filter Effect#
A system with infinite impulse response (IIR) is called an IIR filter. Here, each processed output sample is dependent on both former samples of the input and former samples of the output, each scaled by a coefficient. It is also called a recursive system because the output samples are recursively computed from past output samples.
The z-domain transfer function of a general second order IIR filter having 2 poles and 2 zeros (poles and zeros are the roots of the numerator and denominator of the transfer function, respectively), is given by:
\(\begin{aligned}H(z) = \frac{b_0 + b_1z^{-1} + b_2z^{-2}}{1 + a_1z^{-1} + a_2z^{-2}}\end{aligned}\)
Since this transfer function is the ratio of two quadratic functions, it is commonly referred to as a biquad filter, which is used for many musical purposes.
As before, we can train (automatically update and predict) the coefficients of this filter function to estimate a specific frequency response. We use the Transposed Direct Form-II (TDF-II) to retrieve the difference equation from above transfer function.
\(y[n] = b_0x[n] + h_1[n-1]\)
\(h_1[n] = b_1x[n] - a_1y[n] + h_2[n-1]\)
\(h_2[n] = b_2x[n] - a_2y[n]\)
We again implement this difference equation into PyTorch by inhereting from the nn.Module.
Notice how we store the vectors h1 and h2 in a matrix for simplification purposes.
class DTDFII(Module):
def __init__(self):
super(DTDFII, self).__init__()
self.b0 = Parameter(FloatTensor([uniform(-1, 1)]))
self.b1 = Parameter(FloatTensor([uniform(-1, 1)]))
self.b2 = Parameter(FloatTensor([uniform(-1, 1)]))
self.a1 = Parameter(FloatTensor([uniform(-0.5, 0.5)]))
self.a2 = Parameter(FloatTensor([uniform(-0.5, 0.5)]))
def _cat(self, vectors):
return torch.cat([v_.unsqueeze(-1) for v_ in vectors], dim=-1)
def forward(self, input, h):
output = input * self.b0 + h[:, 0]
h1 = input * self.b1 + h[:, 1] - output * self.a1
h2 = input * self.b2 - output * self.a2
h = self._cat([h1, h2])
return output, h
def init_states(self, size):
h = torch.zeros(size, 2).to(next(self.parameters()).device)
return h
We define our input as a chirp (sine sweep) going from 0 to 20kHz in 10 seconds, with the target being the same sweep filtered by a DSP butterworth algorithm at 2kHz. By iteratively comparing the processed input to the actual filtered output, we try to adjust our differentiable IIR filter to match the frequency response of the original filter.
from scipy import signal
import numpy as np
fs = 48000
sec = 10
T = int(fs * sec)
start_freq = 1
end_freq = 20000
t = np.linspace(0, sec, sec*fs)
train_input = signal.chirp(t=t, f0=start_freq, t1=sec, f1=end_freq, method='logarithmic') + np.random.normal(scale=5e-2, size=len(t))
fc = 18000 #Hz
b, a = signal.butter(N=2, Wn=fc/fs, btype='high')
print("The filter has the following coefficients:")
print("Coeffs b:", b, ", coeffs a:", a)
sos = signal.tf2sos(b, a)
train_target = signal.sosfilt(sos, train_input)
impulse = np.zeros(1000)
impulse[0] = 1.0
imp_filter = signal.sosfilt(sos, impulse)
plot_tf("Filtered chirp signal", fs, np.arange(T) / fs, [train_target], [imp_filter])
ipd.display(ipd.Audio(train_input, rate=fs, normalize=True))
ipd.display(ipd.Audio(train_target, rate=fs, normalize=True))
Above training signal is 10 seconds long. At 48kHz, that is 480000 samples (that is a lot of data!). Performing our forward step on 480000 samples is computationally inefficient, meaning that we have to wait long for each gradient calculation and thus parameter update. In order to make the operations more efficient, we split our signals into batches of sequences, such that we now apply the filter operations on a matrix consisting of (batch_size, sequence_length, audio_channels).
How this is done is not important, we use a utility function you can check out in the utils.py script.
batch_size = 1024
sequence_length = 512
loader = DataLoader(dataset=DIIRDataSet(train_input, train_target, sequence_length), batch_size=batch_size, shuffle=True, drop_last=False)
print("Batch dim:", next(iter(loader))['input'].size())
print("Sequences available in dataset:", int(len(train_input)/sequence_length))
print("Batches available in dataset:", int(np.ceil(int(len(train_input)/sequence_length) / batch_size)))
Above we see that from a training signal of 480000, we have 937 sequences of 512 samples. With a batch_size of 128, this means we have 8 batches: 7 consisting of 128 sequences, 1 consisting of 41 sequences.
We define our model, loss function and optimizer (this time we use Adam rather than SGD).
from torch.optim import Adam
n_epochs = 5000
filter_function = DTDFII()
model = DIIR_WRAPPER(filter_function).to(device)
optimizer = Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
criterion = torch.nn.MSELoss()
Define training loop#
We define a training loop. Here we loop through each batch, calculate the loss and return it for visualisation purposes.
def train(criterion, model, loader, optimizer):
model.train()
device = next(model.parameters()).device
total_loss = 0
for batch in loader:
input_seq_batch = batch['input'].to(device)
target_seq_batch = batch['target'].to(device)
optimizer.zero_grad()
predicted_output = model(input_seq_batch)
loss = criterion(target_seq_batch, predicted_output)
loss.backward()
optimizer.step()
total_loss += loss.item()
total_loss /= len(loader)
return total_loss
Train!#
BE AWARE - ON A CPU, TRAINING MAY TAKE SEVERAL HOURS (TOOK 45 min ON A GPU)#
losses = []
b0, b1, b2 = [], [], []
a1, a2 = [], []
for epoch in range(n_epochs):
loss = train(criterion, model, loader, optimizer)
losses.append(loss)
b0.append(model.cell.b0.item())
b1.append(model.cell.b1.item())
b2.append(model.cell.b2.item())
a1.append(model.cell.a1.item())
a2.append(model.cell.a2.item())
if epoch %200 == 0:
print("Epoch {} -- Loss {:3E}".format(epoch, loss))
coeffs = [b0, b1, b2, a1, a2]
plt.figure(figsize=(6, 3))
plt.plot(losses)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.yscale('log')
plt.show()
Let’s look at how each coefficient adjusts across each iteration.
We clearly see that they are finding their way towards a configuration that gives the minimum of our objective function - this is the magic of DDSP.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
# Animate the fitting process
def get_loss_animation(losses_list):
num_iterations = len(losses_list[0])
num_plots = len(losses_list)
fig, ax = plt.subplots(figsize=(6, 3))
lines_loss = []
for i in range(num_plots):
if i < 3:
label=f"b{i+1}"
else:
label=f"a{i-3}"
line_loss, = ax.plot([], [], lw=2, label=label)
lines_loss.append(line_loss)
ax.set_xlim(0, num_iterations)
ax.set_ylim(-1, max([max(losses) for losses in losses_list]))
ax.set_xlabel("Iteration")
ax.set_ylabel("Loss")
ax.legend()
def init():
for line in lines_loss:
line.set_data([], [])
return lines_loss
def animate(iter_idx):
for i, line_loss in enumerate(lines_loss):
line_loss.set_data(np.arange((iter_idx*30)+1), losses_list[i][:(iter_idx*30)+1])
ax.set_title(f"Params after {iter_idx*30} iterations")
return lines_loss
# Create the animation
anim = FuncAnimation(fig, animate, init_func=init, frames=num_iterations // 30, interval=50, blit=True)
plt.close(fig)
return anim
# Example usage:
display(HTML(get_loss_animation(coeffs).to_html5_video()))
Let’s also see if we have approached the target frequency response
model.eval()
impulse = np.zeros(sequence_length)
impulse[0] = 1.0
impulse = torch.tensor(impulse, dtype=torch.float32).to(device)
input_to_process = train_input
padding = int(np.ceil((len(input_to_process) / sequence_length)) * sequence_length) - len(input_to_process)
batched_input = torch.nn.functional.pad(torch.tensor(input_to_process, dtype=torch.float32), (0, padding)).reshape(-1, sequence_length, 1)
processed = torch.zeros(batched_input.shape)
with torch.no_grad():
processed = model(batched_input.to(device))
imp_model = model(impulse.unsqueeze(0).unsqueeze(-1))
processed = processed.reshape(-1).cpu().numpy()
imp_model = imp_model.reshape(-1).cpu().numpy()
plot_tf(
"Filtered Chirp signal",
fs,
np.arange(len(train_target)) / fs,
[train_target, processed[:len(train_target)]],
[imp_filter, imp_model],
["scipy.signal.butter", "diff_iir"]
)
We see that we are approaching the target frequency reponse
TASKS for Further Experimentation#
Below are two tasks.
a) First, we create another differentiable filter that more freely can predict frequency responses. You are tasked to experiment with the filter, trying to create a target signal and add the filter model to the wrapper. Lastly, you can try to train the filter (I recommend doing this in a notebook).
b) Secondly we provide an implementation and training scheme for the wave equation in a differentiable manner. Many of you know the wave equation from other courses (i.e. the karplus-strong algorithm). You will see how we can use differentiable signal processing to approximate damping coefficients and other physical parameters to obtain a target sound.
a) State Variable Filter (SVF)#
The above filter problem was tailored to work. As seen in the implementation of the DTDFII module, we clamp the coefficients for stability reasons. This means that not all coefficient configurations are possible and thus not all 2nd order frequency responses can be obtained. A high-pass filter with a cutoff at 18kHz was thus deliberately chosen as I knew the DTDFII would be able to find the respective coefficients.
To take account of this, and to be able to predict 2nd order filter frequency responses more freely, we can use the State-Variable Filter (SVF). The SVF can produce any second-order transfer function, whilst still having easily interpretable parameters. Its difference equation is given by:
\(\begin{aligned} y_{\mathrm{BP}}[n] & =\frac{g\left(x[n]-h_2[n-1]\right)+h_1[n-1]}{1+g(g+2 R)} \\ y_{\mathrm{LP}}[n] & =g y_{\mathrm{BP}}[n]+h_2[n-1] \\ y_{\mathrm{HP}}[n] & =x[n]-y_{\mathrm{LP}}[n]-2 R y_{\mathrm{BP}}[n] \\ h_1[n] & =2 y_{\mathrm{BP}}-h_1[n-1] \\ h_2[n] & =2 y_{\mathrm{LP}}-h_2[n-1] \\ y[n] & =c_{\mathrm{HP}} y_{\mathrm{HP}}+c_{\mathrm{BP}} y_{\mathrm{BP}}+c_{\mathrm{LP}} y_{\mathrm{LP}},\end{aligned}\)
With the parameters being:
cHP = high-pass mixing coefficient
cBP = band-pass mixing coefficient
cLP = low-pass mixing coefficient
R = damping/resonance
g = frequency cutoff
We will not go into technical details with the SVF and you do not need to understand the math behind it. However, it is good to be aware of each parameters functionality. Anyone interested in further details can read more about the SVF here: https://www.dafx14.fau.de/papers/dafx14_aaron_wishnick_time_varying_filters_for_.pdf
In the following section, you will be tasked to implement the SVF into the differentiable framework and train it to match a specific frequency response. The SVF implementation and most of the needed code will be provided, you are asked to fill in the empty spaces.
class DSVF(Module):
def __init__(self, G=0.5, twoR=1, hp_gain=0.0, bp_gain=0.0, lp_gain=1.0):
args = locals()
del args['self']
del args['__class__']
super(DSVF, self).__init__()
for key in args:
setattr(self, key, Parameter(FloatTensor([args[key]])))
self.master_gain = Parameter(FloatTensor([1.0]))
def forward(self, x, v):
coeff0, coeff1 = self.calc_coeffs()
input_minus_v1 = x - v[:, 1]
bp_out = coeff1 * input_minus_v1 + coeff0 * v[:, 0]
lp_out = self.G * bp_out + v[:, 1]
hp_out = x - lp_out - self.twoR * bp_out
v = torch.cat([(2 * bp_out).unsqueeze(-1), (2 * lp_out).unsqueeze(-1)], dim=-1) - v
y = self.master_gain * (self.hp_gain * hp_out + self.bp_gain * self.twoR * bp_out + self.lp_gain * lp_out)
return y, v
def init_states(self, size):
v = torch.zeros(size, 2).to(next(self.parameters()).device)
return v
def calc_coeffs(self):
self.G.data = torch.clamp(self.G, min=1e-8)
self.twoR.data = torch.clamp(self.twoR, min=0)
self.bp_gain.data = torch.clamp(self.bp_gain, min=-1)
self.hp_gain.data = torch.clamp(self.hp_gain, min=-1, max=1)
self.lp_gain.data = torch.clamp(self.lp_gain, min=-1, max=1)
coeff0 = 1.0 / (1.0 + self.G * (self.G + self.twoR))
coeff1 = self.G * coeff0
return coeff0, coeff1
Create input and target training signal#
fs = 48000
sec = 2
T = int(fs * sec)
start_freq = 1
end_freq = 20000
t = np.linspace(0, sec, sec*fs)
train_input = signal.chirp(t=t, f0=start_freq, t1=sec, f1=end_freq, method='logarithmic') + np.random.normal(scale=5e-2, size=len(t))
fc = #choose filter cutoff
filter_type = #choose filter type
b, a = signal.butter(N=2, Wn=fc/fs, btype=filter_type)
print("The filter has the following coefficients:")
print("Coeffs b:", b, ", coeffs a:", a)
sos = signal.tf2sos(b, a)
train_target = signal.sosfilt(sos, train_input)
impulse = np.zeros(1000)
impulse[0] = 1.0
imp_filter = signal.sosfilt(sos, impulse)
plot_tf("Filtered chirp signal", fs, np.arange(T) / fs, [train_target], [imp_filter])
ipd.display(ipd.Audio(train_input, rate=fs, normalize=True))
ipd.display(ipd.Audio(train_target, rate=fs, normalize=True))
batch_size = #choose batch size
sequence_length = #choose sequence length
loader = DataLoader(dataset=DIIRDataSet(train_input, train_target, sequence_length), batch_size=batch_size, shuffle=True, drop_last=False)
print("Batch dim:", next(iter(loader))['input'].size())
print("Sequences available in dataset:", int(len(train_input)/sequence_length))
print("Batches available in dataset:", int(np.ceil(int(len(train_input)/sequence_length) / batch_size)))
n_epochs = 1500
filter_function = #initialise differentiable filter
model = #add to wrapper
optimizer = #initialise optimizer
criterion = #initialise loss
#train the model
for epoch in range(n_epochs):
loss = #train
losses.append(loss)
if epoch %200 == 0:
print("Epoch {} -- Loss {:3E}".format(epoch, loss))
model.eval()
impulse = np.zeros(sequence_length)
impulse[0] = 1.0
impulse = torch.tensor(impulse, dtype=torch.float32).to(device)
input_to_process = train_input
padding = int(np.ceil((len(input_to_process) / sequence_length)) * sequence_length) - len(input_to_process)
batched_input = torch.nn.functional.pad(torch.tensor(input_to_process, dtype=torch.float32), (0, padding)).reshape(-1, sequence_length, 1)
processed = torch.zeros(batched_input.shape)
with torch.no_grad():
processed = model(batched_input.to(device))
imp_model = model(impulse.unsqueeze(0).unsqueeze(-1))
processed = processed.reshape(-1).cpu().numpy()
imp_model = imp_model.reshape(-1).cpu().numpy()
plot_tf(
"Filtered Chirp signal",
fs,
np.arange(len(train_target)) / fs,
[train_target, processed[:len(train_target)]],
[imp_filter, imp_model],
["scipy.signal.butter", "diff_iir"]
)
b) The Wave Equation#
In this section, we’ll look at physical sound synthesis and model a string sound from the wave equation. Thereafter we will use gradient descent to find the parameters of the model that best fit a given sound.
In particular we will focus on digital waveguide synthesis (DWG). DWGs are based on D’Alembert’s travelling wave solution to the wave equation, where the solution is given by waves travelling on opposite directions:
here \(F(x + ct)\) represents a wave traveling to the left and \(G(x - ct)\) represents a wave traveling to the right.
In DWGs, the propagation of the traveling waves is simulated using delay lines. At each sample step, losses occur, but if the loss is a linear operation, it can be commuted out of the individual samples and be applied cumulatively to the output of the delay line.
The model of the loss should be frequency-dependent. With the simplest possible loss filter, we obtain a simulation diagram that looks like this:

This might look familiar as the basic structure of the Karplus-Strong algorithm for plucked string synthesis. In fact, the Karplus-Strong algorithm can be seen as a simple DWG. We’ll look at applying the same methods as before to find the parameters of this model that best fit a given sound using gradient descent.
The transfer function of the basic Karplus-Strong algorithm as shown before is
where \(N\) is the length of the delay line corresponding to the modeled string and controls pitch, and \(g\) is the feedback gain, which controls the decay time of the sound.
We’ll implement this transfer function in the frequency domain for more efficient estimation, and in the time domain for the final result.
class KarplusStrong(torch.nn.Module):
def __init__(self, delay_len, n_fft=2048):
super().__init__()
self.delay_gain = torch.nn.Parameter(torch.tensor(0.0))
self.delay_len = delay_len
# for frequency sampling
self.z = torch.exp(1j * torch.linspace(0, torch.pi, n_fft // 2 + 1))
# random excitation
self.exc = torch.zeros(n_fft)
self.exc[:delay_len] = torch.rand(delay_len) - 0.5
self.exc_fft = torch.fft.rfft(self.exc)
# scale delay gain to [0.9, 1.0]
def scaled_gain(self):
return torch.sigmoid(self.delay_gain) * 0.1 + 0.9
# forward pass: synthesis in the frequency domain
def forward(self):
z = self.z
delay_gain = self.scaled_gain()
# sample transfer function
numer = 1.
denom = (1 - delay_gain * (0.5 * z ** (-self.delay_len) + 0.5 * z ** (-self.delay_len - 1)))
# filter excitation in frequency domain
return self.exc_fft * numer / denom
# also provide method for time domain synthesis
def time_domain_synth(self, n_samples):
delay_gain = self.scaled_gain()
# populate filter coefficients for IIR filter
a_coeffs = torch.zeros(self.delay_len + 2)
a_coeffs[0] = 1
a_coeffs[self.delay_len] = -delay_gain * 0.5
a_coeffs[self.delay_len + 1] = -delay_gain * 0.5
b_coeffs = torch.zeros(self.delay_len + 2)
b_coeffs[0] = 1
# pad or truncate self.exc to n_samples
if self.exc.shape[0] < n_samples:
audio = torch.cat([self.exc, torch.zeros(n_samples - self.exc.shape[0])])
else:
audio = self.exc[:n_samples]
audio = torchaudio.functional.lfilter(audio, a_coeffs, b_coeffs, clamp=False)
return audio
# let's have a listen
synth = KarplusStrong(80)
audio = synth.time_domain_synth(32000)
ipd.Audio(audio.detach(), rate=16000)
Let’s now load an acoustic guitar sound file from the NSynth dataset. We’ll try to have our Karplus-Strong model mimic this sound. Since it is a very simple model, we won’t get too close of a match, but we should be able to tune the decay time.
As mentioned before, pitch estimation with gradient descent can be tricky, so we’ll infer the length of the delay line from the pitch of the recording: At MIDI note 51, it’s about 155.56 Hz. With a sample rate of 16000 Hz, this corresponds to a delay of 102.8 samples. We’ll round this to 103 samples. More accuracy could be achieved by using fractional delays, but we’ll keep it simple here.
sr = 16000
audio, sr = librosa.load("sound-files/guitar-nsynth.wav", sr=sr, mono=True)
# how many points used in sampling the transfer function
nfft = 4096
# fix random excitation
torch.manual_seed(0)
karplus_model = KarplusStrong(delay_len=103, n_fft=nfft)
print("Original:")
ipd.display(ipd.Audio(audio, rate=sr))
print("Synthesized:")
ipd.display(ipd.Audio(karplus_model.time_domain_synth(sr * 4).detach(), rate=sr))
This doesn’t sound close at all. Let’s see if we can once again use gradient descent to find a better value for \(g\) and match the decay time. We’ll define our own loss function using L1 loss on the normalized log magnitudes of the spectrum:
def to_log_mag(freq_response, rel_to_max=True, eps=1e-7):
mag = torch.abs(freq_response)
if rel_to_max:
div = torch.max(mag)
else:
div = 1.0
return 10 * torch.log10(mag / div + eps)
def loss_fn(y, y_hat):
y_mags = to_log_mag(y)
y_hat_mags = to_log_mag(y_hat)
return torch.mean((y_mags - y_hat_mags).abs())
We’re all set for optimization!
# calculate truncated fft
target = torch.fft.rfft(torch.tensor(audio), n=nfft)
fftfreqs = torch.fft.rfftfreq(nfft, 1 / sr)
plt.plot(fftfreqs, to_log_mag(target.detach()), label="target")
plt.plot(fftfreqs, to_log_mag(karplus_model().detach()), label="initial synthesis")
optim = torch.optim.Adam(karplus_model.parameters(), lr=1e-2)
for i in range(1000):
optim.zero_grad()
loss = loss_fn(target, karplus_model())
loss.backward()
optim.step()
plt.plot(fftfreqs, to_log_mag(karplus_model().detach()), label="optimized synthesis")
plt.legend()
plt.ylabel("Magnitude (dB)")
plt.xlabel("Frequency (Hz)")
plt.show()
print("Audio after optimization:")
td_out = karplus_model.time_domain_synth(audio.shape[0]).detach()
ipd.display(ipd.Audio(td_out, rate=sr))