diff --git a/dtw_minwarplag.ipynb b/dtw_minwarplag.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..4cd3a1606eaf40faf586856e3f106a62af7e0346 --- /dev/null +++ b/dtw_minwarplag.ipynb @@ -0,0 +1,387 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "<center>\n", + "\n", + "*******************************************************************************************\n", + " \n", + "### DYNAMIC TIME WARPING, \n", + "\n", + "### MINIMUM-WARP OPTIMAL PATH, AND POINTWISE LAG \n", + " \n", + "<br>\n", + " \n", + "##### 30 JULY 2023 \n", + "\n", + "##### Juan Ignacio Mendoza Garay \n", + "##### doctoral student \n", + "##### Department of Music, Art and Culture Studies \n", + "##### University of Jyväskylä \n", + "\n", + "*******************************************************************************************\n", + "\n", + "</center>\n", + "\n", + "#### INFORMATION:\n", + "\n", + "\n", + "* Description:\n", + "\n", + " Demonstrates the most simple (as in \"easy to understand\", not necessarily faster)\n", + " algorithm for Dynamic Time Warping (also called \"classical\" version), an algorithm \n", + " to traceback the optimal path (also \"classical\"), and novel algorithms (as far as I am aware)\n", + " to traceback the optimal path with minimum time-warping and its pointwise lag.\n", + "\n", + "* Instructions:\n", + "\n", + " Edit the values indicated with an arrow like this: <--- \n", + " Comment/uncomment or change values as suggested by the comments. \n", + " Run the program, close your eyes and hope for the best. \n", + "\n", + "*******************************************************************************************\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from tslearn.metrics import dtw_path\n", + "import matplotlib.pyplot as plt\n", + "from scipy.spatial.distance import cdist\n", + "from time import time" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*******************************************************************************************\n", + "#### TEST SIGNALS:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# EXAMPLE SET 1:\n", + "if True: # <---\n", + " \n", + " x_length = 40 # <---\n", + "\n", + " x = np.arange(0,x_length)\n", + " y = np.zeros(x.size)\n", + " y_1 = y.copy()\n", + " y_2 = y.copy()\n", + " \n", + " y_1[10:20] = np.arange(0,10,1) # <---\n", + " #y_2 = y_1 # <--- same shape, no lag\n", + " y_2[15:25] = np.arange(0,10,1) # <--- same shape, lag forwards (delay)\n", + " #y_2[5:15] = np.arange(0,10,1) # <--- same shape, lag backwards (anticipation)\n", + " #y_2[12:27] = np.arange(0,15,1) # <--- different shape, lag forwards (delay)\n", + " #y_2[3:18] = np.arange(0,15,1) # <--- different shape, lag backwards (anticipation)\n", + "\n", + "# EXAMPLE SET 2:\n", + "# Uses x, y_1, and y_2 of example set 1.\n", + "if False: # <--- \n", + "\n", + " y_1[10:20] = [1,2,3,4,5,5,4,3,2,1] # <--- \n", + " y_2[15:25] = [1,2,3,4,5,5,4,3,2,1] # <--- \n", + " \n", + "# EXAMPLE SET 3:\n", + "# Uses x of example set 1.\n", + "if False: # <--- \n", + " \n", + " y_1_period = 3 # <--- \n", + " y_2_period = 2 # <--- \n", + " y_2_phase_shift = 0.5 # <--- \n", + " \n", + " y_1 = np.sin( x / y_1_period )\n", + " y_2 = np.sin( ( (x / y_2_period ) + (np.pi * y_2_phase_shift ) ))\n", + " \n", + " \n", + "# EXAMPLE SET 4:\n", + "if False: # <--- \n", + "\n", + " y_1 = np.array([8,7,6,5,5,6,7,8]) # <--- \n", + " y_2 = np.array([5,4,4,5,6,7,8,9]) # <--- \n", + " \n", + " x = np.arange(0,y_1.size)\n", + "\n", + " \n", + "y_1_rs = y_1.reshape(-1,1)\n", + "y_2_rs = y_2.reshape(-1,1)\n", + " \n", + "plt.plot(y_1) \n", + "plt.plot(y_2); # the semicolon prevents printing the output of the last command" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*******************************************************************************************\n", + "#### METHOD 1:\n", + "Using the 'tslearn' library. Simple, fast, easy. Gets the job done with no hassle.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dm_1 = cdist( y_1_rs, y_2_rs , 'cityblock' )\n", + "\n", + "tic = time()\n", + "optimal_path_1, dtw_score_1 = dtw_path(y_1, y_2)\n", + "print('computation time of DTW 1 = '+str(time()-tic))\n", + "\n", + "op_1_x = [col[0] for col in optimal_path_1]\n", + "op_1_y = [col[1] for col in optimal_path_1]\n", + "\n", + "plt.imshow(dm_1.T, origin='lower')\n", + "plt.plot(op_1_x, op_1_y, ':w',linewidth=2)\n", + "plt.title('DTW 1 = '+str(round(dtw_score_1,3)));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The figure above shows a heat-map of the distance matrix and the classical DTW optimal path." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*******************************************************************************************\n", + "#### METHOD 2:\n", + "Using hand-made, home-brewed, simple, raw and no-BS code written with love by your humble servant.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# REFERENCES: \n", + "# https://tslearn.readthedocs.io/en/stable/user_guide/dtw.html\n", + "# https://en.wikipedia.org/wiki/Dynamic_time_warping\n", + "# Müller, M. (2007). Dynamic time warping. Information retrieval for music and motion, 69-84.\n", + "\n", + "tic = time()\n", + "dm_2 = cdist( y_1_rs, y_2_rs , 'cityblock' )\n", + "\n", + "# initialisation:\n", + "C = np.empty((dm_2.shape[0]+1,dm_2.shape[1]+1))\n", + "C[:] = np.inf\n", + "C[0,0] = 0\n", + "\n", + "# main loop:\n", + "for i in range(1,C.shape[0]):\n", + " for j in range(1,C.shape[1]):\n", + " C[i,j] = dm_2[i-1,j-1] + min(C[i-1, j], C[i, j-1], C[i-1, j-1]) # goes barebones\n", + " # C[i,j] = dm_2[i-1,j-1]**2 + min(C[i-1, j], C[i, j-1], C[i-1, j-1]) # gets cocky\n", + "\n", + "dtw_score_2 = C[i,j] # barebones\n", + "#dtw_score_2 = np.sqrt(C[i,j]) # cocky\n", + "\n", + "\n", + "# traceback optimal path:\n", + "tic_1 = time()\n", + "i, j = np.array(C.shape)-2\n", + "optimal_path_2 = [[i,j]]\n", + "while (i > 0) or (j > 0):\n", + " \n", + " tb = np.argmin((C[i, j], C[i, j + 1], C[i + 1, j]))\n", + " \n", + " if tb == 0:\n", + " i -= 1\n", + " j -= 1\n", + " elif tb == 1:\n", + " i -= 1\n", + " else: # (tb == 2):\n", + " j -= 1\n", + " \n", + " optimal_path_2.append([i,j])\n", + "\n", + "print('computation time of DTW 2 = '+str(time()-tic))\n", + "print('computation time of classical optimal path = '+str(time()-tic_1))\n", + "\n", + "op_2_x = [col[0] for col in optimal_path_2]\n", + "op_2_y = [col[1] for col in optimal_path_2]\n", + "\n", + "plt.imshow(dm_2.T, origin='lower')\n", + "plt.plot(op_2_x, op_2_y, ':w',linewidth=2)\n", + "plt.title('DTW 2 = '+str(round(dtw_score_2,3)));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that the 'barebones' (default) version of method 2 returns a DTW based on absolute differences, not on the Euclidean distance as in the 'cocky' version used by the tslearn library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# traceback minimum-warp optimal path:\n", + "\n", + "tic = time()\n", + "i, j = np.array(C.shape)-2\n", + "optimal_path_3 = [[i,j]]\n", + "while (i > 0) or (j > 0):\n", + " \n", + " this_query = (C[i, j], C[i, j + 1], C[i + 1, j])\n", + " this_min = np.min(this_query)\n", + " these_argmins = np.where(this_query==this_min)\n", + "\n", + " if these_argmins[0][0] == 0:\n", + " \n", + " if (np.any(these_argmins[0][:] == 1)) and (i > j):\n", + " \n", + " i -= 1\n", + " \n", + " elif (np.any(these_argmins[0][:] == 2)) and (j > i):\n", + " \n", + " j -= 1\n", + " \n", + " else:\n", + " \n", + " i -= 1\n", + " j -= 1\n", + " \n", + " elif these_argmins[0][0] == 1:\n", + " \n", + " i -= 1\n", + " \n", + " elif these_argmins[0][0] == 2:\n", + " \n", + " j -= 1\n", + " \n", + " optimal_path_3.append([i,j])\n", + "\n", + "print('computation time of minimum-warp optimal path = '+str(time()-tic))\n", + "\n", + "op_3_x = [col[0] for col in optimal_path_3]\n", + "op_3_y = [col[1] for col in optimal_path_3]\n", + "\n", + "plt.imshow(dm_2.T, origin='lower')\n", + "plt.plot(op_2_x, op_2_y, ':w',linewidth=2)\n", + "plt.plot(op_3_x, op_3_y, ':r', linewidth=2)\n", + "plt.title('DTW 2 = '+str(round(dtw_score_2,3)));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The algorithm for the minimum-warp optimal path (in red) tries to get closer to the diagonal." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pm = np.zeros( dm_2.shape )\n", + "\n", + "for i_op in range(0,len(optimal_path_3)):\n", + " pm[op_3_x[i_op],op_3_y[i_op]] = 1 \n", + "\n", + "ax = plt.gca() \n", + "ax.imshow(pm.T, origin='lower')\n", + "ax.plot( range(0,dm_2.shape[0]) , ':w',linewidth=1);\n", + "xy_grid = np.arange(-0.5,dm_2.shape[0]+0.5)\n", + "ax.grid(1,linestyle=':')\n", + "ax.set_xticks(xy_grid)\n", + "ax.set_yticks(xy_grid)\n", + "ax.set_xticklabels('')\n", + "ax.set_yticklabels('')\n", + "ax.set_title('minimum-warp optimal path');" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# pointwise lag of minimum-warp optimal path:\n", + "\n", + "pm_r = np.fliplr(pm).T # rotate to get anti-diagonals (lag domain)\n", + "lags = np.zeros(pm_r.shape[0])\n", + "i_lags = 0\n", + "i_adiag = -pm_r.shape[0] + 1\n", + "\n", + "while i_adiag <= pm_r.shape[0]: # iterate through anti-adiagonals\n", + " \n", + " this_adiag = pm_r.diagonal(i_adiag)\n", + "\n", + " #print('this_adiag = %s'%this_adiag)\n", + " \n", + " if any(this_adiag):\n", + "\n", + " i_center = int( np.floor(this_adiag.size / 2) )\n", + "\n", + " if not this_adiag[int(i_center)] : # value where the anti-adiagonal intersects the main adiagonal\n", + "\n", + " i_match = int(np.where(this_adiag)[0]) # where the path intersects this anti-diagonal\n", + "\n", + " lags[i_lags] = i_center - i_match \n", + " \n", + " i_adiag += 2\n", + " i_lags += 1\n", + " \n", + " else: # see if there's anything in the next diagonal\n", + " \n", + " i_adiag += 1\n", + " \n", + "plt.plot(lags)\n", + "plt.title('pointwise lag');" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}