diff --git a/changelog.md b/changelog.md index a3c7be6d01..9449fa8c62 100644 --- a/changelog.md +++ b/changelog.md @@ -2,6 +2,10 @@ ## Unreleased +### Features + +- [#4108](https://github.com/ignite/cli/pull/4108) Add `xast` package (cherry-picked from [#3770](https://github.com/ignite/cli/pull/3770)) + ### Changes - [#3959](https://github.com/ignite/cli/pull/3959) Remove app name prefix from the `.gitignore` file diff --git a/ignite/pkg/goanalysis/goanalysis.go b/ignite/pkg/goanalysis/goanalysis.go index 36f16a472f..701ae3d67b 100644 --- a/ignite/pkg/goanalysis/goanalysis.go +++ b/ignite/pkg/goanalysis/goanalysis.go @@ -45,7 +45,6 @@ func DiscoverMain(path string) (pkgPaths []string, err error) { return nil }) - if err != nil { return nil, err } @@ -338,3 +337,59 @@ func ReplaceCode(pkgPath, oldFunctionName, newFunction string) (err error) { } return nil } + +// HasAnyStructFieldsInPkg finds the struct within a package folder and checks +// if any of the fields are defined in the struct. +func HasAnyStructFieldsInPkg(pkgPath, structName string, fields []string) (bool, error) { + absPath, err := filepath.Abs(pkgPath) + if err != nil { + return false, err + } + fileSet := token.NewFileSet() + all, err := parser.ParseDir(fileSet, absPath, nil, parser.ParseComments) + if err != nil { + return false, err + } + + fieldsNames := make(map[string]struct{}) + for _, field := range fields { + fieldsNames[strings.ToLower(field)] = struct{}{} + } + + exist := false + for _, pkg := range all { + for _, f := range pkg.Files { + ast.Inspect(f, func(x ast.Node) bool { + typeSpec, ok := x.(*ast.TypeSpec) + if !ok { + return true + } + + if _, ok := typeSpec.Type.(*ast.StructType); !ok || + typeSpec.Name.Name != structName || + typeSpec.Type == nil { + return true + } + + // Check if the struct has fields. + structType, ok := typeSpec.Type.(*ast.StructType) + if !ok { + return true + } + + // Iterate through the fields of the struct. + for _, field := range structType.Fields.List { + for _, fieldName := range field.Names { + if _, ok := fieldsNames[strings.ToLower(fieldName.Name)]; !ok { + continue + } + exist = true + return false + } + } + return true + }) + } + } + return exist, nil +} diff --git a/ignite/pkg/goanalysis/goanalysis_test.go b/ignite/pkg/goanalysis/goanalysis_test.go index 4d1297e3b0..8b7f5b2768 100644 --- a/ignite/pkg/goanalysis/goanalysis_test.go +++ b/ignite/pkg/goanalysis/goanalysis_test.go @@ -145,84 +145,84 @@ func createMainFiles(tmpDir string, mainFiles []string) (pathsWithMain []string, func TestFuncVarExists(t *testing.T) { tests := []struct { name string - testfile string + testFile string goImport string methodSignature string want bool }{ { name: "test a declaration inside a method success", - testfile: "testdata/varexist", + testFile: "testdata/varexist", methodSignature: "Background", goImport: "context", want: true, }, { name: "test global declaration success", - testfile: "testdata/varexist", + testFile: "testdata/varexist", methodSignature: "Join", goImport: "path/filepath", want: true, }, { name: "test a declaration inside an if and inside a method success", - testfile: "testdata/varexist", + testFile: "testdata/varexist", methodSignature: "SplitList", goImport: "path/filepath", want: true, }, { name: "test global variable success assign", - testfile: "testdata/varexist", + testFile: "testdata/varexist", methodSignature: "New", goImport: "errors", want: true, }, { name: "test invalid import", - testfile: "testdata/varexist", + testFile: "testdata/varexist", methodSignature: "Join", goImport: "errors", want: false, }, { name: "test invalid case sensitive assign", - testfile: "testdata/varexist", + testFile: "testdata/varexist", methodSignature: "join", goImport: "context", want: false, }, { name: "test invalid struct assign", - testfile: "testdata/varexist", + testFile: "testdata/varexist", methodSignature: "fooStruct", goImport: "context", want: false, }, { name: "test invalid method signature", - testfile: "testdata/varexist", + testFile: "testdata/varexist", methodSignature: "fooMethod", goImport: "context", want: false, }, { name: "test not found name", - testfile: "testdata/varexist", + testFile: "testdata/varexist", methodSignature: "Invalid", goImport: "context", want: false, }, { name: "test invalid assign with wrong", - testfile: "testdata/varexist", + testFile: "testdata/varexist", methodSignature: "invalid.New", goImport: "context", want: false, }, { name: "test invalid assign with wrong", - testfile: "testdata/varexist", + testFile: "testdata/varexist", methodSignature: "SplitList", goImport: "path/filepath", want: true, @@ -230,7 +230,7 @@ func TestFuncVarExists(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - appPkg, _, err := xast.ParseFile(tt.testfile) + appPkg, _, err := xast.ParseFile(tt.testFile) require.NoError(t, err) got := goanalysis.FuncVarExists(appPkg, tt.goImport, tt.methodSignature) @@ -577,3 +577,95 @@ func NewMethod1() { }) } } + +func TestHasStructFieldsInPkg(t *testing.T) { + tests := []struct { + name string + path string + structName string + fields []string + err error + want bool + }{ + { + name: "test a value with an empty struct", + path: "testdata", + structName: "emptyStruct", + fields: []string{"name"}, + want: false, + }, + { + name: "test no value with an empty struct", + path: "testdata", + structName: "emptyStruct", + fields: []string{""}, + want: false, + }, + { + name: "test a valid field into single field struct", + path: "testdata", + structName: "fooStruct", + fields: []string{"name"}, + want: true, + }, + { + name: "test a not valid field into single field struct", + path: "testdata", + structName: "fooStruct", + fields: []string{"baz"}, + want: false, + }, + { + name: "test a not valid field into struct", + path: "testdata", + structName: "bazStruct", + fields: []string{"baz"}, + want: false, + }, + { + name: "test a valid field into struct", + path: "testdata", + structName: "bazStruct", + fields: []string{"name"}, + want: true, + }, + { + name: "test two valid fields into struct", + path: "testdata", + structName: "bazStruct", + fields: []string{"name", "title"}, + want: true, + }, + { + name: "test a valid and a not valid fields into struct", + path: "testdata", + structName: "bazStruct", + fields: []string{"foo", "title"}, + want: true, + }, + { + name: "test three not valid fields into struct", + path: "testdata", + structName: "bazStruct", + fields: []string{"foo", "baz", "bla"}, + want: false, + }, + { + name: "invalid path", + path: "invalid_path", + err: os.ErrNotExist, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := goanalysis.HasAnyStructFieldsInPkg(tt.path, tt.structName, tt.fields) + if tt.err != nil { + require.Error(t, err) + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/ignite/pkg/goanalysis/testdata/fieldexist.go b/ignite/pkg/goanalysis/testdata/fieldexist.go new file mode 100644 index 0000000000..34d9e6c6bc --- /dev/null +++ b/ignite/pkg/goanalysis/testdata/fieldexist.go @@ -0,0 +1,13 @@ +package goanalysis + +type ( + emptyStruct struct{} + fooStruct struct { + name string + } + bazStruct struct { + name string + title string + description string + } +) diff --git a/ignite/pkg/xast/function.go b/ignite/pkg/xast/function.go new file mode 100644 index 0000000000..3469a32f68 --- /dev/null +++ b/ignite/pkg/xast/function.go @@ -0,0 +1,412 @@ +package xast + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "strings" + + "github.com/ignite/cli/v28/ignite/pkg/errors" +) + +type ( + // functionOpts represent the options for functions. + functionOpts struct { + newParams []param + body string + newLines []line + insideCall []call + insideStruct []str + appendCode []string + returnVars []string + } + + // FunctionOptions configures code generation. + FunctionOptions func(*functionOpts) + + str struct { + structName string + paramName string + code string + index int + } + call struct { + name string + code string + index int + } + param struct { + name string + varType string + index int + } + line struct { + code string + number uint64 + } +) + +// AppendFuncParams add a new param value. +func AppendFuncParams(name, varType string, index int) FunctionOptions { + return func(c *functionOpts) { + c.newParams = append(c.newParams, param{ + name: name, + varType: varType, + index: index, + }) + } +} + +// ReplaceFuncBody replace all body of the function, the method will replace first and apply the other options after. +func ReplaceFuncBody(body string) FunctionOptions { + return func(c *functionOpts) { + c.body = body + } +} + +// AppendFuncCode append code before the end or the return, if exists, of a function in Go source code content. +func AppendFuncCode(code string) FunctionOptions { + return func(c *functionOpts) { + c.appendCode = append(c.appendCode, code) + } +} + +// AppendFuncAtLine append a new code at line. +func AppendFuncAtLine(code string, lineNumber uint64) FunctionOptions { + return func(c *functionOpts) { + c.newLines = append(c.newLines, line{ + code: code, + number: lineNumber, + }) + } +} + +// AppendInsideFuncCall add code inside another function call. For instances, the method have a parameter a +// call 'New(param1, param2)' and we want to add the param3 the result will be 'New(param1, param2, param3)'. +func AppendInsideFuncCall(callName, code string, index int) FunctionOptions { + return func(c *functionOpts) { + c.insideCall = append(c.insideCall, call{ + name: callName, + code: code, + index: index, + }) + } +} + +// AppendInsideFuncStruct add code inside another function call. For instances, +// the struct have only one parameter 'Params{Param1: param1}' and we want to add +// the param2 the result will be 'Params{Param1: param1, Param2: param2}'. +func AppendInsideFuncStruct(structName, paramName, code string, index int) FunctionOptions { + return func(c *functionOpts) { + c.insideStruct = append(c.insideStruct, str{ + structName: structName, + paramName: paramName, + code: code, + index: index, + }) + } +} + +// NewFuncReturn replaces return statements in a Go function with a new return statement. +func NewFuncReturn(returnVars ...string) FunctionOptions { + return func(c *functionOpts) { + c.returnVars = append(c.returnVars, returnVars...) + } +} + +func newFunctionOptions() functionOpts { + return functionOpts{ + newParams: make([]param, 0), + body: "", + newLines: make([]line, 0), + insideCall: make([]call, 0), + insideStruct: make([]str, 0), + appendCode: make([]string, 0), + returnVars: make([]string, 0), + } +} + +// ModifyFunction modify a function based in the options. +func ModifyFunction(fileContent, functionName string, functions ...FunctionOptions) (modifiedContent string, err error) { + // Apply function options. + opts := newFunctionOptions() + for _, o := range functions { + o(&opts) + } + + fileSet := token.NewFileSet() + + // Parse the Go source code content. + f, err := parser.ParseFile(fileSet, "", fileContent, parser.ParseComments) + if err != nil { + return "", err + } + + // Parse the content of the new function into an ast. + var newFunctionBody *ast.BlockStmt + if opts.body != "" { + newFuncContent := fmt.Sprintf("package p; func _() { %s }", strings.TrimSpace(opts.body)) + newContent, err := parser.ParseFile(fileSet, "", newFuncContent, parser.ParseComments) + if err != nil { + return "", err + } + newFunctionBody = newContent.Decls[0].(*ast.FuncDecl).Body + } + + // Parse the content of the append code an ast. + appendCode := make([]ast.Stmt, 0) + for _, codeToInsert := range opts.appendCode { + newFuncContent := fmt.Sprintf("package p; func _() { %s }", strings.TrimSpace(codeToInsert)) + newContent, err := parser.ParseFile(fileSet, "", newFuncContent, parser.ParseComments) + if err != nil { + return "", err + } + appendCode = append(appendCode, newContent.Decls[0].(*ast.FuncDecl).Body.List...) + } + + // Parse the content of the return vars into an ast. + returnStmts := make([]ast.Expr, 0) + for _, returnVar := range opts.returnVars { + // Parse the new return var to expression. + newRetExpr, err := parser.ParseExprFrom(fileSet, "", []byte(returnVar), parser.ParseComments) + if err != nil { + return "", err + } + returnStmts = append(returnStmts, newRetExpr) + } + + callMap := make(map[string][]call) + callMapCheck := make(map[string][]call) + for _, c := range opts.insideCall { + calls, ok := callMap[c.name] + if !ok { + calls = []call{} + } + callMap[c.name] = append(calls, c) + callMapCheck[c.name] = append(calls, c) + } + + structMap := make(map[string][]str) + structMapCheck := make(map[string][]str) + for _, s := range opts.insideStruct { + structs, ok := structMap[s.structName] + if !ok { + structs = []str{} + } + structMap[s.structName] = append(structs, s) + structMapCheck[s.structName] = append(structs, s) + } + + // Parse the Go code to insert. + var ( + found bool + errInspect error + ) + ast.Inspect(f, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok || funcDecl.Name.Name != functionName { + return true + } + + for _, p := range opts.newParams { + fieldParam := &ast.Field{ + Names: []*ast.Ident{ast.NewIdent(p.name)}, + Type: ast.NewIdent(p.varType), + } + switch { + case p.index == -1: + // Append the new argument to the end + funcDecl.Type.Params.List = append(funcDecl.Type.Params.List, fieldParam) + case p.index >= 0 && p.index <= len(funcDecl.Type.Params.List): + // Insert the new argument at the specified index + funcDecl.Type.Params.List = append( + funcDecl.Type.Params.List[:p.index], + append([]*ast.Field{fieldParam}, funcDecl.Type.Params.List[p.index:]...)..., + ) + default: + errInspect = errors.Errorf("params index %d out of range", p.index) + return false + } + } + + // Check if the function has the code you want to replace. + if newFunctionBody != nil { + funcDecl.Body = newFunctionBody + } + + // Add the new code at line. + for _, newLine := range opts.newLines { + // Check if the function body has enough lines. + if newLine.number > uint64(len(funcDecl.Body.List))-1 { + errInspect = errors.Errorf("line number %d out of range", newLine.number) + return false + } + // Parse the Go code to insert. + insertionExpr, err := parser.ParseExprFrom(fileSet, "", []byte(newLine.code), parser.ParseComments) + if err != nil { + errInspect = err + return false + } + // Insert code at the specified line number. + funcDecl.Body.List = append( + funcDecl.Body.List[:newLine.number], + append([]ast.Stmt{&ast.ExprStmt{X: insertionExpr}}, funcDecl.Body.List[newLine.number:]...)..., + ) + } + + // Check if there is a return statement in the function. + if len(funcDecl.Body.List) > 0 { + lastStmt := funcDecl.Body.List[len(funcDecl.Body.List)-1] + switch stmt := lastStmt.(type) { + case *ast.ReturnStmt: + // Replace the return statements. + if len(returnStmts) > 0 { + // Remove existing return statements. + stmt.Results = nil + // Add the new return statement. + stmt.Results = append(stmt.Results, returnStmts...) + } + if len(appendCode) > 0 { + // If there is a return, insert before it. + appendCode = append(appendCode, stmt) + funcDecl.Body.List = append(funcDecl.Body.List[:len(funcDecl.Body.List)-1], appendCode...) + } + default: + if len(returnStmts) > 0 { + errInspect = errors.New("return statement not found") + return false + } + // If there is no return, insert at the end of the function body. + if len(appendCode) > 0 { + funcDecl.Body.List = append(funcDecl.Body.List, appendCode...) + } + } + } else { + if len(returnStmts) > 0 { + errInspect = errors.New("return statement not found") + return false + } + // If there are no statements in the function body, insert at the end of the function body. + if len(appendCode) > 0 { + funcDecl.Body.List = append(funcDecl.Body.List, appendCode...) + } + } + + // Add new code to the function callers. + ast.Inspect(funcDecl, func(n ast.Node) bool { + switch expr := n.(type) { + case *ast.CallExpr: // Add a new parameter to a function call. + // Check if the call expression matches the function call name. + name := "" + switch exp := expr.Fun.(type) { + case *ast.Ident: + name = exp.Name + case *ast.SelectorExpr: + name = exp.Sel.Name + default: + return true + } + + calls, ok := callMap[name] + if !ok { + return true + } + + // Construct the new argument to be added + for _, c := range calls { + newArg := ast.NewIdent(c.code) + newArg.NamePos = token.Pos(c.index) + switch { + case c.index == -1: + // Append the new argument to the end + expr.Args = append(expr.Args, newArg) + case c.index >= 0 && c.index <= len(expr.Args): + // Insert the new argument at the specified index + expr.Args = append(expr.Args[:c.index], append([]ast.Expr{newArg}, expr.Args[c.index:]...)...) + default: + errInspect = errors.Errorf("function call index %d out of range", c.index) + return false // Stop the inspection, an error occurred + } + } + delete(callMapCheck, name) + case *ast.CompositeLit: // Add a new parameter to a literal struct. + // Check if the call expression matches the function call name. + name := "" + switch exp := expr.Type.(type) { + case *ast.Ident: + name = exp.Name + case *ast.SelectorExpr: + name = exp.Sel.Name + default: + return true + } + + structs, ok := structMap[name] + if !ok { + return true + } + + // Construct the new argument to be added + for _, s := range structs { + var newArg ast.Expr = ast.NewIdent(s.code) + if s.paramName != "" { + newArg = &ast.KeyValueExpr{ + Key: ast.NewIdent(s.paramName), + Colon: token.Pos(s.index), + Value: ast.NewIdent(s.code), + } + } + + switch { + case s.index == -1: + // Append the new argument to the end + expr.Elts = append(expr.Elts, newArg) + case s.index >= 0 && s.index <= len(expr.Elts): + // Insert the new argument at the specified index + expr.Elts = append(expr.Elts[:s.index], append([]ast.Expr{newArg}, expr.Elts[s.index:]...)...) + default: + errInspect = errors.Errorf("function call index %d out of range", s.index) + return false // Stop the inspection, an error occurred + } + } + delete(structMapCheck, name) + default: + return true + } + return true // Continue the inspection for duplicated calls + }) + if errInspect != nil { + return false + } + if len(callMapCheck) > 0 { + errInspect = errors.Errorf("function calls not found: %v", callMapCheck) + return false + } + if len(structMapCheck) > 0 { + errInspect = errors.Errorf("function structs not found: %v", structMapCheck) + return false + } + + // everything is ok, mark as found and stop the inspect + found = true + return false + }) + if errInspect != nil { + return "", errInspect + } + if !found { + return "", errors.Errorf("function %s not found in file content", functionName) + } + + // Format the modified AST. + var buf bytes.Buffer + if err := format.Node(&buf, fileSet, f); err != nil { + return "", err + } + + // Return the modified content. + return buf.String(), nil +} diff --git a/ignite/pkg/xast/function_test.go b/ignite/pkg/xast/function_test.go new file mode 100644 index 0000000000..1b40988bae --- /dev/null +++ b/ignite/pkg/xast/function_test.go @@ -0,0 +1,321 @@ +package xast + +import ( + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ignite/cli/v28/ignite/pkg/errors" +) + +func TestModifyFunction(t *testing.T) { + existingContent := `package main + +import ( + "fmt" +) + +func main() { + fmt.Println("Hello, world!") + New(param1, param2) +} + +func anotherFunction() bool { + p := bla.NewParam() + p.CallSomething("Another call") + return true +}` + + type args struct { + fileContent string + functionName string + functions []FunctionOptions + } + tests := []struct { + name string + args args + want string + err error + }{ + { + name: "add all modifications type", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{ + AppendFuncParams("param1", "string", 0), + ReplaceFuncBody(`return false`), + AppendFuncAtLine(`fmt.Println("Appended at line 0.")`, 0), + AppendFuncAtLine(`SimpleCall(foo, bar)`, 1), + AppendFuncCode(`fmt.Println("Appended code.")`), + AppendFuncCode(`Param{Baz: baz, Foo: foo}`), + NewFuncReturn("1"), + AppendInsideFuncCall("SimpleCall", "baz", 0), + AppendInsideFuncCall("SimpleCall", "bla", -1), + AppendInsideFuncCall("Println", strconv.Quote("test"), -1), + AppendInsideFuncStruct("Param", "Bar", strconv.Quote("bar"), -1), + }, + }, + want: `package main + +import ( + "fmt" +) + +func main() { + fmt.Println("Hello, world!") + New(param1, param2) +} + +func anotherFunction(param1 string) bool { + fmt.Println("Appended at line 0.", "test") + SimpleCall(baz, foo, bar, bla) + fmt.Println("Appended code.", "test") + Param{Baz: baz, Foo: foo, Bar: "bar"} + return 1 +} +`, + }, + { + name: "add the replace body", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{ReplaceFuncBody(`return false`)}, + }, + want: `package main + +import ( + "fmt" +) + +func main() { + fmt.Println("Hello, world!") + New(param1, param2) +} + +func anotherFunction() bool { return false } +`, + }, + { + name: "add append line and code modification", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{ + AppendFuncAtLine(`fmt.Println("Appended at line 0.")`, 0), + AppendFuncAtLine(`SimpleCall(foo, bar)`, 1), + AppendFuncCode(`fmt.Println("Appended code.")`), + }, + }, + want: `package main + +import ( + "fmt" +) + +func main() { + fmt.Println("Hello, world!") + New(param1, param2) +} + +func anotherFunction() bool { + fmt.Println("Appended at line 0.") + SimpleCall(foo, bar) + + p := bla.NewParam() + p.CallSomething("Another call") + fmt.Println("Appended code.") + + return true +} +`, + }, + { + name: "add all modifications type", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{NewFuncReturn("1")}, + }, + want: strings.ReplaceAll(existingContent, "return true", "return 1\n") + "\n", + }, + { + name: "add inside call modifications", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{ + AppendInsideFuncCall("NewParam", "baz", 0), + AppendInsideFuncCall("NewParam", "bla", -1), + AppendInsideFuncCall("CallSomething", strconv.Quote("test1"), -1), + AppendInsideFuncCall("CallSomething", strconv.Quote("test2"), 0), + }, + }, + want: `package main + +import ( + "fmt" +) + +func main() { + fmt.Println("Hello, world!") + New(param1, param2) +} + +func anotherFunction() bool { + p := bla.NewParam(baz, bla) + p.CallSomething("test2", "Another call", "test1") + return true +} +`, + }, + { + name: "add inside struct modifications", + args: args{ + fileContent: `package main + +import ( + "fmt" +) + +func anotherFunction() bool { + Param{Baz: baz, Foo: foo} + Client{baz, foo} + return true +}`, + functionName: "anotherFunction", + functions: []FunctionOptions{ + AppendInsideFuncStruct("Param", "Bar", "bar", -1), + AppendInsideFuncStruct("Param", "Bla", "bla", 1), + AppendInsideFuncStruct("Client", "", "bar", 0), + }, + }, + want: `package main + +import ( + "fmt" +) + +func anotherFunction() bool { + Param{Baz: baz, Bla: bla, Foo: foo, Bar: bar} + Client{bar, baz, foo} + return true +} +`, + }, + { + name: "params out of range", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{AppendFuncParams("param1", "string", 1)}, + }, + err: errors.New("params index 1 out of range"), + }, + { + name: "invalid params", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{AppendFuncParams("9#.(c", "string", 0)}, + }, + err: errors.New("format.Node internal error (12:22: expected ')', found 9 (and 1 more errors))"), + }, + { + name: "invalid content for replace body", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{ReplaceFuncBody("9#.(c")}, + }, + err: errors.New("1:24: illegal character U+0023 '#'"), + }, + { + name: "line number out of range", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{AppendFuncAtLine(`fmt.Println("")`, 4)}, + }, + err: errors.New("line number 4 out of range"), + }, + { + name: "invalid code for append at line", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{AppendFuncAtLine("9#.(c", 0)}, + }, + err: errors.New("1:2: illegal character U+0023 '#'"), + }, + { + name: "invalid code append", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{AppendFuncCode("9#.(c")}, + }, + err: errors.New("1:24: illegal character U+0023 '#'"), + }, + { + name: "invalid new return", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{NewFuncReturn("9#.(c")}, + }, + err: errors.New("1:2: illegal character U+0023 '#'"), + }, + { + name: "call name not found", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{AppendInsideFuncCall("FooFunction", "baz", 0)}, + }, + err: errors.New("function calls not found: map[FooFunction:[{FooFunction baz 0}]]"), + }, + { + name: "invalid call param", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{AppendInsideFuncCall("NewParam", "9#.(c", 0)}, + }, + err: errors.New("format.Node internal error (13:21: illegal character U+0023 '#' (and 2 more errors))"), + }, + { + name: "call params out of range", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{AppendInsideFuncCall("NewParam", "baz", 1)}, + }, + err: errors.New("function call index 1 out of range"), + }, + { + name: "empty modifications", + args: args{ + fileContent: existingContent, + functionName: "anotherFunction", + functions: []FunctionOptions{}, + }, + want: existingContent + "\n", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ModifyFunction(tt.args.fileContent, tt.args.functionName, tt.args.functions...) + if tt.err != nil { + require.Error(t, err) + require.Equal(t, tt.err.Error(), err.Error()) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/ignite/pkg/xast/global.go b/ignite/pkg/xast/global.go new file mode 100644 index 0000000000..4be31ff807 --- /dev/null +++ b/ignite/pkg/xast/global.go @@ -0,0 +1,181 @@ +package xast + +import ( + "bytes" + "go/ast" + "go/format" + "go/parser" + "go/token" + + "github.com/ignite/cli/v28/ignite/pkg/errors" +) + +type ( + // globalOpts represent the options for globals. + globalOpts struct { + globals []global + } + + // GlobalOptions configures code generation. + GlobalOptions func(*globalOpts) + + global struct { + name, varType, value string + } + + // GlobalType represents the global type. + GlobalType string +) + +const ( + GlobalTypeVar GlobalType = "var" + GlobalTypeConst GlobalType = "const" +) + +// WithGlobal add a new global. +func WithGlobal(name, varType, value string) GlobalOptions { + return func(c *globalOpts) { + c.globals = append(c.globals, global{ + name: name, + varType: varType, + value: value, + }) + } +} + +func newGlobalOptions() globalOpts { + return globalOpts{ + globals: make([]global, 0), + } +} + +// InsertGlobal inserts global variables or constants into the provided Go source code content after the import section. +// The function parses the provided content, locates the import section, and inserts the global declarations immediately after it. +// The type of globals (variables or constants) is specified by the globalType parameter. +// Each global declaration is defined by calling WithGlobal function with appropriate arguments. +// The function returns the modified content with the inserted global declarations. +func InsertGlobal(fileContent string, globalType GlobalType, globals ...GlobalOptions) (modifiedContent string, err error) { + // apply global options. + opts := newGlobalOptions() + for _, o := range globals { + o(&opts) + } + + fileSet := token.NewFileSet() + + // Parse the Go source code content. + f, err := parser.ParseFile(fileSet, "", fileContent, parser.ParseComments) + if err != nil { + return "", err + } + + // Find the index of the import declaration or package declaration if no imports. + var insertIndex int + for i, decl := range f.Decls { + if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT { + insertIndex = i + 1 + break + } else if funcDecl, ok := decl.(*ast.FuncDecl); ok { + insertIndex = i + if funcDecl.Doc == nil { + insertIndex++ + } + break + } + } + + // Create global variable/constant declarations. + for _, global := range opts.globals { + // Create an identifier for the global. + ident := ast.NewIdent(global.name) + + // Create a value expression if provided. + var valueExpr ast.Expr + if global.value != "" { + valueExpr, err = parser.ParseExprFrom(fileSet, "", []byte(global.value), parser.ParseComments) + if err != nil { + return "", err + } + } + + // Create a declaration based on the global type. + var spec ast.Spec + switch globalType { + case GlobalTypeVar: + spec = &ast.ValueSpec{ + Names: []*ast.Ident{ident}, + Type: ast.NewIdent(global.varType), + Values: []ast.Expr{valueExpr}, + } + case GlobalTypeConst: + spec = &ast.ValueSpec{ + Names: []*ast.Ident{ident}, + Type: ast.NewIdent(global.varType), + Values: []ast.Expr{valueExpr}, + } + default: + return "", errors.Errorf("unsupported global type: %s", string(globalType)) + } + + // Insert the declaration after the import section or package declaration if no imports. + f.Decls = append( + f.Decls[:insertIndex], + append([]ast.Decl{ + &ast.GenDecl{ + TokPos: 1, + Tok: token.Lookup(string(globalType)), + Specs: []ast.Spec{spec}, + }, + }, f.Decls[insertIndex:]...)...) + insertIndex++ + } + + // Format the modified AST. + var buf bytes.Buffer + if err := format.Node(&buf, fileSet, f); err != nil { + return "", err + } + + // Return the modified content. + return buf.String(), nil +} + +// AppendFunction appends a new function to the end of the Go source code content. +func AppendFunction(fileContent string, function string) (modifiedContent string, err error) { + fileSet := token.NewFileSet() + + // Parse the function body as a separate file. + funcFile, err := parser.ParseFile(fileSet, "", "package main\n"+function, parser.AllErrors) + if err != nil { + return "", err + } + + // Extract the first declaration, assuming it's a function declaration. + var funcDecl *ast.FuncDecl + for _, decl := range funcFile.Decls { + if fDecl, ok := decl.(*ast.FuncDecl); ok { + funcDecl = fDecl + break + } + } + if funcDecl == nil { + return "", errors.Errorf("no function declaration found in the provided function body") + } + + // Parse the Go source code content. + f, err := parser.ParseFile(fileSet, "", fileContent, parser.ParseComments) + if err != nil { + return "", err + } + + // Append the function declaration to the file's declarations. + f.Decls = append(f.Decls, funcDecl) + + // Format the modified AST. + var buf bytes.Buffer + if err := format.Node(&buf, fileSet, f); err != nil { + return "", err + } + + return buf.String(), nil +} diff --git a/ignite/pkg/xast/global_test.go b/ignite/pkg/xast/global_test.go new file mode 100644 index 0000000000..5ea4d9f259 --- /dev/null +++ b/ignite/pkg/xast/global_test.go @@ -0,0 +1,352 @@ +package xast + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ignite/cli/v28/ignite/pkg/errors" +) + +func TestInsertGlobal(t *testing.T) { + type args struct { + fileContent string + globalType GlobalType + globals []GlobalOptions + } + tests := []struct { + name string + args args + want string + err error + }{ + { + name: "Insert global int var", + args: args{ + fileContent: `package main + +import ( + "fmt" +) + +`, + globalType: GlobalTypeVar, + globals: []GlobalOptions{ + WithGlobal("myIntVar", "int", "42"), + }, + }, + want: `package main + +import ( + "fmt" +) + +var myIntVar int = 42 +`, + }, + { + name: "Insert global int const", + args: args{ + fileContent: `package main + +import ( + "fmt" +) + +`, + globalType: GlobalTypeConst, + globals: []GlobalOptions{ + WithGlobal("myIntConst", "int", "42"), + }, + }, + want: `package main + +import ( + "fmt" +) + +const myIntConst int = 42 +`, + }, + { + name: "Insert string const", + args: args{ + fileContent: `package main + +import ( + "fmt" +) + +`, + globalType: GlobalTypeConst, + globals: []GlobalOptions{ + WithGlobal("myStringConst", "string", `"hello"`), + }, + }, + want: `package main + +import ( + "fmt" +) + +const myStringConst string = "hello" +`, + }, + { + name: "Insert multiples consts", + args: args{ + fileContent: `package main + +import ( + "fmt" +) + +`, + globalType: GlobalTypeConst, + globals: []GlobalOptions{ + WithGlobal("myStringConst", "string", `"hello"`), + WithGlobal("myBoolConst", "bool", "true"), + WithGlobal("myUintConst", "uint64", "40"), + }, + }, + want: `package main + +import ( + "fmt" +) + +const myStringConst string = "hello" +const myBoolConst bool = true +const myUintConst uint64 = 40 +`, + }, + { + name: "Insert global int var with not imports", + args: args{ + fileContent: `package main +`, + globalType: GlobalTypeVar, + globals: []GlobalOptions{ + WithGlobal("myIntVar", "int", "42"), + }, + }, + want: `package main + +var myIntVar int = 42 +`, + }, + { + name: "Insert global int var int an empty file", + args: args{ + fileContent: ``, + globalType: GlobalTypeVar, + globals: []GlobalOptions{ + WithGlobal("myIntVar", "int", "42"), + }, + }, + err: errors.New("1:1: expected 'package', found 'EOF'"), + }, + { + name: "Insert a custom var", + args: args{ + fileContent: `package main`, + globalType: GlobalTypeVar, + globals: []GlobalOptions{ + WithGlobal("fooVar", "foo", "42"), + }, + }, + want: `package main + +var fooVar foo = 42 +`, + }, + { + name: "Insert an invalid var", + args: args{ + fileContent: `package main`, + globalType: GlobalTypeVar, + globals: []GlobalOptions{ + WithGlobal("myInvalidVar", "invalid", "AEF#3fa."), + }, + }, + err: errors.New("1:4: illegal character U+0023 '#'"), + }, + { + name: "Insert an invalid type", + args: args{ + fileContent: `package main`, + globalType: "invalid", + globals: []GlobalOptions{ + WithGlobal("fooVar", "foo", "42"), + }, + }, + err: errors.New("unsupported global type: invalid"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := InsertGlobal(tt.args.fileContent, tt.args.globalType, tt.args.globals...) + if tt.err != nil { + require.Error(t, err) + require.Equal(t, tt.err.Error(), err.Error()) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func TestAppendFunction(t *testing.T) { + type args struct { + fileContent string + function string + } + tests := []struct { + name string + args args + want string + err error + }{ + { + name: "Append a function after the package declaration", + args: args{ + fileContent: `package main`, + function: `func add(a, b int) int { + return a + b +}`, + }, + want: `package main + +func add(a, b int) int { + return a + b +} +`, + }, + { + name: "Append a function after a var", + args: args{ + fileContent: `package main + +import ( + "fmt" +) + +var myIntVar int = 42 +`, + function: `func add(a, b int) int { + return a + b +}`, + }, + want: `package main + +import ( + "fmt" +) + +var myIntVar int = 42 + +func add(a, b int) int { + return a + b +} +`, + }, + { + name: "Append a function after the import", + args: args{ + fileContent: `package main + +import ( + "fmt" +) +`, + function: `func add(a, b int) int { + return a + b +}`, + }, + want: `package main + +import ( + "fmt" +) + +func add(a, b int) int { + return a + b +} +`, + }, + { + name: "Append a function after another function", + args: args{ + fileContent: `package main + +import ( + "fmt" +) + +var myIntVar int = 42 + +func myFunction() int { + return 42 +} +`, + function: `func add(a, b int) int { + return a + b +}`, + }, + want: `package main + +import ( + "fmt" +) + +var myIntVar int = 42 + +func myFunction() int { + return 42 +} +func add(a, b int) int { + return a + b +} +`, + }, + { + name: "Append a function in an empty file", + args: args{ + fileContent: ``, + function: `func add(a, b int) int { + return a + b +}`, + }, + err: errors.New("1:1: expected 'package', found 'EOF'"), + }, + { + name: "Append a empty function", + args: args{ + fileContent: `package main`, + function: ``, + }, + err: errors.New("no function declaration found in the provided function body"), + }, + { + name: "Append an invalid function", + args: args{ + fileContent: `package main`, + function: `@,.l.e,`, + }, + err: errors.New("2:1: illegal character U+0040 '@'"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := AppendFunction(tt.args.fileContent, tt.args.function) + if tt.err != nil { + require.Error(t, err) + require.Equal(t, tt.err.Error(), err.Error()) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/ignite/pkg/xast/import.go b/ignite/pkg/xast/import.go new file mode 100644 index 0000000000..0fb57271b3 --- /dev/null +++ b/ignite/pkg/xast/import.go @@ -0,0 +1,161 @@ +package xast + +import ( + "bytes" + "go/ast" + "go/format" + "go/parser" + "go/token" + "strconv" + + "github.com/ignite/cli/v28/ignite/pkg/errors" +) + +type ( + // importOpts represent the options for imp. + importOpts struct { + imports []imp + } + + // ImportOptions configures code generation. + ImportOptions func(*importOpts) + + imp struct { + repo string + name string + index int + } +) + +// WithLastImport add a new import int the end. +func WithLastImport(repo string) ImportOptions { + return func(c *importOpts) { + c.imports = append(c.imports, imp{ + repo: repo, + name: "", + index: -1, + }) + } +} + +// WithImport add a new import. If the index is -1 will append in the end of the imports. +func WithImport(repo string, index int) ImportOptions { + return func(c *importOpts) { + c.imports = append(c.imports, imp{ + repo: repo, + name: "", + index: index, + }) + } +} + +// WithNamedImport add a new import with name. If the index is -1 will append in the end of the imports. +func WithNamedImport(name, repo string, index int) ImportOptions { + return func(c *importOpts) { + c.imports = append(c.imports, imp{ + name: name, + repo: repo, + index: index, + }) + } +} + +// WithLastNamedImport add a new import with name in the end of the imports. +func WithLastNamedImport(name, repo string) ImportOptions { + return func(c *importOpts) { + c.imports = append(c.imports, imp{ + name: name, + repo: repo, + index: -1, + }) + } +} + +func newImportOptions() importOpts { + return importOpts{ + imports: make([]imp, 0), + } +} + +// AppendImports appends import statements to the existing import block in Go source code content. +func AppendImports(fileContent string, imports ...ImportOptions) (string, error) { + // apply global options. + opts := newImportOptions() + for _, o := range imports { + o(&opts) + } + + fileSet := token.NewFileSet() + + // Parse the Go source code content. + f, err := parser.ParseFile(fileSet, "", fileContent, parser.ParseComments) + if err != nil { + return "", err + } + + // Find the existing import declaration. + var importDecl *ast.GenDecl + for _, decl := range f.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.IMPORT || len(genDecl.Specs) == 0 { + continue + } + importDecl = genDecl + break + } + + if importDecl == nil { + // If no existing import declaration found, create a new one. + importDecl = &ast.GenDecl{ + Tok: token.IMPORT, + Specs: make([]ast.Spec, 0), + } + f.Decls = append([]ast.Decl{importDecl}, f.Decls...) + } + + // Check existing imports to avoid duplicates. + existImports := make(map[string]struct{}) + for _, spec := range importDecl.Specs { + importSpec, ok := spec.(*ast.ImportSpec) + if !ok { + continue + } + existImports[importSpec.Path.Value] = struct{}{} + } + + // Add new import statements. + for _, importStmt := range opts.imports { + // Check if the import already exists. + path := strconv.Quote(importStmt.repo) + if _, ok := existImports[path]; ok { + continue + } + // Create a new import spec. + spec := &ast.ImportSpec{ + Name: ast.NewIdent(importStmt.name), + Path: &ast.BasicLit{ + Kind: token.STRING, + Value: path, + }, + } + + switch { + case importStmt.index == -1: + // Append the new argument to the end + importDecl.Specs = append(importDecl.Specs, spec) + case importStmt.index >= 0 && importStmt.index <= len(importDecl.Specs): + // Insert the new argument at the specified index + importDecl.Specs = append(importDecl.Specs[:importStmt.index], append([]ast.Spec{spec}, importDecl.Specs[importStmt.index:]...)...) + default: + return "", errors.Errorf("index out of range") // Stop the inspection, an error occurred + } + } + + // Format the modified AST. + var buf bytes.Buffer + if err := format.Node(&buf, fileSet, f); err != nil { + return "", err + } + + return buf.String(), nil +} diff --git a/ignite/pkg/xast/import_test.go b/ignite/pkg/xast/import_test.go new file mode 100644 index 0000000000..afa0c28794 --- /dev/null +++ b/ignite/pkg/xast/import_test.go @@ -0,0 +1,252 @@ +package xast + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ignite/cli/v28/ignite/pkg/errors" +) + +func TestAppendImports(t *testing.T) { + existingContent := `package main + +import ( + "fmt" +) + +func main() { + fmt.Println("Hello, world!") +}` + + type args struct { + fileContent string + imports []ImportOptions + } + tests := []struct { + name string + args args + want string + err error + }{ + { + name: "add single import statement", + args: args{ + fileContent: existingContent, + imports: []ImportOptions{ + WithImport("strings", -1), + }, + }, + want: `package main + +import ( + "fmt" + "strings" +) + +func main() { + fmt.Println("Hello, world!") +} +`, + }, + { + name: "add multiple import statements", + args: args{ + fileContent: existingContent, + imports: []ImportOptions{ + WithNamedImport("st", "strings", -1), + WithImport("strconv", -1), + WithLastImport("os"), + }, + }, + want: `package main + +import ( + "fmt" + "os" + "strconv" + st "strings" +) + +func main() { + fmt.Println("Hello, world!") +} +`, + }, + { + name: "add multiple import statements with an existing one", + args: args{ + fileContent: existingContent, + imports: []ImportOptions{ + WithNamedImport("st", "strings", -1), + WithImport("strconv", -1), + WithImport("os", -1), + WithLastImport("fmt"), + }, + }, + want: `package main + +import ( + "fmt" + "os" + "strconv" + st "strings" +) + +func main() { + fmt.Println("Hello, world!") +} +`, + }, + { + name: "add import to specific index", + args: args{ + fileContent: `package main + +import ( + "fmt" + "os" + st "strings" +)`, + imports: []ImportOptions{ + WithImport("strconv", 1), + }, + }, + want: `package main + +import ( + "fmt" + "os" + "strconv" + st "strings" +) +`, + }, + { + name: "add multiple imports to specific index", + args: args{ + fileContent: `package main + +import ( + "fmt" + "os" + st "strings" +)`, + imports: []ImportOptions{ + WithImport("strconv", 0), + WithNamedImport("", "testing", 3), + WithLastImport("bytes"), + }, + }, + want: `package main + +import ( + "bytes" + "fmt" + "os" + "strconv" + st "strings" + "testing" +) +`, + }, + { + name: "add duplicate import statement", + args: args{ + fileContent: existingContent, + imports: []ImportOptions{ + WithLastImport("fmt"), + }, + }, + want: existingContent + "\n", + }, + { + name: "no import statement", + args: args{ + fileContent: `package main + +func main() { + fmt.Println("Hello, world!") +}`, + imports: []ImportOptions{ + WithImport("fmt", -1), + }, + }, + want: `package main + +import "fmt" + +func main() { + fmt.Println("Hello, world!") +} +`, + }, + { + name: "no import statement and add two imports", + args: args{ + fileContent: `package main + +func main() { + fmt.Println("Hello, world!") +}`, + imports: []ImportOptions{ + WithImport("fmt", -1), + WithLastImport("os"), + }, + }, + want: `package main + +import ( + "fmt" + "os" +) + +func main() { + fmt.Println("Hello, world!") +} +`, + }, + { + name: "invalid index", + args: args{ + fileContent: existingContent, + imports: []ImportOptions{ + WithImport("strings", 10), + }, + }, + err: errors.New("index out of range"), + }, + { + name: "add invalid import name", + args: args{ + fileContent: existingContent, + imports: []ImportOptions{ + WithNamedImport("fmt\"", "fmt\"", -1), + }, + }, + err: errors.New("format.Node internal error (5:8: expected ';', found fmt (and 2 more errors))"), + }, + { + name: "add empty file content", + args: args{ + fileContent: "", + imports: []ImportOptions{ + WithImport("fmt", -1), + }, + }, + err: errors.New("1:1: expected 'package', found 'EOF'"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := AppendImports(tt.args.fileContent, tt.args.imports...) + if tt.err != nil { + require.Error(t, err) + require.Equal(t, tt.err.Error(), err.Error()) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +}