Question

I'm working on some code that is not very well written and involves some fairly complex logic which I wish to refactor. The topic is the validation of rules and reporting potential violations. Unfortunately the class design is rather weird, so I am stuck with some IEnumerable challenges.

As a simplified example, I have the following:

IEnumerable<RuleDefinition>
IEnumerable<Request>

where

public class RuleDefinition
{
    public RequestType ConcerningRequestType { get; set; }
    public int MinimumDistanceBetweenRequests { get; set; }
}

public class Request
{
    public int TimeIndex { get; set; }
    public RequestType TypeOfThisRequest { get; set; }
}

Obviously, the rule is violated when the request type matches and the interval (TimeIndex) between two Requests is too short. Now, I want to extract:

  • If there are rule violations (that is fairly easy)
  • Which rules are violated
  • Which requests violate the rule

So in our case, I would like to obtain something like this:

public class Violation
{
    public RuleDefinition ViolatedRule { get; set; }
    public Request FirstRequest { get; set; }
    public Request SecondRequest { get; set; }
}

I think this is a fairly simple problem, but I fail to come up with a solution that can be called well readable and maintainable. I've tried various things.. it always turns out completely messy (I just tried to implement this example, and it's terrible)

Any ideas, patterns to use in this case? (Resharper often correctly suggest .SelectMany, but that makes things even less readable)

EDIT: Here's my long and ugly implementation. ;)

var ruleDefinitions = new List<RuleDefinition>
{ 
    new RuleDefinition { 
        ConcerningRequestType = RequestType.Exclusive, 
        MinimumDistanceBetweenRequests = 2} 
};
var requests = new List<Request>()
    {
        new Request { TimeIndex = 1, TypeOfThisRequest = RequestType.Normal },
        new Request { TimeIndex = 1, TypeOfThisRequest = RequestType.Normal },
        new Request { TimeIndex = 2, TypeOfThisRequest = RequestType.Normal },

        new Request { TimeIndex = 3, TypeOfThisRequest = RequestType.Exclusive },
        new Request { TimeIndex = 4, TypeOfThisRequest = RequestType.Exclusive },
    };

var violations = new List<Violation>();
foreach (var rule in ruleDefinitions)
{
    var requestsMatchingType = requests.Where(r => r.TypeOfThisRequest == rule.ConcerningRequestType);
    foreach (var firstRequest in requestsMatchingType)
    {
        var collidingRequest = requests.FirstOrDefault(secondRequest => 
            secondRequest.TimeIndex > firstRequest.TimeIndex &&
            Math.Abs(secondRequest.TimeIndex - firstRequest.TimeIndex) < rule.MinimumDistanceBetweenRequests);

        if (collidingRequest != null)
        {
            violations.Add(new Violation
                {
                    ViolatedRule = rule,
                    FirstRequest = firstRequest,
                    SecondRequest = collidingRequest
                });
        }
    }
}
Console.WriteLine("found {0} violations.", violations.Count());
Console.ReadKey();
Was it helpful?

Solution

It's not a straightforward task so the first thing I would do is to define an interface to see what I need here:

interface IViolationFinder
{
    IEnumerable<Violation> Search(
        IEnumerable<RuleDefinition> ruleDefinitions, 
        IEnumerable<Request> requests);
}

Now we clearly see what we need to implement. Because your search logic is quite complex I don't think you should express it with a single linq. You can, but you shouldn't. That two nested foreach loop with linq embedded inside is quite nasty and I don't think it would be cleaner with linq itself.

What you need is to make more methods inside your implementation. It's going to increase the readability. So the naive implementation would be this (this is yours):

class ViolationFinder : IViolationFinder
{
    public IEnumerable<Violation> Search(IEnumerable<RuleDefinition> ruleDefinitions, IEnumerable<Request> requests)
    {
        var violations = new List<Violation>();
        foreach (var rule in ruleDefinitions)
        {
            var requestsMatchingType = requests.Where(r => r.TypeOfThisRequest == rule.ConcerningRequestType);
            foreach (var firstRequest in requestsMatchingType)
            {
                var collidingRequest = requests.FirstOrDefault(secondRequest =>
                    secondRequest.TimeIndex > firstRequest.TimeIndex &&
                    Math.Abs(secondRequest.TimeIndex - firstRequest.TimeIndex) < rule.MinimumDistanceBetweenRequests);

                if (collidingRequest != null)
                {
                    violations.Add(new Violation
                    {
                        ViolatedRule = rule,
                        FirstRequest = firstRequest,
                        SecondRequest = collidingRequest
                    });
                }
            }
        }

        return violations;
    }
}

You can start refactor this. Instead of thinking in one method, let's extract the most obvious part:

class ViolationFinder : IViolationFinder
{
    public IEnumerable<Violation> Search(IEnumerable<RuleDefinition> ruleDefinitions, IEnumerable<Request> requests)
    {
        var violations = new List<Violation>();
        foreach (RuleDefinition rule in ruleDefinitions)
        {
            IEnumerable<Request> requestsMatchingType = requests.Where(r => r.TypeOfThisRequest == rule.ConcerningRequestType);
            violations.AddRange(
                FindViolationsInRequests(requestsMatchingType, requests, rule));
        }

        return violations;
    }

    private IEnumerable<Violation> FindViolationsInRequests(
        IEnumerable<Request> matchingRequests,
        IEnumerable<Request> allRequest,
        RuleDefinition rule)
    {
        foreach (Request firstRequest in matchingRequests)
        {
            var collidingRequest = allRequest.FirstOrDefault(secondRequest =>
                secondRequest.TimeIndex > firstRequest.TimeIndex &&
                Math.Abs(secondRequest.TimeIndex - firstRequest.TimeIndex) < rule.MinimumDistanceBetweenRequests);

            if (collidingRequest != null)
            {
                yield return new Violation
                {
                    ViolatedRule = rule,
                    FirstRequest = firstRequest,
                    SecondRequest = collidingRequest
                };
            }
        }
    }
}

Search is almost clean, but we see that FindViolationsInRequests gets every request and the rule so passing the filtered request list is quite useless. So we do this:

class ViolationFinder : IViolationFinder
{
    public IEnumerable<Violation> Search(IEnumerable<RuleDefinition> ruleDefinitions, IEnumerable<Request> requests)
    {
        var violations = new List<Violation>();
        foreach (RuleDefinition rule in ruleDefinitions)
        {
            violations.AddRange(FindViolationsInRequests(requests, rule));
        }

        return violations;
    }

    private IEnumerable<Violation> FindViolationsInRequests(
        IEnumerable<Request> allRequest,
        RuleDefinition rule)
    {
        foreach (Request firstRequest in FindMatchingRequests(allRequest, rule))
        {
            var collidingRequest = allRequest.FirstOrDefault(secondRequest =>
                secondRequest.TimeIndex > firstRequest.TimeIndex &&
                Math.Abs(secondRequest.TimeIndex - firstRequest.TimeIndex) < rule.MinimumDistanceBetweenRequests);

            if (collidingRequest != null)
            {
                yield return new Violation
                {
                    ViolatedRule = rule,
                    FirstRequest = firstRequest,
                    SecondRequest = collidingRequest
                };
            }
        }
    }

    private IEnumerable<Request> FindMatchingRequests(IEnumerable<Request> requests, RuleDefinition rule)
    {
        return requests.Where(r => r.TypeOfThisRequest == rule.ConcerningRequestType);
    }
}

The next thing now is that the

    var collidingRequest = allRequest.FirstOrDefault(secondRequest =>
        secondRequest.TimeIndex > firstRequest.TimeIndex &&
        Math.Abs(secondRequest.TimeIndex - firstRequest.TimeIndex) < rule.MinimumDistanceBetweenRequests);

is complex enough to make some method for it:

class ViolationFinder : IViolationFinder
{
    public IEnumerable<Violation> Search(IEnumerable<RuleDefinition> ruleDefinitions, IEnumerable<Request> requests)
    {
        var violations = new List<Violation>();

        foreach (RuleDefinition rule in ruleDefinitions)
        {
            violations.AddRange(FindViolationsInRequests(requests, rule));
        }

        return violations;
    }

    private IEnumerable<Violation> FindViolationsInRequests(
        IEnumerable<Request> allRequest,
        RuleDefinition rule)
    {
        foreach (Request firstRequest in FindMatchingRequests(allRequest, rule))
        {

            Request collidingRequest = FindCollidingRequest(allRequest, firstRequest, rule.MinimumDistanceBetweenRequests);

            if (collidingRequest != null)
            {
                yield return new Violation
                {
                    ViolatedRule = rule,
                    FirstRequest = firstRequest,
                    SecondRequest = collidingRequest
                };
            }
        }
    }

    private IEnumerable<Request> FindMatchingRequests(IEnumerable<Request> requests, RuleDefinition rule)
    {
        return requests.Where(r => r.TypeOfThisRequest == rule.ConcerningRequestType);
    }

    private Request FindCollidingRequest(IEnumerable<Request> requests, Request firstRequest, int minimumDistanceBetweenRequests)
    {
        return requests.FirstOrDefault(secondRequest => IsCollidingRequest(firstRequest, secondRequest, minimumDistanceBetweenRequests));
    }

    private bool IsCollidingRequest(Request firstRequest, Request secondRequest, int minimumDistanceBetweenRequests)
    {
        return secondRequest.TimeIndex > firstRequest.TimeIndex &&
               Math.Abs(secondRequest.TimeIndex - firstRequest.TimeIndex) < minimumDistanceBetweenRequests;
    }
}

Ok, it's getting cleaner. I can almost easily tell the purpose of every method. Just a bit more work and you end up something like this:

class ViolationFinder : IViolationFinder
{
    public IEnumerable<Violation> Search(IEnumerable<RuleDefinition> ruleDefinitions, IEnumerable<Request> requests)
    {
        List<Request> requestList = requests.ToList();
        return ruleDefinitions.SelectMany(rule => FindViolationsInRequests(requestList, rule));
    }

    private IEnumerable<Violation> FindViolationsInRequests(IEnumerable<Request> allRequest, RuleDefinition rule)
    {
        return FindMatchingRequests(allRequest, rule)
                .Select(firstRequest => FindSingleViolation(allRequest, firstRequest, rule))
                .Where(violation => violation != null);
    }

    private Violation FindSingleViolation(IEnumerable<Request> allRequest, Request request, RuleDefinition rule)
    {
        Request collidingRequest = FindCollidingRequest(allRequest, request, rule.MinimumDistanceBetweenRequests);

        if (collidingRequest != null)
        {
            return new Violation
            {
                ViolatedRule = rule,
                FirstRequest = request,
                SecondRequest = collidingRequest
            };
        }

        return null;
    }

    private IEnumerable<Request> FindMatchingRequests(IEnumerable<Request> requests, RuleDefinition rule)
    {
        return requests.Where(r => r.TypeOfThisRequest == rule.ConcerningRequestType);
    }

    private Request FindCollidingRequest(IEnumerable<Request> requests, Request firstRequest, int minimumDistanceBetweenRequests)
    {
        return requests.FirstOrDefault(secondRequest => IsCollidingRequest(firstRequest, secondRequest, minimumDistanceBetweenRequests));
    }

    private bool IsCollidingRequest(Request firstRequest, Request secondRequest, int minimumDistanceBetweenRequests)
    {
        return secondRequest.TimeIndex > firstRequest.TimeIndex &&
               Math.Abs(secondRequest.TimeIndex - firstRequest.TimeIndex) < minimumDistanceBetweenRequests;
    }
}

Please note that single responsibility principle is applicable for methods too. Except the Search method everything is the part of the private implementation, but as you may see each processing part got a method with a name. Each method has it's single responsibility so you can read the implementation much easier.

  • Search (the public)
  • FindViolationsInRequests
  • FindSingleViolation
  • FindMatchingRequests
  • FindCollidingRequest
  • IsCollidingRequest

These are the units of this implementation.

The refactor process would be much more safe if you would write unit tests for the original implementation and only after that you start the refactor. Then you always know you don't break your logic. Your unit tests would be okay if you write them againts the first variation (when I put your full code into the Search method), so against the interface.

Another little but important part is this:

public IEnumerable<Violation> Search(IEnumerable<RuleDefinition> ruleDefinitions, IEnumerable<Request> requests)
{
    List<Request> requestList = requests.ToList();
    return ruleDefinitions.SelectMany(rule => FindViolationsInRequests(requestList, rule));
}

Where I make a list from the items so I'm absolutely sure that I'm not going to enumerate over the IEnumerable more than once (which can cause problems for certain implementations, think about IQueryable).

OTHER TIPS

If you are not opposed to using query expressions, then you can write your implementation as:

var violations = from rule in ruleDefinitions
                 join r1 in requests on rule.ConcerningRequestType equals r1.TypeOfThisRequest
                 join r2 in requests on rule.ConcerningRequestType equals r2.TypeOfThisRequest
                 where r1 != r2 &&
                       r2.TimeIndex > r1.TimeIndex &&
                       Math.Abs(r2.TimeIndex - r1.TimeIndex) < rule.MinimumDistanceBetweenRequests
                 select new Violation() { FirstRequest = r1, SecondRequest = r2, ViolatedRule = rule };
Licensed under: CC-BY-SA with attribution
Not affiliated with StackOverflow
scroll top