How to access invocations through extension methods, methods in static classes and methods with ref/out parameters with Roslyn

StackOverflow https://stackoverflow.com/questions/23245368

  •  08-07-2023
  •  | 
  •  

Frage

I'm working on creating an open source project for creating .NET UML Sequence Diagrams that leverages a javascript library called js-sequence-diagrams. I am not sure Roslyn is the right tool for the job, but I thought I would give it a shot so I have put together some proof of concept code which attempts to get all methods and their invocations and then outputs these invocations in a form that can be interpreted by js-sequence-diagrams.

The code generates some output, but it does not capture everything. I cannot seem to capture invocations via extension methods, invocations of static methods in static classes.

I do see invocations of methods with out parameters, but not in any form that extends the BaseMethodDeclarationSyntax

Here is the code (keep in mind this is proof of concept code and so I did not entirely follow best-practices, but I am not requesting a code review here ... also, I am used to using Tasks so I am messing around with await, but am not entirely sure I am using it properly yet)

https://gist.github.com/SoundLogic/11193841

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection.Emit;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.MSBuild;
using Microsoft.CodeAnalysis.FindSymbols;
using System.Collections.Immutable;

namespace Diagrams
{
    class Program
    {
        static void Main(string[] args)
        {
            string solutionName = "Diagrams";
            string solutionExtension = ".sln";
            string solutionFileName = solutionName + solutionExtension;
            string rootPath = @"C:\Workspace\";
            string solutionPath = rootPath + solutionName + @"\" + solutionFileName;

            MSBuildWorkspace workspace = MSBuildWorkspace.Create();
            DiagramGenerator diagramGenerator = new DiagramGenerator( solutionPath, workspace );
            diagramGenerator.ProcessSolution();

            #region reference

            //TODO: would ReferencedSymbol.Locations be a better way of accessing MethodDeclarationSyntaxes? 
            //INamedTypeSymbol programClass = compilation.GetTypeByMetadataName("DotNetDiagrams.Program");

            //IMethodSymbol barMethod = programClass.GetMembers("Bar").First(s => s.Kind == SymbolKind.Method) as IMethodSymbol;
            //IMethodSymbol fooMethod = programClass.GetMembers("Foo").First(s => s.Kind == SymbolKind.Method) as IMethodSymbol;

            //ITypeSymbol fooSymbol = fooMethod.ContainingType;
            //ITypeSymbol barSymbol = barMethod.ContainingType;

            //Debug.Assert(barMethod != null);
            //Debug.Assert(fooMethod != null);

            //List<ReferencedSymbol> barReferencedSymbols = SymbolFinder.FindReferencesAsync(barMethod, solution).Result.ToList();
            //List<ReferencedSymbol> fooReferencedSymbols = SymbolFinder.FindReferencesAsync(fooMethod, solution).Result.ToList();

            //Debug.Assert(barReferencedSymbols.First().Locations.Count() == 1);
            //Debug.Assert(fooReferencedSymbols.First().Locations.Count() == 0);

            #endregion

            Console.ReadKey();
        }
    }

    class DiagramGenerator
    {
        private Solution _solution;

        public DiagramGenerator( string solutionPath, MSBuildWorkspace workspace )
        {
            _solution = workspace.OpenSolutionAsync(solutionPath).Result;
        }

        public async void ProcessSolution()
        {
            foreach (Project project in _solution.Projects)
            {
                Compilation compilation = await project.GetCompilationAsync();
                ProcessCompilation(compilation);
            }
        }

        private async void ProcessCompilation(Compilation compilation)
        {
            var trees = compilation.SyntaxTrees;

            foreach (var tree in trees)
            {
                var root = await tree.GetRootAsync();
                var classes = root.DescendantNodes().OfType<ClassDeclarationSyntax>();

                foreach (var @class in classes)
                {
                    ProcessClass( @class, compilation, tree, root );
                }
            }
        }

        private void ProcessClass(
              ClassDeclarationSyntax @class
            , Compilation compilation
            , SyntaxTree tree
            , SyntaxNode root)
        {
            var methods = @class.DescendantNodes().OfType<MethodDeclarationSyntax>();

            foreach (var method in methods)
            {
                var model = compilation.GetSemanticModel(tree);
                // Get MethodSymbol corresponding to method
                var methodSymbol = model.GetDeclaredSymbol(method);
                // Get all InvocationExpressionSyntax in the above code.
                var allInvocations = root.DescendantNodes().OfType<InvocationExpressionSyntax>();
                // Use GetSymbolInfo() to find invocations of target method
                var matchingInvocations =
                    allInvocations.Where(i => model.GetSymbolInfo(i).Symbol.Equals(methodSymbol));

                ProcessMethod( matchingInvocations, method, @class);
            }

            var delegates = @class.DescendantNodes().OfType<DelegateDeclarationSyntax>();

            foreach (var @delegate in delegates)
            {
                var model = compilation.GetSemanticModel(tree);
                // Get MethodSymbol corresponding to method
                var methodSymbol = model.GetDeclaredSymbol(@delegate);
                // Get all InvocationExpressionSyntax in the above code.
                var allInvocations = tree.GetRoot().DescendantNodes().OfType<InvocationExpressionSyntax>();
                // Use GetSymbolInfo() to find invocations of target method
                var matchingInvocations =
                    allInvocations.Where(i => model.GetSymbolInfo(i).Symbol.Equals(methodSymbol));

                ProcessDelegates(matchingInvocations, @delegate, @class);
            }

        }

        private void ProcessMethod(
              IEnumerable<InvocationExpressionSyntax> matchingInvocations
            , MethodDeclarationSyntax methodDeclarationSyntax
            , ClassDeclarationSyntax classDeclarationSyntax )
        {
            foreach (var invocation in matchingInvocations)
            {
                MethodDeclarationSyntax actingMethodDeclarationSyntax = null;
                if (SyntaxNodeHelper.TryGetParentSyntax(invocation, out actingMethodDeclarationSyntax))
                {
                    var r = methodDeclarationSyntax;
                    var m = actingMethodDeclarationSyntax;

                    PrintCallerInfo(
                        invocation
                        , classDeclarationSyntax
                        , m.Identifier.ToFullString()
                        , r.ReturnType.ToFullString()
                        , r.Identifier.ToFullString()
                        , r.ParameterList.ToFullString()
                        , r.TypeParameterList != null ? r.TypeParameterList.ToFullString() : String.Empty
                        );
                }
            }
        }

        private void ProcessDelegates( 
              IEnumerable<InvocationExpressionSyntax> matchingInvocations
            , DelegateDeclarationSyntax delegateDeclarationSyntax
            , ClassDeclarationSyntax classDeclarationSyntax )
        {
            foreach (var invocation in matchingInvocations)
            {
                DelegateDeclarationSyntax actingMethodDeclarationSyntax = null;

                if (SyntaxNodeHelper.TryGetParentSyntax(invocation, out actingMethodDeclarationSyntax))
                {
                    var r = delegateDeclarationSyntax;
                    var m = actingMethodDeclarationSyntax;

                    PrintCallerInfo(
                        invocation
                        , classDeclarationSyntax
                        , m.Identifier.ToFullString()
                        , r.ReturnType.ToFullString()
                        , r.Identifier.ToFullString()
                        , r.ParameterList.ToFullString()
                        , r.TypeParameterList != null ? r.TypeParameterList.ToFullString() : String.Empty
                    );
                }
            }
        }

        private void PrintCallerInfo(
              InvocationExpressionSyntax invocation
            , ClassDeclarationSyntax classBeingCalled
            , string callingMethodName
            , string returnType
            , string calledMethodName
            , string calledMethodArguments
            , string calledMethodTypeParameters = null )
        {
            ClassDeclarationSyntax parentClassDeclarationSyntax = null;
            if (!SyntaxNodeHelper.TryGetParentSyntax(invocation, out parentClassDeclarationSyntax))
            {
                throw new Exception();
            }

            calledMethodTypeParameters = calledMethodTypeParameters ?? String.Empty;

            var actedUpon = classBeingCalled.Identifier.ValueText;
            var actor = parentClassDeclarationSyntax.Identifier.ValueText;
            var callInfo = callingMethodName + "=>" + calledMethodName + calledMethodTypeParameters + calledMethodArguments;
            var returnCallInfo = returnType;

            string info = BuildCallInfo(
                  actor
                , actedUpon
                , callInfo
                , returnCallInfo);

            Console.Write(info);
        }

        private string BuildCallInfo(string actor, string actedUpon, string callInfo, string returnInfo)
        {
            const string calls = "->";
            const string returns = "-->";
            const string descriptionSeparator = ": ";

            string callingInfo = actor + calls + actedUpon + descriptionSeparator + callInfo;
            string returningInfo = actedUpon + returns + actor + descriptionSeparator + "returns " + returnInfo;

            callingInfo = callingInfo.RemoveNewLines(true);
            returningInfo = returningInfo.RemoveNewLines(true);

            string result = callingInfo + Environment.NewLine;
            result += returningInfo + Environment.NewLine;

            return result;
        }
    }

    static class SyntaxNodeHelper
    {
        public static bool TryGetParentSyntax<T>(SyntaxNode syntaxNode, out T result) 
            where T : SyntaxNode
        {
            // set defaults
            result = null;

            if (syntaxNode == null)
            {
                return false;
            }

            try
            {
                syntaxNode = syntaxNode.Parent;

                if (syntaxNode == null)
                {
                    return false;
                }

                if (syntaxNode.GetType() == typeof (T))
                {
                    result = syntaxNode as T;
                    return true;
                }

                return TryGetParentSyntax<T>(syntaxNode, out result);
            }
            catch
            {
                return false;
            }
        }
    }

    public static class StringEx
    {
        public static string RemoveNewLines(this string stringWithNewLines, bool cleanWhitespace = false)
        {
            string stringWithoutNewLines = null;
            List<char> splitElementList = Environment.NewLine.ToCharArray().ToList();

            if (cleanWhitespace)
            {
                splitElementList.AddRange(" ".ToCharArray().ToList());
            }

            char[] splitElements = splitElementList.ToArray();

            var stringElements = stringWithNewLines.Split(splitElements, StringSplitOptions.RemoveEmptyEntries);
            if (stringElements.Any())
            {
                stringWithoutNewLines = stringElements.Aggregate(stringWithoutNewLines, (current, element) => current + (current == null ? element : " " + element));
            }

            return stringWithoutNewLines ?? stringWithNewLines;
        }
    }
}

Any guidance here would be much appreciated!

War es hilfreich?

Lösung

Using the methodSymbol in the ProcessClass method I took Andy's suggestion and came up with the below (although I imagine there may be an easier way to go about this):

private async Task<List<MethodDeclarationSyntax>> GetMethodSymbolReferences( IMethodSymbol methodSymbol )
{
    var references = new List<MethodDeclarationSyntax>();

    var referencingSymbols = await SymbolFinder.FindCallersAsync(methodSymbol, _solution);
    var referencingSymbolsList = referencingSymbols as IList<SymbolCallerInfo> ?? referencingSymbols.ToList();

    if (!referencingSymbolsList.Any(s => s.Locations.Any()))
    {
        return references;
    }

    foreach (var referenceSymbol in referencingSymbolsList)
    {
        foreach (var location in referenceSymbol.Locations)
        {
            var position = location.SourceSpan.Start;
            var root = await location.SourceTree.GetRootAsync();
            var nodes = root.FindToken(position).Parent.AncestorsAndSelf().OfType<MethodDeclarationSyntax>();

            references.AddRange(nodes);
        }
    }

    return references;
}

and the resulting image generated by plugging the output text into js-sequence-diagrams (I have updated the github gist with the full source for this should anyone find it useful - I excluded method parameters so the diagram was easy digest, but these can optionally be turned back on):

Edit:

I've updated the code (see the github gist) so now calls are shown in the order they were made (based on the span start location of a called method from within the calling method via results from FindCallersAsync):

enter image description here

Lizenziert unter: CC-BY-SA mit Zuschreibung
Nicht verbunden mit StackOverflow
scroll top