سؤال

I am trying to implement a KD-tree for use with DBSCAN. The problem is that I need to find all the neighbours of all points that meet a distance criteria. The problem is I don't get the same output when using the naive search (which is the desired output) as when I use the nearestNeighbours method in my implementation. My implementation is adapted from a python implementation. Here's what I've got so far:

//Point.java
package dbscan_gui;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;

public class Point {

    final HashSet<Point> neighbours = new HashSet<Point>();
    int[] points;
    boolean visited = false;
    public Point(int... is) {
        this.points = is;
    }
    public String toString() {
        return Arrays.toString(points);
    }

    public double squareDistance(Point p) {
        double sum = 0;
        for (int i = 0;i < points.length;i++) {
            sum += Math.pow(points[i] - p.points[i],2);
        }
        return sum;
    }
    public double distance(Point p) {
        return Math.sqrt(squareDistance(p));
    }
    public void addNeighbours(ArrayList<Point> ps) {
        neighbours.addAll(ps);
    }
    public void addNeighbour(Point p) {
        if (p != this)
            neighbours.add(p);
    }
}

//KDTree.java
package dbscan_gui;


import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.TreeSet;


public class KDTree {
    KDTreeNode root;
    PointComparator[] comps;
    public KDTree(ArrayList<Point> list) {
        int axes = list.get(0).points.length;

        comps = new PointComparator[axes];
        for (int i = 0; i < axes; i++) {
            comps[i] = new PointComparator(i);
        }
        root = new KDTreeNode(list,0);
    }
    private class PointComparator implements Comparator<Point> {
        private int axis;
        public PointComparator(int axis) {
            this.axis = axis;
        }
        @Override
        public int compare(Point p1, Point p2) {
            return p1.points[axis] - p2.points[axis];
        } 
    }

    /**
     * Adapted from https://code.google.com/p/python-kdtree/
     * Stores points in a tree, sorted by axis
     */
    public class KDTreeNode {
        KDTreeNode leftChild = null;
        KDTreeNode rightChild = null;
        Point location;

        public KDTreeNode(ArrayList<Point> list, int depth) {
            if(list.isEmpty())
                return;
            final int axis = depth % (list.get(0).points.length);

            Collections.sort(list, comps[axis] );
            int median = list.size()/2;
            location = list.get(median);
            List<Point> leftPoints = list.subList(0, median);
            List<Point> rightPoints = list.subList(median+1, list.size());
            if(!leftPoints.isEmpty())
                leftChild  = new KDTreeNode(new ArrayList<Point>(leftPoints), depth+1);
            if(!rightPoints.isEmpty())
                rightChild = new KDTreeNode(new ArrayList<Point>(rightPoints),depth+1);
        }

        /**
         * @return true if this node has no children
         */
        public boolean isLeaf() {
            return leftChild == null && rightChild == null;
        }

    }
    /**
     * Finds the nearest neighbours of a point that fall within a given distance
     * @param queryPoint the point to find the neighbours of
     * @param epsilon the distance threshold
     * @return the list of points
     */
    public ArrayList<Point> nearestNeighbours(Point queryPoint, int epsilon) {
        KDNeighbours neighbours = new KDNeighbours(queryPoint);
        nearestNeighbours_(root, queryPoint, 0, neighbours);
        return neighbours.getBest(epsilon);
    }
    /**
     * @param node
     * @param queryPoint
     * @param depth
     * @param bestNeighbours
     */
    private void nearestNeighbours_(KDTreeNode node, Point queryPoint, int depth, KDNeighbours bestNeighbours) {
        if(node == null)
            return;
        if(node.isLeaf()) {
            bestNeighbours.add(node.location);
            return;
        }
        int axis = depth % (queryPoint.points.length);
        KDTreeNode nearSubtree = node.rightChild;
        KDTreeNode farSubtree  = node.leftChild;
        if(queryPoint.points[axis] < node.location.points[axis]) {
            nearSubtree = node.leftChild;
            farSubtree = node.rightChild;
        }
        nearestNeighbours_(nearSubtree, queryPoint,  depth+1, bestNeighbours);
        if(node.location != queryPoint)
            bestNeighbours.add(node.location);       
        if(Math.pow(node.location.points[axis] - queryPoint.points[axis],2) <= bestNeighbours.largestDistance)
            nearestNeighbours_(farSubtree, queryPoint, depth+1,bestNeighbours);
        return;
    }
    /**
     * Private datastructure for holding the neighbours of a point
     */
    private class KDNeighbours {
        Point queryPoint;
        double largetsDistance = 0;
        TreeSet<Tuple> currentBest = new TreeSet<Tuple>(new Comparator<Tuple>() {

            @Override
            public int compare(Tuple o1, Tuple o2) {
                return (int) (o1.y-o2.y);
            }

        });
        KDNeighbours(Point queryPoint) {
            this.queryPoint = queryPoint;
        }
        public ArrayList<Point> getBest(int epsilon) {
            ArrayList<Point> best = new ArrayList<Point>();
            Iterator<Tuple> it = currentBest.iterator();
            while(it.hasNext()) {
                Tuple t =it.next();
                if(t.y > epsilon*epsilon)
                    break;
                else if(t.x != queryPoint)
                    best.add(t.x);
            }
            return best;
        }

        public void add(Point p) {
            currentBest.add(new Tuple(p, p.squareDistance(queryPoint)));
            largestDistance = currentBest.last().y;
        }
        private class Tuple  {
            Point x;
            double y;
            Tuple(Point x, double y) {
                this.x = x;
                this.y = y;
            }
        }
    }

    public static void main(String[] args) {
        int epsilon = 3;

        System.out.println("Epsilon: "+epsilon);
        ArrayList<Point> points = new ArrayList<Point>();
        Random r = new Random();
        for (int i = 0; i < 10; i++) {
            points.add(new Point(r.nextInt(10), r.nextInt(10)));
        }
        System.out.println("Points "+points );
        System.out.println("----------------");
        System.out.println("Neighbouring Kd");
        KDTree tree = new KDTree(points);
        for (Point p : points) {
            ArrayList<Point> neighbours = tree.nearestNeighbours(p, epsilon);
            for (Point q : neighbours) {
                q.addNeighbour(p);
            }
            p.addNeighbours(neighbours);
            p.printNeighbours();
            p.neighbours.clear();
        }
        System.out.println("------------------");
        System.out.println("Neighbouring O(n^2)");
        for (int i = 0; i < points.size(); i++) {
            for (int j = i + 1; j < points.size(); j++) {
                Point p = points.get(i), q = points.get(j);
                if (p.distance(q) <= epsilon) {
                    p.addNeighbour(q);
                    q.addNeighbour(p);
                }
            }
        }
        for (Point point : points) {
            point.printNeighbours();
        }

    }
}

When I run this I get the following output (the latter part being the model output):

Epsilon: 3
Points [[9, 5], [4, 7], [3, 1], [0, 0], [5, 7], [0, 1], [5, 5], [1, 2], [9, 2], [9, 9]]
----------------
Neighbouring Kd
Neighbours of [0, 0] are: [[0, 1]]
Neighbours of [0, 1] are: [[1, 2], [0, 0], [3, 1]]
Neighbours of [1, 2] are: [[0, 1], [3, 1]]
Neighbours of [3, 1] are: [[0, 1], [1, 2]]
Neighbours of [4, 7] are: [[5, 7]]
Neighbours of [5, 7] are: [[4, 7]]
Neighbours of [5, 5] are: [[4, 7], [5, 7]]
Neighbours of [9, 5] are: [[9, 2]]
Neighbours of [9, 2] are: [[9, 5]]
Neighbours of [9, 9] are: []
------------------
Neighbouring O(n^2)
Neighbours of [0, 0] are: [[0, 1], [1, 2]]
Neighbours of [0, 1] are: [[1, 2], [0, 0], [3, 1]]
Neighbours of [1, 2] are: [[0, 1], [0, 0], [3, 1]]
Neighbours of [3, 1] are: [[0, 1], [1, 2]]
Neighbours of [4, 7] are: [[5, 5], [5, 7]]
Neighbours of [5, 7] are: [[4, 7], [5, 5]]
Neighbours of [5, 5] are: [[4, 7], [5, 7]]
Neighbours of [9, 5] are: [[9, 2]]
Neighbours of [9, 2] are: [[9, 5]]
Neighbours of [9, 9] are: []

I can't figure out why the neighbours aren't the same, it seems that it can find that a->b is a neighbouring, but not that b->a is also a neighbouring.

هل كانت مفيدة؟

المحلول

You may want to use ELKI which includes DBSCAN and index structures such as the R*-tree for nearest neighbors search. When parameterized right, it's really really fast. I saw in the trac that the next version will also have a KD-tree.

From a quick look at your code, I have to agree with @ThomasJungblut - you do not backtrack and then try the other branch as necessary, which is why you miss a lot of neighbors. You may need to look at both branches!

مرخصة بموجب: CC-BY-SA مع الإسناد
لا تنتمي إلى StackOverflow
scroll top