Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiple SSH keys for the same host #1433

Merged
merged 1 commit into from
Oct 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/credentials/gitcreds/creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func flags(fs *flag.FlagSet) {
basicConfig = basicGitConfig{entries: make(map[string]basicEntry)}
fs.Var(&basicConfig, basicAuthFlag, "List of secret=url pairs.")

sshConfig = sshGitConfig{entries: make(map[string]sshEntry)}
sshConfig = sshGitConfig{entries: make(map[string][]sshEntry)}
fs.Var(&sshConfig, sshFlag, "List of secret=url pairs.")
}

Expand Down
54 changes: 16 additions & 38 deletions pkg/credentials/gitcreds/creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ func TestSSHFlagHandling(t *testing.T) {

expectedSSHConfig := fmt.Sprintf(`Host github.com
HostName github.com
IdentityFile %s
Port 22
`, filepath.Join(os.Getenv("HOME"), ".ssh", "id_foo"))
if string(b) != expectedSSHConfig {
t.Errorf("got: %v, wanted: %v", string(b), expectedSSHConfig)
IdentityFile %s/.ssh/id_foo
`, credentials.VolumePath)
if d := cmp.Diff(expectedSSHConfig, string(b)); d != "" {
t.Errorf("ssh_config diff: %s", d)
}

b, err = ioutil.ReadFile(filepath.Join(credentials.VolumePath, ".ssh", "known_hosts"))
Expand Down Expand Up @@ -283,8 +283,10 @@ func TestSSHFlagHandlingThrice(t *testing.T) {
fs := flag.NewFlagSet("test", flag.ContinueOnError)
flags(fs)
err := fs.Parse([]string{
// Two secrets target github.com, and both will end up in the
// ssh config.
"-ssh-git=foo=github.com",
"-ssh-git=bar=gitlab.com",
"-ssh-git=bar=github.com",
"-ssh-git=baz=gitlab.example.com:2222",
})
if err != nil {
Expand All @@ -303,21 +305,16 @@ func TestSSHFlagHandlingThrice(t *testing.T) {

expectedSSHConfig := fmt.Sprintf(`Host github.com
HostName github.com
IdentityFile %s
Port 22
Host gitlab.com
HostName gitlab.com
IdentityFile %s
Port 22
IdentityFile %s/.ssh/id_foo
IdentityFile %s/.ssh/id_bar
Host gitlab.example.com
HostName gitlab.example.com
IdentityFile %s
Port 2222
`, filepath.Join(os.Getenv("HOME"), ".ssh", "id_foo"),
filepath.Join(os.Getenv("HOME"), ".ssh", "id_bar"),
filepath.Join(os.Getenv("HOME"), ".ssh", "id_baz"))
if string(b) != expectedSSHConfig {
t.Errorf("got: %v, wanted: %v", string(b), expectedSSHConfig)
IdentityFile %s/.ssh/id_baz
`, credentials.VolumePath, credentials.VolumePath, credentials.VolumePath)
if d := cmp.Diff(expectedSSHConfig, string(b)); d != "" {
t.Errorf("ssh_config diff: %s", d)
}

b, err = ioutil.ReadFile(filepath.Join(credentials.VolumePath, ".ssh", "known_hosts"))
Expand All @@ -327,8 +324,8 @@ Host gitlab.example.com
expectedSSHKnownHosts := `ssh-rsa aaaa
ssh-rsa bbbb
ssh-rsa cccc`
if string(b) != expectedSSHKnownHosts {
t.Errorf("got: %v, wanted: %v", string(b), expectedSSHKnownHosts)
if d := cmp.Diff(expectedSSHKnownHosts, string(b)); d != "" {
t.Errorf("known_hosts diff: %s", d)
}

b, err = ioutil.ReadFile(filepath.Join(credentials.VolumePath, ".ssh", "id_foo"))
Expand Down Expand Up @@ -370,31 +367,12 @@ func TestSSHFlagHandlingMissingFiles(t *testing.T) {
}
// No ssh-privatekey files yields an error.

cfg := sshGitConfig{entries: make(map[string]sshEntry)}
cfg := sshGitConfig{entries: make(map[string][]sshEntry)}
if err := cfg.Set("not-found=github.com"); err == nil {
t.Error("Set(); got success, wanted error.")
}
}

func TestSSHFlagHandlingURLCollision(t *testing.T) {
credentials.VolumePath, _ = ioutil.TempDir("", "")
dir := credentials.VolumeName("foo")
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
t.Fatalf("os.MkdirAll(%s) = %v", dir, err)
}
if err := ioutil.WriteFile(filepath.Join(dir, corev1.SSHAuthPrivateKey), []byte("bar"), 0777); err != nil {
t.Fatalf("ioutil.WriteFile(ssh-privatekey) = %v", err)
}

cfg := sshGitConfig{entries: make(map[string]sshEntry)}
if err := cfg.Set("foo=github.com"); err != nil {
t.Fatalf("First Set() = %v", err)
}
if err := cfg.Set("bar=github.com"); err == nil {
t.Error("Second Set(); got success, wanted error.")
}
}

func TestBasicMalformedValues(t *testing.T) {
tests := []string{
"bar=baz=blah",
Expand Down
53 changes: 27 additions & 26 deletions pkg/credentials/gitcreds/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const sshKnownHosts = "known_hosts"
// As the flag is read, this status is populated.
// sshGitConfig implements flag.Value
type sshGitConfig struct {
entries map[string]sshEntry
entries map[string][]sshEntry
// The order we see things, for iterating over the above.
order []string
}
Expand All @@ -48,8 +48,9 @@ func (dc *sshGitConfig) String() string {
}
var urls []string
for _, k := range dc.order {
v := dc.entries[k]
urls = append(urls, fmt.Sprintf("%s=%s", v.secret, k))
for _, e := range dc.entries[k] {
urls = append(urls, fmt.Sprintf("%s=%s", e.secretName, k))
}
}
return strings.Join(urls, ",")
}
Expand All @@ -59,19 +60,17 @@ func (dc *sshGitConfig) Set(value string) error {
if len(parts) != 2 {
return xerrors.Errorf("Expect entries of the form secret=url, got: %v", value)
}
secret := parts[0]
secretName := parts[0]
url := parts[1]

if _, ok := dc.entries[url]; ok {
return xerrors.Errorf("Multiple entries for url: %v", url)
}

e, err := newSshEntry(url, secret)
e, err := newSshEntry(url, secretName)
if err != nil {
return err
}
dc.entries[url] = *e
dc.order = append(dc.order, url)
if _, exists := dc.entries[url]; !exists {
dc.order = append(dc.order, url)
}
dc.entries[url] = append(dc.entries[url], *e)
return nil
}

Expand All @@ -82,7 +81,7 @@ func (dc *sshGitConfig) Write() error {
}

// Walk each of the entries and for each do three things:
// 1. Write out: ~/.ssh/id_{secret} with the secret key
// 1. Write out: ~/.ssh/id_{secretName} with the secret key
// 2. Compute its part of "~/.ssh/config"
// 3. Compute its part of "~/.ssh/known_hosts"
var configEntries []string
Expand All @@ -95,17 +94,19 @@ func (dc *sshGitConfig) Write() error {
host = k
port = defaultPort
}
v := dc.entries[k]
if err := v.Write(sshDir); err != nil {
return err
}
configEntries = append(configEntries, fmt.Sprintf(`Host %s
configEntry := fmt.Sprintf(`Host %s
HostName %s
IdentityFile %s
Port %s
`, host, host, v.path(sshDir), port))

knownHosts = append(knownHosts, v.knownHosts)
`, host, host, port)
for _, e := range dc.entries[k] {
if err := e.Write(sshDir); err != nil {
return err
}
configEntry += fmt.Sprintf(` IdentityFile %s
`, e.path(sshDir))
knownHosts = append(knownHosts, e.knownHosts)
}
configEntries = append(configEntries, configEntry)
}
configPath := filepath.Join(sshDir, "config")
configContent := strings.Join(configEntries, "")
Expand All @@ -118,13 +119,13 @@ func (dc *sshGitConfig) Write() error {
}

type sshEntry struct {
secret string
secretName string
privateKey string
knownHosts string
}

func (be *sshEntry) path(sshDir string) string {
return filepath.Join(sshDir, "id_"+be.secret)
return filepath.Join(sshDir, "id_"+be.secretName)
}

func sshKeyScan(domain string) ([]byte, error) {
Expand All @@ -142,8 +143,8 @@ func (be *sshEntry) Write(sshDir string) error {
return ioutil.WriteFile(be.path(sshDir), []byte(be.privateKey), 0600)
}

func newSshEntry(u, secret string) (*sshEntry, error) {
secretPath := credentials.VolumeName(secret)
func newSshEntry(u, secretName string) (*sshEntry, error) {
secretPath := credentials.VolumeName(secretName)

pk, err := ioutil.ReadFile(filepath.Join(secretPath, corev1.SSHAuthPrivateKey))
if err != nil {
Expand All @@ -161,7 +162,7 @@ func newSshEntry(u, secret string) (*sshEntry, error) {
knownHosts := string(kh)

return &sshEntry{
secret: secret,
secretName: secretName,
privateKey: privateKey,
knownHosts: knownHosts,
}, nil
Expand Down