diff --git a/session/session.go b/session/session.go index b1ffac8..0197de6 100644 --- a/session/session.go +++ b/session/session.go @@ -4,6 +4,7 @@ import ( "crypto/aes" "crypto/cipher" "crypto/hmac" + "crypto/rand" "crypto/sha1" "encoding/base64" "encoding/hex" @@ -20,12 +21,16 @@ func generateSecret(base, salt string) []byte { return pbkdf2.Key([]byte(base), []byte(salt), keyIterNum, keySize, sha1.New) } -// The origin of this snippet can be found at https://gist.github.com/doitian/2a89dc9e4372e55335c9111f576b47bf -func verifySign(encryptedData, sign, base, signSalt string) (bool, error) { +func signData(encryptedData, base, signSalt string) []byte { signKey := generateSecret(base, signSalt) signHmac := hmac.New(sha1.New, signKey) signHmac.Write([]byte(encryptedData)) - verifySign := signHmac.Sum(nil) + return signHmac.Sum(nil) +} + +// The origin of this snippet can be found at https://gist.github.com/doitian/2a89dc9e4372e55335c9111f576b47bf +func verifySign(encryptedData, sign, base, signSalt string) (bool, error) { + verifySign := signData(encryptedData, base, signSalt) signDecoded, err := hex.DecodeString(sign) if err != nil { return false, err @@ -36,6 +41,12 @@ func verifySign(encryptedData, sign, base, signSalt string) (bool, error) { return true, nil } +// sign and join data with signature using "--" (needs to be url.QueryEscape'd) +func signJoiner(encryptedData, base, signSalt string) string { + postfix := hex.EncodeToString(signData(encryptedData, base, signSalt)) + return strings.Join([]string{encryptedData, postfix}, "--") +} + func decodeCookieData(cookie []byte) (data, iv []byte, err error) { vectors := strings.SplitN(string(cookie), "--", 2) @@ -52,6 +63,13 @@ func decodeCookieData(cookie []byte) (data, iv []byte, err error) { return } +func encodeCookieData(data, iv []byte) (cookie []byte) { + datas := base64.StdEncoding.EncodeToString(data) + ivs := base64.StdEncoding.EncodeToString(iv) + cookie = []byte(strings.Join([]string{datas, ivs}, "--")) + return +} + func decryptCookie(cookie []byte, secret []byte) (dd []byte, err error) { data, iv, err := decodeCookieData(cookie) @@ -67,6 +85,60 @@ func decryptCookie(cookie []byte, secret []byte) (dd []byte, err error) { return } +// padSession implements PKCS#7 padding for the plaintext +// https://en.wikipedia.org/wiki/Padding_(cryptography)#PKCS7 +func padSession(session []byte, blockSize int) []byte { + sesslen := len(session) + padsize := blockSize - (sesslen % blockSize) + if padsize == blockSize { + return session + } + newlen := sesslen + padsize + padbyte := byte(padsize) + padded := make([]byte, newlen) + copy(padded, session) + for i := sesslen; i < newlen; i++ { + padded[i] = padbyte + } + + return padded +} + +func encryptCookie(dd, secret []byte) (cookie []byte, err error) { + c, err := aes.NewCipher(secret[:32]) + if err != nil { + return + } + padded := padSession(dd, c.BlockSize()) + iv := make([]byte, c.BlockSize()) + // rails uses a random iv, so this should be fine: https://github.com/rails/rails/blob/master/activesupport/lib/active_support/message_encryptor.rb#L172 + _, err = rand.Read(iv) + if err != nil { + return + } + + cfb := cipher.NewCBCEncrypter(c, iv) + data := make([]byte, len(padded)) + cfb.CryptBlocks(data, padded) + + cookie = encodeCookieData(data, iv) + + return +} + +// EncryptSignedCookie encrypts and signs session to produce a cookie that rails can read +func EncryptSignedCookie(session []byte, secretKeyBase, salt, signSalt string) (signedCookie string, err error) { + data, err := encryptCookie(session, generateSecret(secretKeyBase, salt)) + if err != nil { + return + } + + datastr := base64.StdEncoding.EncodeToString(data) + cookie := signJoiner(datastr, secretKeyBase, signSalt) + signedCookie = url.QueryEscape(cookie) + return +} + func DecryptSignedCookie(signedCookie, secretKeyBase, salt, signSalt string) (session []byte, err error) { cookie, err := url.QueryUnescape(signedCookie) if err != nil { diff --git a/session/session_test.go b/session/session_test.go index 035c196..2e5def0 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -1,7 +1,7 @@ package session import ( - "encoding/json" + "bytes" "net/url" "strings" "testing" @@ -9,14 +9,21 @@ import ( const ( secretKeyBase = "fe98c394d54eeae9edff39c1934b156607e4376188463d397d460eef9585cf15c0dd23f353877552d1c9b0565a03b7fdeadfb33907c6d582eb02319a7409610b" - salt = "encrypted cookie" - signSalt = "signed encrypted cookie" - // The cookie's original content is: // map[flash:map[discard:[] flashes:map[notice:Welcome! You have signed up successfully.]] session_id:b85897340bfedc7e03b7e9479c271439 _csrf_token:dTDcQiGuEE8n6KUQmXNhIoXsLQJlqrBPUAsspGMpkdg= warden.user.user.key:[[1] $2a$11$6omJ7/e3Ni7Pl7jZbCdDBu]] signedCookie = "RkpiOStFLzExVm42aXZiMFZWaDB3c09rbEE4aTUvcEg5Q1VnaTNDOTBwMTdSUGFsdjZqbWZpQmV3eXhQbEJieE1EYXZCQXNGNFhKREI5aUx0aXVFZE1vaXQzSTdtYzc5S1NmeXBEZG93Mm1PQmQ2RVMvdjRqbTdsTW1qTjcxRTZFSVpCZFBUcTByN0ZYQmhWWVZPVE45RUsyS2NRcEV5QkdsajRUL3FGYjNmdUZrYmZ5TVZxSlpucllOaXlTN0pZZG85eHlMNEN0MVdYayttdE8wNTBTSElDYTRqditGMmpoL09hcDhkTFZ0dngyM244aG53aWNLNWRvVTN3K2dpUWd0eGttRXZUdGx2TGJHS0xlN0hKWFI2aVhuQlE4Y3NvYWx1QTZvcDRkbDJZdjl4NGJ1b1B1WW9QdXdEOVpzcCtBR1BCVDkxZkNSVENJZkVqMkgzR3pxQ1lVVEJmQlBYK0ZIQWJ5WHRpOC84PS0taDluekdrZE1LbzVrZDVlMHFSSzNjdz09--5f676b46cb0671630fd33bfec08b6fbf3f858c6a" + salt = "encrypted cookie" + signSalt = "signed encrypted cookie" + + // The cookie's original content is: + plainjson = `{"session_id":"b85897340bfedc7e03b7e9479c271439","_csrf_token":"dTDcQiGuEE8n6KUQmXNhIoXsLQJlqrBPUAsspGMpkdg=","warden.user.user.key":[[1],"$2a$11$6omJ7/e3Ni7Pl7jZbCdDBu"],"flash":{"discard":[],"flashes":{"notice":"Welcome! You have signed up successfully."}}}` ) +type sessionObj struct { + SessionID string `json:"session_id"` + CSRF string `json:"_csrf_token"` +} + func TestVerifySign(t *testing.T) { cookie, _ := url.QueryUnescape(signedCookie) vectors := strings.SplitN(cookie, "--", 2) @@ -36,16 +43,52 @@ func TestVerifySign(t *testing.T) { } } -func TestDecryptSignedCookie(t *testing.T) { - cookieData, err := DecryptSignedCookie(signedCookie, secretKeyBase, salt, signSalt) +func TestSignJoiner(t *testing.T) { + encryptedData := "RkpiOStFLzExVm42aXZiMFZWaDB3c09rbEE4aTUvcEg5Q1VnaTNDOTBwMTdSUGFsdjZqbWZpQmV3eXhQbEJieE1EYXZCQXNGNFhKREI5aUx0aXVFZE1vaXQzSTdtYzc5S1NmeXBEZG93Mm1PQmQ2RVMvdjRqbTdsTW1qTjcxRTZFSVpCZFBUcTByN0ZYQmhWWVZPVE45RUsyS2NRcEV5QkdsajRUL3FGYjNmdUZrYmZ5TVZxSlpucllOaXlTN0pZZG85eHlMNEN0MVdYayttdE8wNTBTSElDYTRqditGMmpoL09hcDhkTFZ0dngyM244aG53aWNLNWRvVTN3K2dpUWd0eGttRXZUdGx2TGJHS0xlN0hKWFI2aVhuQlE4Y3NvYWx1QTZvcDRkbDJZdjl4NGJ1b1B1WW9QdXdEOVpzcCtBR1BCVDkxZkNSVENJZkVqMkgzR3pxQ1lVVEJmQlBYK0ZIQWJ5WHRpOC84PS0taDluekdrZE1LbzVrZDVlMHFSSzNjdz09" + if want, got := signedCookie, url.QueryEscape(signJoiner(encryptedData, secretKeyBase, signSalt)); want != got { + t.Errorf("expected: %q\ngot: %q", want, got) + } +} + +func TestEncryptCookie(t *testing.T) { + // get plaintext + plaintext, _ := DecryptSignedCookie(signedCookie, secretKeyBase, salt, signSalt) + ciphertext, _ := encryptCookie(plaintext, generateSecret(secretKeyBase, salt)) + plaintext2, _ := decryptCookie(ciphertext, generateSecret(secretKeyBase, salt)) + if !bytes.Equal(plaintext, plaintext2) { + t.Errorf("ciphertext output by encryptCookie cannot be decrypted by decryptCookie") + } +} + +func TestEncryptSignedCookie(t *testing.T) { + want := make([]byte, len(plainjson)+13) + copy(want, plainjson) + for i := len(plainjson); i < len(plainjson)+13; i++ { + want[i] = 13 + } + + ciphertext, _ := EncryptSignedCookie([]byte(plainjson), secretKeyBase, salt, signSalt) + got, err := DecryptSignedCookie(ciphertext, secretKeyBase, salt, signSalt) if err != nil { - t.Errorf("DecryptSignedCookie test failure: %v", err) + t.Errorf("got error from decrypting ciphertext: %v", err) + } + if !bytes.Equal(want, got) { + t.Errorf("decrypted ciphertext %q does not match plaintext %q", string(got), string(want)) } - var jsonData map[string]interface{} - if err := json.Unmarshal(cookieData, &jsonData); err != nil { +} + +func TestDecryptSignedCookie(t *testing.T) { + want := make([]byte, len(plainjson)+13) + copy(want, plainjson) + for i := len(plainjson); i < len(plainjson)+13; i++ { + want[i] = 13 + } + + got, err := DecryptSignedCookie(signedCookie, secretKeyBase, salt, signSalt) + if err != nil { t.Errorf("DecryptSignedCookie test failure: %v", err) } - if jsonData["session_id"] != "b85897340bfedc7e03b7e9479c271439" { - t.Error("DecryptSignedCookie get wrong values after deserialization") + if !bytes.Equal(want, got) { + t.Errorf("decrypted ciphertext %q does not match plaintext %q", string(got), string(want)) } }