Skip to content
This repository has been archived by the owner on Apr 19, 2024. It is now read-only.

Commit

Permalink
Add WriteAnswer support for promoted fields (#366)
Browse files Browse the repository at this point in the history
  • Loading branch information
System-Glitch authored Aug 3, 2021
1 parent 3fff19a commit 8a89877
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 39 deletions.
54 changes: 36 additions & 18 deletions core/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ type OptionAnswer struct {
Index int
}

type reflectField struct {
value reflect.Value
fieldType reflect.StructField
}

func OptionAnswerList(incoming []string) []OptionAnswer {
list := []OptionAnswer{}
for i, opt := range incoming {
Expand Down Expand Up @@ -63,13 +68,12 @@ func WriteAnswer(t interface{}, name string, v interface{}) (err error) {
}

// get the name of the field that matches the string we were given
fieldIndex, err := findFieldIndex(elem, name)
field, _, err := findField(elem, name)
// if something went wrong
if err != nil {
// bubble up
return err
}
field := elem.Field(fieldIndex)
// handle references to the Settable interface aswell
if s, ok := field.Interface().(Settable); ok {
// use the interface method
Expand Down Expand Up @@ -156,37 +160,51 @@ func IsFieldNotMatch(err error) (string, bool) {

// BUG(AlecAivazis): the current implementation might cause weird conflicts if there are
// two fields with same name that only differ by casing.
func findFieldIndex(s reflect.Value, name string) (int, error) {
// the type of the value
sType := s.Type()
func findField(s reflect.Value, name string) (reflect.Value, reflect.StructField, error) {

// first look for matching tags so we can overwrite matching field names
for i := 0; i < sType.NumField(); i++ {
// the field we are current scanning
field := sType.Field(i)
fields := flattenFields(s)

// first look for matching tags so we can overwrite matching field names
for _, f := range fields {
// the value of the survey tag
tag := field.Tag.Get(tagName)
tag := f.fieldType.Tag.Get(tagName)
// if the tag matches the name we are looking for
if tag != "" && tag == name {
// then we found our index
return i, nil
return f.value, f.fieldType, nil
}
}

// then look for matching names
for i := 0; i < sType.NumField(); i++ {
// the field we are current scanning
field := sType.Field(i)

for _, f := range fields {
// if the name of the field matches what we're looking for
if strings.ToLower(field.Name) == strings.ToLower(name) {
return i, nil
if strings.ToLower(f.fieldType.Name) == strings.ToLower(name) {
return f.value, f.fieldType, nil
}
}

// we didn't find the field
return -1, errFieldNotMatch{name}
return reflect.Value{}, reflect.StructField{}, errFieldNotMatch{name}
}

func flattenFields(s reflect.Value) []reflectField {
sType := s.Type()
numField := sType.NumField()
fields := make([]reflectField, 0, numField)
for i := 0; i < numField; i++ {
fieldType := sType.Field(i)
field := s.Field(i)

if field.Kind() == reflect.Struct && fieldType.Anonymous {
// field is a promoted structure
for _, f := range flattenFields(field) {
fields = append(fields, f)
}
continue
}
fields = append(fields, reflectField{field, fieldType})
}
return fields
}

// isList returns true if the element is something we can Len()
Expand Down
178 changes: 157 additions & 21 deletions core/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,12 @@ func TestWriteAnswer_returnsErrWhenFieldNotFound(t *testing.T) {
}
}

func TestFindFieldIndex_canFindExportedField(t *testing.T) {
func TestFindField_canFindExportedField(t *testing.T) {
// create a reflective wrapper over the struct to look through
val := reflect.ValueOf(struct{ Name string }{})
val := reflect.ValueOf(struct{ Name string }{Name: "Jack"})

// find the field matching "name"
fieldIndex, err := findFieldIndex(val, "name")
field, fieldType, err := findField(val, "name")
// if something went wrong
if err != nil {
// the test failed
Expand All @@ -319,20 +319,28 @@ func TestFindFieldIndex_canFindExportedField(t *testing.T) {
}

// make sure we got the right value
if val.Type().Field(fieldIndex).Name != "Name" {
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", val.Type().Field(fieldIndex).Name)
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right field type
if fieldType.Name != "Name" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", fieldType.Name)
}
}

func TestFindFieldIndex_canFindTaggedField(t *testing.T) {
func TestFindField_canFindTaggedField(t *testing.T) {
// the struct to look through
val := reflect.ValueOf(struct {
Username string `survey:"name"`
}{})
}{
Username: "Jack",
})

// find the field matching "name"
fieldIndex, err := findFieldIndex(val, "name")
field, fieldType, err := findField(val, "name")
// if something went wrong
if err != nil {
// the test failed
Expand All @@ -341,52 +349,180 @@ func TestFindFieldIndex_canFindTaggedField(t *testing.T) {
}

// make sure we got the right value
if val.Type().Field(fieldIndex).Name != "Username" {
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", val.Type().Field(fieldIndex).Name)
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right fieldType
if fieldType.Name != "Username" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name)
}
}

func TestFindFieldIndex_canHandleCapitalAnswerNames(t *testing.T) {
func TestFindField_canHandleCapitalAnswerNames(t *testing.T) {
// create a reflective wrapper over the struct to look through
val := reflect.ValueOf(struct{ Name string }{})
val := reflect.ValueOf(struct{ Name string }{Name: "Jack"})

// find the field matching "name"
fieldIndex, err := findFieldIndex(val, "Name")
field, fieldType, err := findField(val, "Name")
// if something went wrong
if err != nil {
// the test failed
t.Error(err.Error())
return
}

// make sure we got the right value
if val.Type().Field(fieldIndex).Name != "Name" {
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right fieldType
if fieldType.Name != "Name" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", val.Type().Field(fieldIndex).Name)
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", fieldType.Name)
}
}

func TestFindFieldIndex_tagOverwriteFieldName(t *testing.T) {
func TestFindField_tagOverwriteFieldName(t *testing.T) {
// the struct to look through
val := reflect.ValueOf(struct {
Name string
Username string `survey:"name"`
}{})
}{
Name: "Ralf",
Username: "Jack",
})

// find the field matching "name"
field, fieldType, err := findField(val, "name")
// if something went wrong
if err != nil {
// the test failed
t.Error(err.Error())
return
}

// make sure we got the right value
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right fieldType
if fieldType.Name != "Username" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name)
}
}

func TestFindField_supportsPromotedFields(t *testing.T) {
// create a reflective wrapper over the struct to look through
type Common struct {
Name string
}

type Strct struct {
Common // Name field added by composition
Username string
}

val := reflect.ValueOf(Strct{Common: Common{Name: "Jack"}})

// find the field matching "name"
field, fieldType, err := findField(val, "Name")
// if something went wrong
if err != nil {
// the test failed
t.Error(err.Error())
return
}
// make sure we got the right value
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right fieldType
if fieldType.Name != "Name" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Name' found %v.", fieldType.Name)
}
}

func TestFindField_promotedFieldsWithTag(t *testing.T) {
// create a reflective wrapper over the struct to look through
type Common struct {
Username string `survey:"name"`
}

type Strct struct {
Common // Name field added by composition
Name string
}

val := reflect.ValueOf(Strct{
Common: Common{Username: "Jack"},
Name: "Ralf",
})

// find the field matching "name"
fieldIndex, err := findFieldIndex(val, "name")
field, fieldType, err := findField(val, "name")
// if something went wrong
if err != nil {
// the test failed
t.Error(err.Error())
return
}
// make sure we got the right value
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right fieldType
if fieldType.Name != "Username" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name)
}
}

func TestFindField_promotedFieldsDontHavePriorityOverTags(t *testing.T) {
// create a reflective wrapper over the struct to look through
type Common struct {
Name string
}

type Strct struct {
Common // Name field added by composition
Username string `survey:"name"`
}

val := reflect.ValueOf(Strct{
Common: Common{Name: "Ralf"},
Username: "Jack",
})

// find the field matching "name"
field, fieldType, err := findField(val, "name")
// if something went wrong
if err != nil {
// the test failed
t.Error(err.Error())
return
}
// make sure we got the right value
if val.Type().Field(fieldIndex).Name != "Username" {
if field.Interface() != "Jack" {
// the test failed
t.Errorf("Did not find the correct field value. Expected 'Jack' found %v.", field.Interface())
}

// make sure we got the right fieldType
if fieldType.Name != "Username" {
// the test failed
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", val.Type().Field(fieldIndex).Name)
t.Errorf("Did not find the correct field name. Expected 'Username' found %v.", fieldType.Name)
}
}

Expand Down

0 comments on commit 8a89877

Please sign in to comment.