feature: 修复下标越界的bug,调整enter ast 代码

This commit is contained in:
pixelMax(奇淼
2024-07-23 18:36:35 +08:00
parent de782e629f
commit 4512aa7ff8
2 changed files with 36 additions and 53 deletions

View File

@@ -2,6 +2,7 @@ package ast
import (
"go/ast"
"go/token"
"io"
)
@@ -31,66 +32,48 @@ func (a *PackageEnter) Parse(filename string, writer io.Writer) (file *ast.File,
}
func (a *PackageEnter) Rollback(file *ast.File) error {
for i := 0; i < len(file.Decls); i++ {
v1, o1 := file.Decls[i].(*ast.GenDecl)
if o1 {
for j := 0; j < len(v1.Specs); j++ {
v2, o2 := v1.Specs[j].(*ast.TypeSpec)
if o2 {
if v2.Name.Name != a.Type.Group() {
continue
}
v3, o3 := v2.Type.(*ast.StructType)
if o3 {
for k := 0; k < len(v3.Fields.List); k++ {
if len(v3.Fields.List[k].Names) >= 1 && v3.Fields.List[k].Names[0].Name == a.StructName {
_ = NewImport(a.ImportPath).Rollback(file)
v3.Fields.List = append(v3.Fields.List[:k], v3.Fields.List[k+1:]...)
}
}
}
}
}
}
}
// 无需回滚
return nil
}
func (a *PackageEnter) Injection(file *ast.File) error {
_ = NewImport(a.ImportPath).Injection(file)
for i := 0; i < len(file.Decls); i++ {
v1, o1 := file.Decls[i].(*ast.GenDecl)
if o1 {
for j := 0; j < len(v1.Specs); j++ {
v2, o2 := v1.Specs[j].(*ast.TypeSpec)
if o2 {
if v2.Name.Name != a.Type.Group() {
continue
}
v3, o3 := v2.Type.(*ast.StructType)
if o3 {
var has bool
for k := 0; k < len(v3.Fields.List); k++ {
if len(v3.Fields.List[k].Names) == 1 && v3.Fields.List[k].Names[0].Name == a.StructName {
has = true
break
}
}
if !has {
field := &ast.Field{
Names: []*ast.Ident{{Name: a.StructName}},
Type: &ast.SelectorExpr{
X: &ast.Ident{Name: a.PackageName},
Sel: &ast.Ident{Name: a.PackageStructName},
},
}
v3.Fields.List = append(v3.Fields.List, field)
}
}
ast.Inspect(file, func(n ast.Node) bool {
genDecl, ok := n.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
return true
}
for _, spec := range genDecl.Specs {
typeSpec, specok := spec.(*ast.TypeSpec)
if !specok || typeSpec.Name.Name != a.Type.Group() {
continue
}
structType, structTypeOK := typeSpec.Type.(*ast.StructType)
if !structTypeOK {
continue
}
for _, field := range structType.Fields.List {
if len(field.Names) == 1 && field.Names[0].Name == a.StructName {
return true
}
}
field := &ast.Field{
Names: []*ast.Ident{{Name: a.StructName}},
Type: &ast.SelectorExpr{
X: &ast.Ident{Name: a.PackageName},
Sel: &ast.Ident{Name: a.PackageStructName},
},
}
structType.Fields.List = append(structType.Fields.List, field)
return false
}
}
return true
})
return nil
}