Domanda

So, there exists an easy way to calculate the intersection of two sets via set.intersection(). However, I have the following problem:

class Person(Object):                    
    def __init__(self, name, age):                                                      
        self.name = name                                                                
        self.age = age                                                                  

l1 = [Person("Foo", 21), Person("Bar", 22)]                                             
l2 = [Person("Foo", 21), Person("Bar", 24)]                                             

union_list = list(set(l1).union(l2))                                           
# [Person("Foo", 21), Person("Bar", 22), Person("Bar", 24)]

(Object is a base-class provided by my ORM that implements basic __hash__ and __eq__ functionality, which essentially adds every member of the class to the hash. In other words, the __hash__ returned will be a hash of every element of the class)

At this stage, I would like to run a set intersection operation by .name only, to find, say, Person('Bar', -1).intersection(union_list) #= [Person("Bar", -1), Person("Bar", 22), Person("Bar", 24)]. (the typical .intersection() at this point would not give me anything, I can't override __hash__ or __eq__ on the Person class, as this would override the original set union (I think)

What's the best way to do this in Python 2.x?

EDIT: Note that the solution doesn't have to rely on a set. However, I need to find unions and then intersections, so it feels like this is amenable to a set (but I'm willing to accept solutions that use whatever magic you deem worthy, so long as it solves my problem!)

È stato utile?

Soluzione 5

I hate answering my own questions, so I'll hold off on marking this as the 'answer' for a little while yet.

Turns out the way to do this is as follows:

import types
p = Person("Bar", -1)
new_hash_method = lambda obj: hash(obj.name)
p.__hash__ = types.MethodType(new_hash_method, p)
for i in xrange(0, len(union_list)):
    union_list[i].__hash__ = types.MethodType(new_hash_method, union_list[i])
set(union_list).intersection(p)

It's certainly dirty and it relies on types.MethodType, but it's less intensive than the best solution proposed so far (glglgl's solution) as my actual union_list can contain potentially in the order of thousands of items, so this will save me re-creating objects every time I run this intersection procedure.

Altri suggerimenti

Sounds like

>>> class Person:
...     def __init__(self, name, age):
...         self.name = name
...         self.age = age
...     def __eq__(self, other):
...         return self.name == other.name
...     def __hash__(self):
...         return hash(self.name)
...     def __str__(self):
...         return self.name
...
>>> l1 = [Person("Foo", 21), Person("Bar", 22)]
>>> l2 = [Person("Foo", 21), Person("Bar", 24)]
>>> union_list = list(set(l1).union(l2))
>>> [str(l) for l in union_list]
['Foo', 'Bar']

is what you want, since name is your unique key?

How about:

d1 = {p.name:p for p in l1}
d2 = {p.name:p for p in l2}

intersectnames = set(d1.keys()).intersection(d2.keys)
intersect = [d1[k] for k in intersectnames]

It might be faster to throw intersectnames at your ORM, in which case you wouldn't build dictionaries, just collect names in lists.

This is clunky, but...

set(p for p in union_list for q in l2 if p.name == q.name and p.age != q.age) | (set(p for p in l2 for q in union_list if p.name == q.name and p.age != q.age))
# {person(name='Bar', age=22), person(name='Bar', age=24)}

If you want the age to be irrelevant with respect to comparing, you should override __hash__() and __eq__() in Person although you have it in your Object.

If you need this behaviour only in this (and similiar) contexts, you could create a wrapper object which contains the Person and behaves differently, such as

class PersonWrapper(Object):
    def __init__(self, person):
        self.person = person
    def __eq__(self, other):
        if hasattr(other, 'person'):
            return self.person.name == other.person.name
        else:
            return self.person.name == other.name
    def __hash__(self):
        return hash(self.person.name)

and then do

union_list = list(set(PersonWrapper(i) for i in l1).union(PersonWrapper(i) for i in l2))
# [Person("Foo", 21), Person("Bar", 22), Person("Bar", 24)]

(untested)

You'll have to override __hash__ and the comparision methods if you want to use sets like this.

If you don't, then

Person("Foo", 21) == Person("Foo", 21)

will always be false.

If your objects are managed by an ORM, then you'll have to check how it compares objects. Usually it only looks at the objects id and comparision only works if both objects are managed. If you try to compare an object you got from the ORM with an instance you created yourself before it's persisted to the db, then they are likely to be different. Anyway, an ORM shouldn't have problems with you supplying your own comparision logic.

But if for some reasons you can't override __hash__ and __eq__, then you can't use sets for intersection and union with the original objects. You could:

  • calculate the intersection/union yourself
  • create a wrapper class which is comparable:

    class Person:                    
        def __init__(self, name, age):                                                      
            self.name = name                                                                
            self.age = age                                                                  
    
    l1 = [Person("Foo", 21), Person("Bar", 22)]                                             
    l2 = [Person("Foo", 21), Person("Bar", 24)]                                             
    
    class ComparablePerson:
        def __init__(self, person):
            self.person = person
    
        def __hash__(self):
            return hash(self.person.name) + 31*hash(self.person.age)
    
        def __eq__(self, other):
            return (self.person.name == other.person.name and
                    self.person.age == other.person.age)
        def __repr__(self):
            return "<%s - %d>" % (self.person.name, self.person.age)
    
    c1 = set(ComparablePerson(p) for p in l1)
    c2 = set(ComparablePerson(p) for p in l2)
    
    print c1
    print c2
    print c1.union(c2)
    print c2.intersection(c1)
    
Autorizzato sotto: CC-BY-SA insieme a attribuzione
Non affiliato a StackOverflow
scroll top