Golang ast 的使用
从文件中获取注释信息
package main import ( "go/ast" "go/parser" "go/token" "log" "path/filepath" ) type Visitor struct { fset *token.FileSet } func (v *Visitor) Visit(node ast.Node) ast.Visitor { switch node.(type) { //判断ast分类 case *ast.FuncDecl: demo := node.(*ast.FuncDecl) // 打印具体的注释 println(demo.Doc.List[0].Text) // 可以打印出ast结构 ast.Print(v.fset, node) } return v } func main() { fset := token.NewFileSet() path, _ := filepath.Abs("./demo.go") f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) if err != nil { log.Println(err) return } ast.Walk(&Visitor{ fset: fset, }, f) }
demo.go 被分析文件
// file commment package main type Article struct{} // @GET("/get") func (a Article) Get() { } // @POST("/save") func (a Article) Save() { }
输出: // @GET("/get") // @POST("/save")
向文件中添加变量
package main import ( "bytes" "fmt" "go/ast" "go/format" "go/parser" "go/token" "log" "os" "path/filepath" ) type Visitor struct { fset *token.FileSet } func (v *Visitor) Visit(node ast.Node) ast.Visitor { switch node.(type) { //判断ast分类 case *ast.FuncDecl: demo := node.(*ast.FuncDecl) if demo.Name.Name == "GetModels" { returnStm, ok := demo.Body.List[0].(*ast.ReturnStmt) if !ok { return v } comp, ok := returnStm.Results[0].(*ast.CompositeLit) if !ok { return v } comp.Elts = append(comp.Elts, &ast.UnaryExpr{ Op: token.AND, X: &ast.CompositeLit{ Type: &ast.Ident{ Name: "Test", }, }, }) } } return v } func main() { fset := token.NewFileSet() path, _ := filepath.Abs("./models/demo.go") f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) if err != nil { log.Println(err) return } ast.Walk(&Visitor{ fset: fset, }, f) // ast.Print(fset, f) var output []byte buffer := bytes.NewBuffer(output) err = format.Node(buffer, fset, f) if err != nil { log.Fatal(err) } // 输出Go代码 file, err := os.OpenFile(path, os.O_RDWR, 0766) if err != nil { log.Fatalf("open err %s", err.Error()) } n, err := file.Write(buffer.Bytes()) if err != nil { log.Fatalf("write err %s", err.Error()) } fmt.Println(n) fmt.Println(buffer.String()) }
demo.go
/* * @Description: * @Author: gphper * @Date: 2021-07-08 20:12:04 */ package models import ( "context" "github.com/gphper/ginadmin/configs" "github.com/gphper/ginadmin/pkg/loggers" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" ) var mapDB map[string]*gorm.DB type GaTabler interface { schema.Tabler FillData(*gorm.DB) GetConnName() string } type BaseModle struct { ConnName string `gorm:"-" json:"-"` } func (b *BaseModle) TableName() string { return "" } func (b *BaseModle) FillData(db *gorm.DB) {} func (b *BaseModle) GetConnName() string { return b.ConnName } // 获取链接 func GetDB(model GaTabler) *gorm.DB { db, ok := mapDB[model.GetConnName()] if !ok { errMsg := fmt.Sprintf("connection name%s no exists", model.GetConnName()) loggers.LogError(context.Background(), "get_db_error", "GetDB", map[string]string{ "msg": errMsg, }) } return db } func GetModels() []interface{} { return []interface{}{ &AdminUsers{}, &Article{}, &UploadType{}, &User{} } }
最终结果向 GetModels 中添加 &Test{}