diff --git a/hotp.go b/hotp.go index 9ab34a4..292838a 100644 --- a/hotp.go +++ b/hotp.go @@ -88,9 +88,9 @@ func hotpFromURL(u *url.URL) (*HOTP, string, error) { } } - key, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret) + key, err := base32.StdEncoding.DecodeString(Pad(secret)) if err != nil { - // secret isn't base32 encoded + // assume secret isn't base32 encoded key = []byte(secret) } otp := NewHOTP(key, counter, digits) diff --git a/otp_test.go b/otp_test.go index 3b5e993..418baff 100644 --- a/otp_test.go +++ b/otp_test.go @@ -69,6 +69,46 @@ func TestURL(t *testing.T) { } } +// This test makes sure we can generate codes for padded and non-padded +// entries +func TestPaddedURL(t *testing.T) { + var urlList = []string{ + "otpauth://hotp/?secret=ME", + "otpauth://hotp/?secret=MEFR", + "otpauth://hotp/?secret=MFRGG", + "otpauth://hotp/?secret=MFRGGZA", + "otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by=======", + "otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by", + "otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by%3D%3D%3D%3D%3D%3D%3D", + } + var codeList = []string{ + "413198", + "770938", + "670717", + "402378", + "069864", + "069864", + "069864", + } + + for i := range urlList { + if o, id, err := FromURL(urlList[i]); err != nil { + fmt.Println("hotp: URL should have parsed successfully") + fmt.Printf("\turl was: %s\n", urlList[i]) + t.FailNow() + fmt.Printf("\t%s, %s\n", o.OTP(), id) + } else { + code2 := o.OTP() + if code2 != codeList[i] { + fmt.Printf("hotp: mismatched OTPs\n") + fmt.Printf("\texpected: %s\n", codeList[i]) + fmt.Printf("\t actual: %s\n", code2) + t.FailNow() + } + } + } +} + // This test attempts a variety of invalid urls against the parser // to ensure they fail. func TestBadURL(t *testing.T) { diff --git a/totp.go b/totp.go index 4628fa8..789de4a 100644 --- a/totp.go +++ b/totp.go @@ -150,9 +150,9 @@ func totpFromURL(u *url.URL) (*TOTP, string, error) { } } - key, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(secret) + key, err := base32.StdEncoding.DecodeString(Pad(secret)) if err != nil { - // secret isn't base32 encoded + // assume secret isn't base32 encoded key = []byte(secret) } otp := NewTOTP(key, 0, period, digits, algo) diff --git a/util.go b/util.go new file mode 100644 index 0000000..af15c0f --- /dev/null +++ b/util.go @@ -0,0 +1,16 @@ +package twofactor + +import ( + "strings" +) + +// Pad calculates the number of '='s to add to our encoded string +// to make base32.StdEncoding.DecodeString happy +func Pad(s string) string { + if !strings.HasSuffix(s, "=") && len(s)%8 != 0 { + for len(s)%8 != 0 { + s += "=" + } + } + return s +} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..647a931 --- /dev/null +++ b/util_test.go @@ -0,0 +1,53 @@ +package twofactor + +import ( + "encoding/base32" + "fmt" + "math/rand" + "strings" + "testing" +) + +const letters = "1234567890!@#$%^&*()abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func randString() string { + b := make([]byte, rand.Intn(len(letters))) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return base32.StdEncoding.EncodeToString(b) +} + +func TestPadding(t *testing.T) { + for i := 0; i < 300; i++ { + b := randString() + origEncoding := string(b) + modEncoding := strings.Replace(string(b), "=", "", -1) + str, err := base32.StdEncoding.DecodeString(origEncoding) + if err != nil { + fmt.Println("Can't decode: ", string(b)) + t.FailNow() + } + + paddedEncoding := Pad(modEncoding) + if origEncoding != paddedEncoding { + fmt.Println("Padding failed:") + fmt.Printf("Expected: '%s'", origEncoding) + fmt.Printf("Got: '%s'", paddedEncoding) + t.FailNow() + } else { + mstr, err := base32.StdEncoding.DecodeString(paddedEncoding) + if err != nil { + fmt.Println("Can't decode: ", paddedEncoding) + t.FailNow() + } + + if string(mstr) != string(str) { + fmt.Println("Re-padding failed:") + fmt.Printf("Expected: '%s'", str) + fmt.Printf("Got: '%s'", mstr) + t.FailNow() + } + } + } +}