diff --git a/scripts/run_experiments.jl b/scripts/run_experiments.jl
index 292b5fc485323806c0b01624dd85b76ee6284d56..87ca31fd90da581991148cc21f96f8f105088ae5 100644
--- a/scripts/run_experiments.jl
+++ b/scripts/run_experiments.jl
@@ -981,7 +981,7 @@ end
 
 
 function experiment_image_mesh_interpolation(ctx)
-    imgf = loadimg(joinpath(ctx.indir, "input.png"))
+    imgf = from_img(loadimg(joinpath(ctx.indir, "input.png")))
 
     df_psnr = DataFrame()
     df_ssim = DataFrame()
@@ -1008,7 +1008,7 @@ function experiment_image_mesh_interpolation(ctx)
             save_csv(joinpath(ctx.outdir, "$(mesh_size)_$(method).csv"), u)
 
             imgu = sample(u)
-            saveimg(joinpath(ctx.outdir, "$(mesh_size)_$(method).png"), imgu)
+            saveimg(joinpath(ctx.outdir, "$(mesh_size)_$(method).png"), to_img(imgu))
 
             return method => (
                 psnr = assess_psnr(imgu, imgf),
diff --git a/src/function.jl b/src/function.jl
index 6216bd9dc551dd62dcf12b23c3ad36e30f42ac07..17cb43be649fa848dae11687b324d4fdb52e7e46 100644
--- a/src/function.jl
+++ b/src/function.jl
@@ -1,7 +1,7 @@
 using Statistics: mean
 using StaticArrays: SA, SArray, MVector, MMatrix, SUnitRange
 using DataFrames, CSV
-export FeSpace, Mapper, FeFunction, ImageFunction, P1, DP0, DP1
+export FeSpace, Mapper, FeFunction, P1, DP0, DP1
 export interpolate!, sample, bind!, evaluate, integrate, nabla, save_csv
 
 # Finite Elements
diff --git a/src/image.jl b/src/image.jl
index 9d9da61505f2df903d21feed031c19d34e872c59..cfb3c9fb2c9afadcc268b31dcbefced541f242ea 100644
--- a/src/image.jl
+++ b/src/image.jl
@@ -63,23 +63,21 @@ function halve(img)
     return res
 end
 
-# ImageFunction mesh function wrapper
+# ArrayFunction mesh function wrapper
 
 # TODO: inherit from some abstract mesh function type
 # TODO: make mesh type a parameter for performance
-struct ImageFunction{Img}
+struct ArrayFunction{Img}
     mesh::Mesh
     img::Img
     cell::Base.RefValue{Int}
 end
 
-ImageFunction(mesh, img) = ImageFunction(mesh, img, Ref(0))
+ArrayFunction(mesh, img) = ArrayFunction(mesh, img, Ref(0))
 
-bind!(f::ImageFunction, cell) = f.cell[] = cell
-# transform coordinates to image/matrix indexing space
-img_coord(img, x) = (size(img, 1) - x[2] + 1, x[1])
-evaluate(f::ImageFunction, xloc) = evaluate_bilinear(f.img,
-    img_coord(f.img, elmap(f.mesh, f.cell[])(xloc)))
+bind!(f::ArrayFunction, cell) = f.cell[] = cell
+evaluate(f::ArrayFunction, xloc) = evaluate_bilinear(f.img,
+    elmap(f.mesh, f.cell[])(xloc))
 
 # Sampling
 
@@ -127,12 +125,7 @@ function _sample(f::FeFunction)
         end
     end
 
-    # convert to image indexing
-    d = length(f.space.size)
-    reverse!(out, dims = d + 2) # reverse y
-    out2 = permutedims(out, (ntuple(identity, d)..., d + 2, d + 1)) # flip xy
-
-    return out2
+    return out
 end
 
 
@@ -147,7 +140,7 @@ the default interpolation operator for the discrete function u.
 """
 # TODO: should be called "interpolate_nodal"
 function interpolate!(u::FeFunction, img::AbstractArray)
-    f = ImageFunction(u.space.mesh, img)
+    f = ArrayFunction(u.space.mesh, img)
     interpolate!(u, @inline (x; f) -> f; f)
 end
 
@@ -163,7 +156,7 @@ function project_l2_lagrange!(u::FeFunction, img::AbstractArray)
     d = 2 # domain dimension
     space = u.space
     mesh = space.mesh
-    f = ImageFunction(mesh, img)
+    f = ArrayFunction(mesh, img)
     opparams = (; f)
 
     nrdims = prod(space.size)
@@ -256,7 +249,7 @@ function project_qi_lagrange!(u::FeFunction, img)
     d = 2 # domain dimension
     space = u.space
     mesh = space.mesh
-    f = ImageFunction(mesh, img)
+    f = ArrayFunction(mesh, img)
     # count contributions to respective dof for subsequent averaging
     gdofcount = zeros(Int, size(u.data))
     area_refel = 0.5
@@ -351,7 +344,7 @@ function project_l2_pixel!(u::FeFunction, img)
     d = 2 # domain dimension
     space = u.space
     mesh = space.mesh
-    f = ImageFunction(mesh, img)
+    f = ArrayFunction(mesh, img)
     opparams = (; f)
 
     nrdims = prod(space.size)
@@ -360,7 +353,7 @@ function project_l2_pixel!(u::FeFunction, img)
 
     # first loop to count number of triangles intersecting pixel evaluation
     # points
-    ncells = zeros(axes(img, 2), axes(img, 1))
+    ncells = zeros(axes(img))
     for cell in cells(mesh)
         pixels = PixelIterator(mesh, cell)
         for I in pixels
diff --git a/src/mesh.jl b/src/mesh.jl
index 1638735290ef2e00b472fae9ab24637269a3939e..6b0ef3a847514d8b5059d0ad39c381ba5d8e0922 100644
--- a/src/mesh.jl
+++ b/src/mesh.jl
@@ -166,15 +166,15 @@ function init_grid(m::Int, n::Int = m, v0 = (0., 0.), v1 = (1., 1.))
     return Mesh(vertices, cells)
 end
 
-function init_grid(img::Array{<:Any, 2}; type=:vertex)
-    s = (size(img, 2), size(img, 1))
+function init_grid(a::Array{<:Any, 2}; type=:vertex)
+    s = size(a)
     type == :vertex ?
         init_grid((s .- 1)..., (1.0, 1.0), s) :
 	init_grid(s..., (0.5, 0.5), s .- (0.5, 0.5))
 end
 
-function init_grid(img::Array{<:Any, 2}, m::Int, n::Int = m; type=:vertex)
-    s = (size(img, 2), size(img, 1))
+function init_grid(a::Array{<:Any, 2}, m::Int, n::Int = m; type=:vertex)
+    s = size(a)
     type == :vertex ?
         init_grid(((m, n) .- 1)..., (1.0, 1.0), s) :
         init_grid((m, n)..., (0.5, 0.5), s .- (0.5, 0.5))