代码改变世界

go/parser的使用

2023-01-30 08:21  轩脉刃  阅读(1083)  评论(0编辑  收藏  举报

想用golang来解析一个golang的项目,我们需要用官方的库go\parser

我们要先解析.gitignore:

// parse .gitignore
		ignoreFile := []string{}
		ignoreFolder := []string{}
		content, err := os.ReadFile("/Users/jianfengye/Documents/workspace/gohade/hade/.gitignore")
		if err != nil {
			return err
		}
		lines := strings.Split(string(content), "\n")
		for _, line := range lines {
			if line == "" {
				continue
			} else if strings.HasSuffix(line, "/") {
				ignoreFolder = append(ignoreFolder, line)
			} else {
				ignoreFile = append(ignoreFile, line)
			}
		}
		ignoreFolder = append(ignoreFolder, ".idea")
		ignoreFolder = append(ignoreFolder, ".git")
		fmt.Println("ignoreFile len: " + fmt.Sprint(len(ignoreFile)))
		fmt.Println("ignoreFolder len: " + fmt.Sprint(len(ignoreFolder)))

然后解析go.mod

// parse go.mod
		goModFile := filepath.Join(folder, "go.mod")
		data, err := os.ReadFile(goModFile)
		if err != nil {
			return errors.Wrap(err, "read go.mod file error")
		}

		f, err := modfile.Parse(goModFile, data, nil)
		if err != nil {
			return errors.Wrap(err, "parse go.mod file error")
		}
		modName := f.Module.Mod.String()

最后再解析循环解析目录下的所有pkg

allPkgs := map[string]*ast.Package{}

// visit all files in the folder
err = filepath.Walk(folder, func(path string, info os.FileInfo, err error) error {
   // check if the file is ignored
   for _, ignore := range ignoreFolder {
      if strings.HasPrefix(path, filepath.Join(folder, ignore)) {
         return nil
      }
   }
   for _, ignore := range ignoreFile {
      if strings.HasPrefix(path, filepath.Join(folder, ignore)) {
         return nil
      }
   }

   if err != nil {
      return err
   }

   if info.IsDir() {
      fs := token.NewFileSet()
      pkgs, firstErr := parser.ParseDir(fs, path, nil, parser.ParseComments)
      if firstErr != nil {
         fmt.Println("firstErr: " + firstErr.Error())
         return firstErr
      }

      pkgPath := modName + strings.TrimPrefix(path, folder)
      if len(pkgs) == 0 {
         return nil
      }
      fmt.Println("parse:" + path)

      for _, v := range pkgs {
         allPkgs[pkgPath] = v
         fmt.Println("package:", pkgPath)
      }

      return nil
   }

   return nil
})

if err != nil {
		fmt.Printf("%+v", err)
}

spew.Dump(len(allPkgs))

这个基本就成型了,下面就是要理解一下parser.ParseDir的返回结构。

返回结构解析

返回结构主要是一个ast.Package

返回结构的UML图如图所示。

这里再描述一下:

Package包含了多个File结构,每个File结构最主要的是Decls,就是申明。这个申明可以是GenDecl,也可以是FuncDecl。其中GenDecl代表了对结构、变量、常量等的申明。而FuncDecl代表的是对函数的申明。

GenDecl最重要的是Specs这个字段,它又分为importSpec,ValueSpec,TypeSpec。其中TypeSpec是最关键的,能说明这个Spec是用来描述struct?interface?array? map?

FuncDecl是用Recv *FieldList来区分这个函数定义是 method 还是 function。其中最主要的就是Body,里面是整个函数定义的内容。

所以要获取一个文件中所有struct的属性和方法名的逻辑,代码如下:

package parsego

import (
	"fmt"
	"github.com/gohade/hade/framework/cobra"
	"github.com/pkg/errors"
	"go/ast"
	"go/parser"
	"go/token"
	"os"
)

var genClassFile string

func InitMermaidCmd() *cobra.Command {
	GenClassCommand.Flags().StringVarP(&genClassFile, "file", "f", "", "file path")
	//_ = GenClassCommand.MarkFlagRequired("file")
	MermaidCommand.AddCommand(GenClassCommand)
	return MermaidCommand
}

var MermaidCommand = &cobra.Command{
	Use:   "mermaid",
	Short: "mermaid",
	RunE: func(c *cobra.Command, args []string) error {
		return nil
	},
}

// GenClassCommand generate a file's class diagrames
var GenClassCommand = &cobra.Command{
	Use:   "genClass",
	Short: "genClass",
	RunE: func(c *cobra.Command, args []string) error {
		// check file exist
		if _, err := os.Stat(genClassFile); os.IsNotExist(err) {
			return errors.Wrap(err, "file not exist")
		}

		// parse file
		fset := token.NewFileSet()
		asf, err := parser.ParseFile(fset, genClassFile, nil, parser.ParseComments)
		if err != nil {
			return errors.Wrap(err, "parse file error")
		}
		//ast.Print(fset, asf)
		//return nil

		fileContent, _ := os.ReadFile(genClassFile)

		// print class diagrames
		tab := "\t"
		fmt.Println("classDiagram")

		// class map
		classFieldMap := map[string][]string{}
		classFuncMap := map[string][]string{}
		interfaceFuncMap := map[string][]string{}

		for _, decl := range asf.Decls {
			switch decl.(type) {
			case *ast.GenDecl:
				genDecl := decl.(*ast.GenDecl)
				switch genDecl.Tok {
				case token.TYPE:
					for _, spec := range genDecl.Specs {
						typeSpec := spec.(*ast.TypeSpec)

						switch typeSpec.Type.(type) {
						case *ast.StructType:
							className := typeSpec.Name.Name
							if classFieldMap[className] == nil {
								classFieldMap[className] = []string{}
							}

							structType := typeSpec.Type.(*ast.StructType)
							for _, field := range structType.Fields.List {
								startPos := field.Type.Pos()
								endPos := field.Type.End()
								fieldType := string(fileContent[startPos-1 : endPos-1])

								for _, name := range field.Names {
									classFieldMap[className] = append(classFieldMap[className],
										name.Name+" "+fieldType)
								}
							}
						case *ast.InterfaceType:
							interfaceName := typeSpec.Name.Name
							if interfaceFuncMap[interfaceName] == nil {
								interfaceFuncMap[interfaceName] = []string{}
							}

							interfaceType := typeSpec.Type.(*ast.InterfaceType)
							for _, method := range interfaceType.Methods.List {
								startPos := method.Type.Pos()
								endPos := method.Type.End()
								methodType := string(fileContent[startPos-1 : endPos-1])

								if method.Names == nil {
									if ident, ok := method.Type.(*ast.Ident); ok {
										//spew.Dump(ident.Obj)
										if typeSpec, ok := ident.Obj.Decl.(*ast.TypeSpec); ok {
											if interfaceType, ok := typeSpec.Type.(*ast.InterfaceType); ok {
												for _, method := range interfaceType.Methods.List {
													startPos := method.Type.Pos()
													endPos := method.Type.End()
													methodType := string(fileContent[startPos-1 : endPos-1])
													interfaceFuncMap[interfaceName] = append(interfaceFuncMap[interfaceName],
														methodType)
												}
											}
										}
									}
								}
								for _, name := range method.Names {
									interfaceFuncMap[interfaceName] = append(interfaceFuncMap[interfaceName],
										name.Name+" "+methodType)
								}
							}
						}
					}
				}
			case *ast.FuncDecl:
				funcDecl := decl.(*ast.FuncDecl)
				if funcDecl.Recv != nil {
					// method
					recv := funcDecl.Recv.List[0]
					switch recv.Type.(type) {
					case *ast.StarExpr:
						starExpr := recv.Type.(*ast.StarExpr)
						className := string(fileContent[starExpr.X.Pos()-1 : starExpr.X.End()-1])
						if classFuncMap[className] == nil {
							classFuncMap[className] = []string{}
						}
						classFuncMap[className] = append(classFuncMap[className], funcDecl.Name.Name)
					case *ast.StructType:
						structType := recv.Type.(*ast.StructType)
						structType.Fields.Pos()
						structType.Fields.End()
						className := string(fileContent[structType.Fields.Pos()-1 : structType.Fields.End()-1])
						if classFuncMap[className] == nil {
							classFuncMap[className] = []string{}
						}
						classFuncMap[className] = append(classFuncMap[className], funcDecl.Name.Name)
					}
				}
			}
		}
		//fmt.Println(tab)

		for className, fields := range classFieldMap {
			fmt.Println(tab + "class " + className)
			for _, field := range fields {
				fmt.Println(tab + className + ": " + field)
			}

			if classFuncMap[className] != nil {
				for _, funcName := range classFuncMap[className] {
					fmt.Println(tab + className + ": " + funcName + "()")
				}
			}
			fmt.Println()
		}

		for interfaceName, funcs := range interfaceFuncMap {
			fmt.Println(tab + "class " + interfaceName)
			fmt.Println(tab + "<<interface>> " + interfaceName)

			for _, funcName := range funcs {
				fmt.Println(tab + interfaceName + ": " + funcName)
			}
			fmt.Println()
		}

		return nil
	},
}


参考

https://medium.com/justforfunc/understanding-go-programs-with-go-parser-c4e88a6edb87