diff --git a/join.go b/join.go index 7dd08db..aa32b85 100644 --- a/join.go +++ b/join.go @@ -39,17 +39,27 @@ func IsNotExist(err error) bool { // components in the returned string are not modified (in other words are not // replaced with symlinks on the filesystem) after this function has returned. // Such a symlink race is necessarily out-of-scope of SecureJoin. +// +// Volume names in unsafePath are always discarded, regardless if they are +// provided via direct input or when evaluating symlinks. Therefore: +// +// "C:\Temp" + "D:\path\to\file.txt" results in "C:\Temp\path\to\file.txt" func SecureJoinVFS(root, unsafePath string, vfs VFS) (string, error) { // Use the os.* VFS implementation if none was specified. if vfs == nil { vfs = osVFS{} } + unsafePath = filepath.FromSlash(unsafePath) var path bytes.Buffer n := 0 for unsafePath != "" { if n > 255 { - return "", &os.PathError{Op: "SecureJoin", Path: root + "/" + unsafePath, Err: syscall.ELOOP} + return "", &os.PathError{Op: "SecureJoin", Path: root + string(filepath.Separator) + unsafePath, Err: syscall.ELOOP} + } + + if v := filepath.VolumeName(unsafePath); v != "" { + unsafePath = unsafePath[len(v):] } // Next path component, p. diff --git a/join_test.go b/join_test.go index 754ded0..a234c82 100644 --- a/join_test.go +++ b/join_test.go @@ -9,6 +9,7 @@ import ( "io/ioutil" "os" "path/filepath" + "runtime" "syscall" "testing" ) @@ -22,6 +23,11 @@ func symlink(t *testing.T, oldname, newname string) { } } +type input struct { + root, unsafe string + expected string +} + // Test basic handling of symlink expansion. func TestSymlink(t *testing.T) { dir, err := ioutil.TempDir("", "TestSymlink") @@ -38,21 +44,26 @@ func TestSymlink(t *testing.T) { symlink(t, "../../../../../../../../../../../../../etc", filepath.Join(dir, "etclink")) symlink(t, "/../../../../../../../../../../../../../etc/passwd", filepath.Join(dir, "passwd")) - for _, test := range []struct { - root, unsafe string - expected string - }{ + rootOrVol := string(filepath.Separator) + if vol := filepath.VolumeName(dir); vol != "" { + rootOrVol = vol + rootOrVol + } + + tc := []input{ // Make sure that expansion with a root of '/' proceeds in the expected fashion. - {"/", filepath.Join(dir, "passwd"), "/etc/passwd"}, - {"/", filepath.Join(dir, "etclink"), "/etc"}, - {"/", filepath.Join(dir, "etc"), filepath.Join(dir, "somepath")}, + {rootOrVol, filepath.Join(dir, "passwd"), filepath.Join(rootOrVol, "etc", "passwd")}, + {rootOrVol, filepath.Join(dir, "etclink"), filepath.Join(rootOrVol, "etc")}, + + {rootOrVol, filepath.Join(dir, "etc"), filepath.Join(dir, "somepath")}, // Now test scoped expansion. {dir, "passwd", filepath.Join(dir, "somepath", "passwd")}, {dir, "etclink", filepath.Join(dir, "somepath")}, {dir, "etc", filepath.Join(dir, "somepath")}, {dir, "etc/test", filepath.Join(dir, "somepath", "test")}, {dir, "etc/test/..", filepath.Join(dir, "somepath")}, - } { + } + + for _, test := range tc { got, err := SecureJoin(test.root, test.unsafe) if err != nil { t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) @@ -85,29 +96,31 @@ func TestNoSymlink(t *testing.T) { } defer os.RemoveAll(dir) - for _, test := range []struct { - root, unsafe string - }{ - // TODO: Do we need to have some conditional FromSlash handling here? - {dir, "somepath"}, - {dir, "even/more/path"}, - {dir, "/this/is/a/path"}, - {dir, "also/a/../path/././/with/some/./.././junk"}, - {dir, "yetanother/../path/././/with/some/./.././junk../../../../../../../../../../../../etc/passwd"}, - {dir, "/../../../../../../../../../../../../../../../../etc/passwd"}, - {dir, "../../../../../../../../../../../../../../../../somedir"}, - {dir, "../../../../../../../../../../../../../../../../"}, - {dir, "./../../.././././../../../../../../../../../../../../../../../../etc passwd"}, - } { - expected := filepath.Join(test.root, filepath.Clean(string(filepath.Separator)+test.unsafe)) + tc := []input{ + {dir, "somepath", filepath.Join(dir, "somepath")}, + {dir, "even/more/path", filepath.Join(dir, "even", "more", "path")}, + {dir, "/this/is/a/path", filepath.Join(dir, "this", "is", "a", "path")}, + {dir, "also/a/../path/././/with/some/./.././junk", filepath.Join(dir, "also", "path", "with", "junk")}, + {dir, "yetanother/../path/././/with/some/./.././junk../../../../../../../../../../../../etc/passwd", filepath.Join(dir, "etc", "passwd")}, + {dir, "/../../../../../../../../../../../../../../../../etc/passwd", filepath.Join(dir, "etc", "passwd")}, + {dir, "../../../../../../../../../../../../../../../../somedir", filepath.Join(dir, "somedir")}, + {dir, "../../../../../../../../../../../../../../../../", filepath.Join(dir)}, + {dir, "./../../.././././../../../../../../../../../../../../../../../../etc passwd", filepath.Join(dir, "etc passwd")}, + } + + if runtime.GOOS == "windows" { + tc = append(tc, []input{ + {dir, "d:\\etc\\test", filepath.Join(dir, "etc", "test")}, + }...) + } + + for _, test := range tc { got, err := SecureJoin(test.root, test.unsafe) if err != nil { t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err) - continue } - if got != expected { - t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, expected, got) - continue + if got != test.expected { + t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got) } } } @@ -130,10 +143,7 @@ func TestNonLexical(t *testing.T) { symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2")) symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3")) - for _, test := range []struct { - root, unsafe string - expected string - }{ + for _, test := range []input{ {dir, "subdir", filepath.Join(dir, "subdir")}, {dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, {dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, @@ -188,7 +198,7 @@ func TestSymlinkLoop(t *testing.T) { } { got, err := SecureJoin(test.root, test.unsafe) if !errors.Is(err, syscall.ELOOP) { - t.Errorf("securejoin(%q, %q): expected ELOOP, got %v & %q", test.root, test.unsafe, err, got) + t.Errorf("securejoin(%q, %q): expected ELOOP, got %q & %v", test.root, test.unsafe, got, err) continue } } @@ -275,10 +285,7 @@ func TestSecureJoinVFS(t *testing.T) { symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2")) symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3")) - for _, test := range []struct { - root, unsafe string - expected string - }{ + for _, test := range []input{ {dir, "subdir", filepath.Join(dir, "subdir")}, {dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")}, {dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")},