diff --git a/scripts/run_experiments.jl b/scripts/run_experiments.jl
index d5c164d333ff2be18ae095b95febc8d47d2503e6..e78258cadd0974d27ec19a9c408415913027da13 100644
--- a/scripts/run_experiments.jl
+++ b/scripts/run_experiments.jl
@@ -27,6 +27,8 @@ using IterativeSolvers
 
 include("util.jl")
 isdefined(Main, :Revise) && Revise.track(joinpath(@__DIR__, "util.jl"))
+include("wavelet.jl")
+isdefined(Main, :Revise) && Revise.track(joinpath(@__DIR__, "wavelet.jl"))
 
 using .Util
 
@@ -452,106 +454,6 @@ function experiment_convergence_opticalflow(ctx)
         width=size(f0, 2), height=size(f0, 1))
 end
 
-function dwt!(y, x)
-    a = sqrt(1 / 2)
-    n = length(x)
-    n <= 1 && return (y .= x)
-    # use second half of y as scratch space
-    for i in 1:n÷2
-        j = 2*(i-1)+1
-        y[n÷2+i] = a * x[j] + a * x[j+1]
-    end
-    dwt!(view(y, 1:n÷2), view(y, n÷2+1:2*(n÷2)))
-    for i in 1:n÷2
-        j = 2*(i-1)+1
-        y[n÷2+i] = a * x[j] - a * x[j+1]
-    end
-    isodd(n) && (y[n] = x[n])
-    return y
-end
-
-function idwt!(y, x)
-    a = sqrt(1 / 2)
-    n = length(x)
-    n <= 1 && return (y .= x)
-    isodd(n) && (y[n] = x[n])
-    # use second half of y as scratch space
-    idwt!(view(y, n÷2+1:2*(n÷2)), view(x, 1:n÷2))
-    for i in 1:n÷2
-        j = 2*(i-1)+1
-        y[j]   = a * y[n÷2+i] + a * x[n÷2+i]
-        y[j+1] = a * y[n÷2+i] - a * x[n÷2+i]
-    end
-    return y
-end
-
-# TODO: non-allocating
-function mdwt!(y, x)
-    d = ndims(x)
-    s = size(x)
-    any(s .<= 1) && return (y .= x)
-
-    # transformation matrix
-    M = [sqrt(1/2^d)*(-1)^count_ones(i&j) for i in 0:2^d-1, j in 0:2^d-1]
-
-    offset = s.÷2
-    firsthalf = Base.OneTo.(offset)
-    for i in Iterators.product(firsthalf...)
-        j = @. 2*(i-1)+1
-
-        src_ix = ntuple(k->j[k]:1+j[k], d)
-        dst_ix = ntuple(k->i[k]:offset[k]:offset[k]+i[k], d)
-
-        src = vec(view(x, src_ix...))
-        dst = vec(view(y, dst_ix...))
-        mul!(dst, M, src)
-    end
-
-    mdwt!(view(y, firsthalf...), y[firsthalf...])
-
-    for k in 1:d
-        iseven(s[k]) && continue
-        selectdim(y, k, s[k]) .= selectdim(x, k, s[k])
-    end
-
-    return y
-end
-
-# TODO: non-allocating
-function imdwt!(y, x)
-    d = ndims(x)
-    s = size(x)
-    any(s .<= 1) && return (y .= x)
-
-    offset = s .÷ 2
-    firsthalf = Base.OneTo.(offset)
-
-    # transformation matrix
-    M = [sqrt(1 / 2^d)*(-1)^count_ones(i&j) for i in 0:2^d-1, j in 0:2^d-1]
-
-    for k in 1:d
-        iseven(s[k]) && continue
-        selectdim(y, k, s[k]) .= selectdim(x, k, s[k])
-    end
-
-    xt = copy(x)
-    imdwt!(view(xt, firsthalf...), x[firsthalf...])
-
-    for i in Iterators.product(firsthalf...)
-        j = @. 2*(i-1)+1
-
-        src_ix = ntuple(k->i[k]:offset[k]:offset[k]+i[k], d)
-        dst_ix = ntuple(k->j[k]:1+j[k], d)
-
-        src = vec(view(xt, src_ix...))
-        dst = vec(view(y, dst_ix...))
-        mul!(dst, M, src)
-    end
-
-    return y
-end
-mdwt(x) = mdwt!(similar(x), x)
-imdwt(x) = imdwt!(similar(x), x)
 
 function experiment_global_basic(ctx)
     f = loadimg(joinpath(ctx.indir, "input_original.png"))
diff --git a/scripts/wavelet.jl b/scripts/wavelet.jl
new file mode 100644
index 0000000000000000000000000000000000000000..71dbe78be34966058e1a7ac2cb2232a87cbf1fbb
--- /dev/null
+++ b/scripts/wavelet.jl
@@ -0,0 +1,104 @@
+# one-dimensional discrete wavelet transform
+
+function dwt!(y, x)
+    a = sqrt(1 / 2)
+    n = length(x)
+    n <= 1 && return (y .= x)
+    # use second half of y as scratch space
+    for i in 1:n÷2
+        j = 2*(i-1)+1
+        y[n÷2+i] = a * x[j] + a * x[j+1]
+    end
+    dwt!(view(y, 1:n÷2), view(y, n÷2+1:2*(n÷2)))
+    for i in 1:n÷2
+        j = 2*(i-1)+1
+        y[n÷2+i] = a * x[j] - a * x[j+1]
+    end
+    isodd(n) && (y[n] = x[n])
+    return y
+end
+
+function idwt!(y, x)
+    a = sqrt(1 / 2)
+    n = length(x)
+    n <= 1 && return (y .= x)
+    isodd(n) && (y[n] = x[n])
+    # use second half of y as scratch space
+    idwt!(view(y, n÷2+1:2*(n÷2)), view(x, 1:n÷2))
+    for i in 1:n÷2
+        j = 2*(i-1)+1
+        y[j]   = a * y[n÷2+i] + a * x[n÷2+i]
+        y[j+1] = a * y[n÷2+i] - a * x[n÷2+i]
+    end
+    return y
+end
+
+# multi-dimensional discrete wavelet transform
+
+# TODO: non-allocating
+function mdwt!(y, x)
+    d = ndims(x)
+    s = size(x)
+    any(s .<= 1) && return (y .= x)
+
+    # transformation matrix
+    M = [sqrt(1/2^d)*(-1)^count_ones(i&j) for i in 0:2^d-1, j in 0:2^d-1]
+
+    offset = s.÷2
+    firsthalf = Base.OneTo.(offset)
+    for i in Iterators.product(firsthalf...)
+        j = @. 2*(i-1)+1
+
+        src_ix = ntuple(k->j[k]:1+j[k], d)
+        dst_ix = ntuple(k->i[k]:offset[k]:offset[k]+i[k], d)
+
+        src = vec(view(x, src_ix...))
+        dst = vec(view(y, dst_ix...))
+        mul!(dst, M, src)
+    end
+
+    mdwt!(view(y, firsthalf...), y[firsthalf...])
+
+    for k in 1:d
+        iseven(s[k]) && continue
+        selectdim(y, k, s[k]) .= selectdim(x, k, s[k])
+    end
+
+    return y
+end
+
+# TODO: non-allocating
+function imdwt!(y, x)
+    d = ndims(x)
+    s = size(x)
+    any(s .<= 1) && return (y .= x)
+
+    offset = s .÷ 2
+    firsthalf = Base.OneTo.(offset)
+
+    # transformation matrix
+    M = [sqrt(1 / 2^d)*(-1)^count_ones(i&j) for i in 0:2^d-1, j in 0:2^d-1]
+
+    for k in 1:d
+        iseven(s[k]) && continue
+        selectdim(y, k, s[k]) .= selectdim(x, k, s[k])
+    end
+
+    xt = copy(x)
+    imdwt!(view(xt, firsthalf...), x[firsthalf...])
+
+    for i in Iterators.product(firsthalf...)
+        j = @. 2*(i-1)+1
+
+        src_ix = ntuple(k->i[k]:offset[k]:offset[k]+i[k], d)
+        dst_ix = ntuple(k->j[k]:1+j[k], d)
+
+        src = vec(view(xt, src_ix...))
+        dst = vec(view(y, dst_ix...))
+        mul!(dst, M, src)
+    end
+
+    return y
+end
+mdwt(x) = mdwt!(similar(x), x)
+imdwt(x) = imdwt!(similar(x), x)