Question

I have a question regarding outputting in Python.

I have the following 3 files as input data:

File A

abc with-1-rosette-n    2
abc with-1-tyre-n   1
abc with-1-weight-n 2

File B

def with-1-rosette-n 1
def with-1-tyre-n   2
def about-bit-n 1

File C

ghi with-1-rosette-n  2
ghi as+n-produce-v   1
ghi then-damage-v  1

I first tried to create a script in which I would consider the sum of the values (Col 3) of the intersection of Col 2.

This works fine - outputs all lines properly.

I tried to modify the script to consider the mean of the values of Col 3 of the intersection of Col 2 and this is where I run into trouble.

Basically, the script does not output the lines of the intersection.

Script A

def sumVectors(classA_infile, classB_infile, outfile):

        class_dictA = {}

        with open(classA_infile, "rb") as opened_infile_A:
                for line in opened_infile_A:
                        items = line.split()
                        classA, feat, valuesA = items[:3]
                        class_dictA[feat] = float(valuesA)


        class_dictB = {}

        with open(classB_infile, "rb") as opened_infile_B:
                for line in opened_infile_B:
                        items = line.split()
                        classB, feat, valuesB = items[:3]
                        class_dictB[feat] = float(valuesB)

        with open(outfile, "wb") as output_file:
                for key in class_dictA:
                        if key in class_dictB:
                                weight = (class_dictA[key] + class_dictB[key])/2
                                outstring = "\t".join([classA + "-" +  classB, key, str(weight)])
                                print outstring
                        else:
                                weight = class_dictA[key]
                                outstring = "\t".join([classA + "-" +  classB, key, str(weight)])
                output_file.write(outstring + "\n")

                for key in class_dictB:
                        if key not in class_dictA:
                                weight = class_dictB[key]
                                outstring = "\t".join([classA + "-" + classB, key, str(weight)])
                                output_file.write(outstring + "\n")

When I try to incorporate a third file: I run into a key issue. Here, I am trying to see if a key in File C is also in files A and B, if so, we take the average of those three files. In this case, it is giving me a Key Error, just when it enters the first if block, so I am having a difficult time solving this problem.

Here is the example of the script for considering 3 files.

Script B

def sumVectors(classA_infile, classB_infile, classC_infile, outfile):

        class_dictA = {}

        with open(classA_infile, "rb") as opened_infile_A:
                for line in opened_infile_A:
                        items = line.split()
                        classA, feat, valuesA = items[:3]
                        class_dictA[feat] = float(valuesA)


        class_dictB = {}

        with open(classB_infile, "rb") as opened_infile_B:
                for line in opened_infile_B:
                        items = line.split()
                        classB, feat, valuesB = items[:3]
                        class_dictB[feat] = float(valuesB)

        class_dictC = {}

        with open(classC_infile, "rb") as opened_infile_C:
                for line in opened_infile_C:
                        items = line.split()
                        classC, feat, valuesC = items[:3]
                        class_dictC[feat] = float(valuesC)

        with open(outfile, "wb") as output_file:
                for key in class_dictC:
                        if key in class_dictA and class_dictB:
                                weight = (class_dictA[key] + class_dictB[key]+ class_dictC[key])/3
                                outstring = "\t".join([classA + "-" +  classB + "-" +  classC, key, str(weight)])
                                print outstring
                        else:
                                weight = class_dictC[key]
                                outstring = "\t".join([classA + "-" +  classB + "-" +  classC,  key, str(weight)])
                                output_file.write(outstring + "\n")

In the case of Script A, the desired output would be:

(where we consider the mean of the common element in Col 2):

abc-def with-1-rosette-n    1.5
abc-def with-1-tyre-n   1
abc-def with-1-weight-n 2
def with-1-tyre-n   2
def about-bit-n 1

and in the case of Script B, the desired output would be:

File B (where we consider the mean of common element of all 3 files in Col 2):

abc-def-ghi with-1-rosette-n    1.667
abc-def-ghi with-1-tyre-n   1.5
abc-def-ghi with-1-weight-n 2
abc-def-ghi with-1-rosette-n 1.5
abc-def-ghi about-bit-n 1
abc-def-ghi as+n-produce-v   1
abc-def-ghi then-damage-v  1

Can anyone help me see where I am going wrong, and I am unsure of the best pythonic route to solve it... Thanks.

Was it helpful?

Solution

from collections import defaultdict

# Because you are looking for a union of files, we can treat
#  the input data as a simple concatenation of all input files;
# If you were after intersection, we would have to deal with
#  each input file separately.
def chain_from_files(*filenames):
    for fname in filenames:
        with open(fname, "rb") as inf:
            for line in inf:
                yield line

# get the key and all related data for each line
def get_item(line):
    row = line.split()
    return row[1], (row[0], int(row[2]))    # <= returns a tuple ('abc', 2)

# iterate through the input,
# collect a list of related values for each key
def collect_items(lines, get_item):
    result = defaultdict(list)
    for line in lines:
        key, value = get_item(line)
        result[key].append(value)
    return result

# make an output-string for each key
# and its list of related values
def show_item(key, values):
    classes, nums = zip(*values)          # <= unpacks the tuples
    classes = '-'.join(sorted(set(classes)))
    average = float(sum(nums)) / len(nums)
    return "{} {} {}\n".format(classes, key, average)

def main():
    lines = chain_from_files(classA_infile, classB_infile, classC_infile)
    data  = collect_items(lines, get_item)

    with open(outputfile, "wb") as outf:
        for key,value in data.items():
            outf.write(show_item(key, value))

if __name__=="__main__":
    main()

which gives as output

ghi then-damage-v 1.0
abc-def with-1-tyre-n 1.5
abc-def-ghi with-1-rosette-n 1.66666666667
ghi as+n-produce-v 1.0
abc with-1-weight-n 2.0
def about-bit-n 1.0
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top