Domanda

I have a list of 9 elements. The first three represent positions, the next velocities, the next forces.

Sometimes I require forces from the array, other times velocities, and other times positions.

So I wrote a function as follows:

def Extractor(alist,quantity):

    if quantity=='positions':
        x = alist[0]
        y = alist[1]
        z = alist[2]
        return (x,y,z)

    elif quantity=='velocities':
        vx = alist[3]
        vy = alist[4]
        vz = alist[5]
        tot_v = np.sqrt(vx**2 + vy**2 + vz**2)
        return (vx,vy,vz,tot_v)

    elif quantity=='forces':
        fx = alist[6]
        fy = alist[7]
        fz = alist[8]
        tot_f = np.sqrt(fx**2 + fy**2 + fz**2)
        return (fx,fy,fz,tot_f)
    else:
        print "Do not recognise quantity: not one of positions, velocities, force"

However, this seems like a massive code smell to me due to duplicate code. Is there a nicer, more pythonic way to do this? I'm quite new to OOP, but could I use some kind of class inheritance utilising polymorphism?

È stato utile?

Soluzione

Your method violates the Single Responsibility Principle. Consider splitting it like this:

def positionExtractor(alist):
    return tuple(alist[0:3])

def velocityExtractor(alist):
    velocity = tuple(alist[3:6])
    return velocity + (np.sqrt(sum(x**2 for x in velocity)),)

def forcesExtractor(alist):
    forces = tuple(alist[6:9])
    return forces + (np.sqrt(sum(x**2 for x in forces)),)

You can put them in a dictionary:

extractors = {
    'position' : positionExtractor,
    'velocity' : velocityExtractor,
    'forces' : forcesExtractor}

and use:

result = extractors[quantity](alist)

Here is an example with inheritance. It seems to be over-engineering for such a simple task though:

import numpy as np

class Extractor:
    def extract(self, alist):
        raise NotImplementedError()

class IndexRangeExtractor(Extractor):
    def __init__(self, fromIndex, toIndex):
        self.fromIndex = fromIndex
        self.toIndex = toIndex

    def extract(self, alist):
        return tuple(alist[self.fromIndex:self.toIndex])

class EuclideanDistanceExtractorDecorator(Extractor):
    def __init__(self, innerExtractor):
        self.innerExtractor = innerExtractor

    def extract(self, alist):
        innerResult = self.innerExtractor.extract(alist)
        distance = np.sqrt(sum(x**2 for x in innerResult))

        return innerResult + (distance,)

#... 

class ExtractorFactory:
    def __init__(self):
        self.extractors = {
            'position':IndexRangeExtractor(0, 3),
            'velocity':EuclideanDistanceExtractorDecorator(
                IndexRangeExtractor(3, 6)),
            'forces':EuclideanDistanceExtractorDecorator(
                IndexRangeExtractor(6, 9))}

    def createExtractor(self, quantity):
        return self.extractors[quantity]


alist = [1,2,3,4,5,6,7,8,9]
ef = ExtractorFactory()
e1 = ef.createExtractor('position')
e2 = ef.createExtractor('velocity')
e3 = ef.createExtractor('forces')

print e1.extract(alist)
print e2.extract(alist)
print e3.extract(alist)

Altri suggerimenti

You can start by using an offset to pick out the elements; all but positions need a formula applied too:

_slices = {'positions': slice(3), 'velocities': slice(3, 6), 'forces': slice(6, 9)}

def Extractor(alist, quantity):
    try:
        a, b, c = alist[_slices[quantity]]
        tot = np.sqrt(a**2 + b**2 + c**2)
        return a, b, c, tot
    except KeyError:
         raise ValueError(
             "Do not recognise quantity: "
             "not one of {}".format(', '.join(_slices)))        

This returns a consistent number of values; if calculating the square root for positions is not possible, I'd return 0.0 for the total:

tot = np.sqrt(a**2 + b**2 + c**2) if quantity != 'positions' else 0.0

Why not try something like this:

    def extract_position(x,y,z):
        return (x, y, z)

    def extract_velocities(x,y,z):
        return (x, y, z, np.sqrt(x**2 + y**2 + z**2))

    def extract_forces(x,y,z):
        return (x, y, z, np.sqrt(x**2 + y**2 + z**2))

    extractor = { 'positions': extract_position,
                  'velocities': extract_velocities,
                  'forces': extract_forces }

    try:

        print extractor['positions'](1,2,3)

        print extractor['unknown'](4,5,6)

    except KeyError:
        print "Do not recognise quantity: not one of positions, velocities, force"

I prefer to use function pointers to tie data to arbitrary calculations. Also, the dictionary replaced the switch style syntax so this at least feels similar to what you were looking for.

Also you have the same calculation for determining velocities and forces, so you could condense that as well.

Maybe something like:

from operator import itemgetter

def extract(sequence, quantity):
    try:
        a, b, c = {
            'positions': itemgetter(0, 1, 2),
            'velocities': itemgetter(3, 4, 5),
            'forces': itemgetter(6, 7, 8)
        }[quantity](sequence)
        return a, b, c, np.sqrt(a**2 + b**2, c**2)
    except KeyError as e:
        pass # handle no suitable quantity found here

Note that a calculation is always performed instead... keeps the return value consistent as a 4-tuple... unless it's a really expensive calculation, this shouldn't be an issue.

Autorizzato sotto: CC-BY-SA insieme a attribuzione
Non affiliato a StackOverflow
scroll top