Question

My problem is that I have a list of 2D parametric splines, and I need a more efficient way of rendering them onto an image grid. Each spline is determined by a series of points, a line radius / thickness (in pixels), and an opacity.

The original implementation I had in mind is similar to the question discussed here, which iterates through every single pixel on the image, finds the minimum distance to the curve, and then marks the pixel if the minimum distance is below the desired radius.

import math
import matplotlib.pyplot as plt
import numpy as np
import scipy.interpolate
import time

from PIL import Image

class GenePainter(object):
    def __init__(self, source):
        self.source = source

    def render(self):
        output = np.zeros(self.source.shape, dtype=np.float32)
        Ny, Nx = output.shape[0], output.shape[1]

        #x = np.array([5, 10, 15, 20, 5, 5])
        #y = np.array([5, 5, 20, 15, 10, 30])

        x = np.array(np.random.random(4) * 128, dtype=np.float32)
        y = np.array(np.random.random(4) * 128, dtype=np.float32)

        sx, sy = spline(x, y, 1000)

        t = time.time()
        for yi in xrange(Ny):
            for xi in xrange(Nx):
                d = min_distance(sx, sy, xi, yi)
                if d < 10.: # radius
                    output[yi, xi, :] = np.array([1, 1, 0, 0.5])
        print time.time() - t

        # t = time.time()
        # for _ in xrange(100):
        #   plt.plot(sx, sy, label='spline', linewidth=10, aa=False, solid_capstyle="round")
        # print time.time() - t

        plt.imshow(output, interpolation='none')
        plt.show()

    def score(self, image):
        return np.linalg.norm(self.source - image, 2)

def spline(x, y, n):
    if x.ndim != 1 or y.ndim != 1 or x.size != y.size:
        raise Exception()

    t = np.linspace(0, 1, x.size)

    sx = scipy.interpolate.interp1d(t, x, kind='cubic')
    sy = scipy.interpolate.interp1d(t, y, kind='cubic')

    st = np.linspace(0, 1, n)

    return sx(st), sy(st)

def min_distance(sx, sy, px, py):
    dx = sx - px
    dy = sy - py
    d = dx ** 2 + dy ** 2
    return math.sqrt(np.amin(d))

def read_image(file):
    image_raw = Image.open(file)
    image_raw.load()
    # return np.array(image_raw, dtype=np.float32)
    image_rgb = Image.new('RGB', image_raw.size)
    image_rgb.paste(image_raw, None)
    return np.array(image_rgb, dtype=np.float32)

if __name__ == "__main__":

    # source = read_image('ML129.png')
    source = np.zeros((256, 256, 4), dtype=np.float32)

    p = GenePainter(source)
    p.render()

The problem is that each spline drawing on a 256 x 256 RGBA image takes ~1.5 seconds because of the unoptimized iteration through each pixel, which is too slow for my purposes. I plan to have up to ~250 of these splines on a single image, and will processing up to ~100 images for a job, and maybe have up to ~1000 jobs in total, so I'm looking for any optimization that will cut down my computation time.

An alternative that I've looked into is to just draw all the splines onto a PyPlot plot, and then dump the final image to a numpy array that I can use for other calculations, which seems to run a bit faster, ~0.15 seconds to draw 100 splines.

plt.plot(sx, sy, label='spline', linewidth=10, aa=False, solid_capstyle="round")

The problem is that the linewidth parameter seems to correspond to pixels on my screen, rather than the number of pixels on the image (on the 256 x 256 grid), so when I resize the window, the scale of the line changes with the window, but the linewidth stays the same. I would like the curve width to correspond to the pixels on the 256 x 256 grid instead.

I would prefer to solve the issue by finding a way to greatly optimize the first numerical implementation, rather than the PyPlot drawing. I've also looked into downsampling the image (only computing distances for a subset of pixels rather than every pixel), but even with using 10% pixels 0.15 seconds per spline is still too slow.

Thank you in advance for any help or advice!

Was it helpful?

Solution

You can use matplotlib to do the drawing, here is an example:

I create a RendererAgg and a ndarray share the same memory with it. Then create the Line2D artist, and call the draw() method on the RendererAgg object.

import numpy as np
from matplotlib.backends.backend_agg import RendererAgg
w, h = 256, 256
r = RendererAgg(w, h, 72)
arr = np.frombuffer(r.buffer_rgba(), np.uint8)
arr.shape = r.height, r.width, -1
t = np.linspace(0, 2*np.pi, 100)
x = np.sin(2*t) * w*0.45 + w*0.5
y = np.cos(3*t) * h*0.45 + h*0.5

from matplotlib.lines import Line2D
line = Line2D(x, y, linewidth=5, color=(1.0, 0.0, 0.0), alpha=0.3)
line.draw(r)
pl.imsave("test.png", arr)

Here is the output:

enter image description here

Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top