mirror of
https://github.com/twofas/2fas-server.git
synced 2025-01-05 22:15:56 +01:00
feat(pass): tokens (#24)
feat(pass): tokens Add token signing and verification to be used by pass.
This commit is contained in:
parent
dbd4245b6f
commit
17fb204680
1
.gitignore
vendored
1
.gitignore
vendored
@ -22,3 +22,4 @@ licenses-errors
|
||||
|
||||
data/
|
||||
.env.testing
|
||||
|
||||
|
@ -93,6 +93,18 @@ services:
|
||||
env_file:
|
||||
- .env
|
||||
|
||||
localstack:
|
||||
container_name: "${LOCALSTACK_DOCKER_NAME:-localstack-main}"
|
||||
image: localstack/localstack
|
||||
ports:
|
||||
- "127.0.0.1:4566:4566"
|
||||
environment:
|
||||
- DEBUG=1
|
||||
volumes:
|
||||
- "./tests/localstack_init.sh:/etc/localstack/init/ready.d/localstack_init.sh" # ready hook
|
||||
- "./data/localstack:/var/lib/localstack"
|
||||
- "/var/run/docker.sock:/var/run/docker.sock"
|
||||
|
||||
volumes:
|
||||
go-modules:
|
||||
# shared-volume is used to share volume between api and admin. On producition AWS S3 is used,
|
||||
|
11
go.mod
11
go.mod
@ -12,15 +12,18 @@ require (
|
||||
github.com/go-playground/validator/v10 v10.15.5
|
||||
github.com/go-redis/redis_rate/v10 v10.0.1
|
||||
github.com/go-sql-driver/mysql v1.7.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0
|
||||
github.com/google/uuid v1.3.1
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/jaswdr/faker v1.16.0
|
||||
github.com/kelseyhightower/envconfig v1.4.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/pressly/goose/v3 v3.15.1
|
||||
github.com/redis/go-redis/v9 v9.3.0
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/viper v1.17.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
golang.org/x/sync v0.4.0
|
||||
google.golang.org/api v0.147.0
|
||||
gorm.io/datatypes v1.2.0
|
||||
gorm.io/driver/mysql v1.5.2
|
||||
@ -60,7 +63,6 @@ require (
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/kelseyhightower/envconfig v1.4.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.5 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
@ -82,13 +84,12 @@ require (
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/arch v0.5.0 // indirect
|
||||
golang.org/x/crypto v0.14.0 // indirect
|
||||
golang.org/x/crypto v0.17.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/oauth2 v0.13.0 // indirect
|
||||
golang.org/x/sync v0.4.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/sys v0.16.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
golang.org/x/time v0.3.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect
|
||||
google.golang.org/appengine v1.6.8 // indirect
|
||||
|
14
go.sum
14
go.sum
@ -144,6 +144,8 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
|
||||
github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
@ -384,8 +386,8 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
||||
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@ -529,8 +531,8 @@ golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
|
||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
@ -544,8 +546,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
62
internal/pass/sign/lib.go
Normal file
62
internal/pass/sign/lib.go
Normal file
@ -0,0 +1,62 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
awsKeySpec = "ECC_NIST_P256"
|
||||
awsSigningAlgorithm = "ECDSA_SHA_256"
|
||||
jwtSigningAlgorithm = "ES256"
|
||||
|
||||
// since we control both signature and verification, and we always use the same
|
||||
// algorithm, jwt header part (first segment) is always the same.
|
||||
// we can skip it (as in not send it) to save bytes in QR code.
|
||||
// note: header has only key type, not key id.
|
||||
jwtHeader = "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9."
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
publicKey *ecdsa.PublicKey
|
||||
signingMethod jwt.SigningMethod
|
||||
}
|
||||
|
||||
func NewService(keyID string, client *kms.KMS) (*Service, error) {
|
||||
resp, err := client.GetPublicKey(&kms.GetPublicKeyInput{
|
||||
KeyId: &keyID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch key for %q: %w", keyID, err)
|
||||
}
|
||||
if *resp.KeySpec != awsKeySpec {
|
||||
return nil, fmt.Errorf("the only supported key spec is %q, received: %q", awsKeySpec, *resp.KeySpec)
|
||||
}
|
||||
|
||||
key, err := x509.ParsePKIXPublicKey(resp.PublicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response from KMSas public key: %w", err)
|
||||
}
|
||||
|
||||
return &Service{
|
||||
publicKey: key.(*ecdsa.PublicKey),
|
||||
signingMethod: kmsSigningMethod{
|
||||
client: client,
|
||||
keyID: keyID,
|
||||
hash: crypto.SHA256,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ConnectionType string
|
||||
|
||||
const (
|
||||
ConnectionTypeBrowserExtensionWait ConnectionType = "be/wait"
|
||||
ConnectionTypeBrowserExtensionProxy ConnectionType = "be/proxy"
|
||||
ConnectionTypeMobileProxy ConnectionType = "mobile/proxy"
|
||||
)
|
169
internal/pass/sign/lib_test.go
Normal file
169
internal/pass/sign/lib_test.go
Normal file
@ -0,0 +1,169 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ecdsaSigningMethodWithStaticKey struct {
|
||||
privateKey *ecdsa.PrivateKey
|
||||
}
|
||||
|
||||
func (e ecdsaSigningMethodWithStaticKey) Verify(signingString string, sig []byte, key interface{}) error {
|
||||
panic("not needed")
|
||||
}
|
||||
|
||||
func (e ecdsaSigningMethodWithStaticKey) Sign(signingString string, key interface{}) ([]byte, error) {
|
||||
return jwt.SigningMethodES256.Sign(signingString, e.privateKey)
|
||||
}
|
||||
|
||||
func (e ecdsaSigningMethodWithStaticKey) Alg() string {
|
||||
return jwt.SigningMethodES256.Alg()
|
||||
}
|
||||
|
||||
func TestSignAndVerifyHappyPath(t *testing.T) {
|
||||
srv := createTestService(t)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
token, err := srv.SignAndEncode(Message{
|
||||
ConnectionID: uuid.New().String(),
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
ConnectionType: ConnectionTypeBrowserExtensionProxy,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func createTestService(t *testing.T) Service {
|
||||
t.Helper()
|
||||
|
||||
pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
srv := Service{
|
||||
publicKey: &pk.PublicKey,
|
||||
signingMethod: ecdsaSigningMethodWithStaticKey{
|
||||
privateKey: pk,
|
||||
},
|
||||
}
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestSignAndVerify(t *testing.T) {
|
||||
sess, err := session.NewSession(&aws.Config{
|
||||
Region: aws.String("us-east-1"),
|
||||
Credentials: credentials.NewStaticCredentials("test", "test", ""),
|
||||
S3ForcePathStyle: aws.Bool(true),
|
||||
Endpoint: aws.String("http://localhost:4566"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
kmsClient := kms.New(sess)
|
||||
srv := createTestService(t)
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenFn func() string
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "not even jwt token",
|
||||
tokenFn: func() string {
|
||||
return "xxx"
|
||||
},
|
||||
expectedError: jwt.ErrTokenMalformed,
|
||||
},
|
||||
{
|
||||
name: "token is expired",
|
||||
tokenFn: func() string {
|
||||
token, err := srv.SignAndEncode(Message{
|
||||
ConnectionID: uuid.New().String(),
|
||||
ExpiresAt: now.Add(-time.Hour),
|
||||
ConnectionType: ConnectionTypeBrowserExtensionProxy,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return token
|
||||
},
|
||||
expectedError: jwt.ErrTokenExpired,
|
||||
},
|
||||
{
|
||||
name: "invalid claim",
|
||||
tokenFn: func() string {
|
||||
token, err := srv.SignAndEncode(Message{
|
||||
ConnectionID: uuid.New().String(),
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
ConnectionType: ConnectionTypeBrowserExtensionWait,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return token
|
||||
},
|
||||
expectedError: ErrInvalidClaims,
|
||||
},
|
||||
{
|
||||
name: "invalid signature",
|
||||
tokenFn: func() string {
|
||||
resp, err := kmsClient.CreateKey(&kms.CreateKeyInput{
|
||||
KeySpec: aws.String("ECC_NIST_P256"),
|
||||
KeyUsage: aws.String("SIGN_VERIFY"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
serviceWithAnotherKey, err := NewService(*resp.KeyMetadata.KeyId, kmsClient)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
token, err := serviceWithAnotherKey.SignAndEncode(Message{
|
||||
ConnectionID: uuid.New().String(),
|
||||
ExpiresAt: now.Add(-time.Hour),
|
||||
ConnectionType: ConnectionTypeBrowserExtensionProxy,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return token
|
||||
},
|
||||
expectedError: jwt.ErrTokenSignatureInvalid,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
token := tc.tokenFn()
|
||||
err := srv.CanI(token, ConnectionTypeBrowserExtensionProxy)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error %v, got nil", tc.expectedError)
|
||||
}
|
||||
if !errors.Is(err, tc.expectedError) {
|
||||
t.Fatalf("Expected error %v, got %v", tc.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
111
internal/pass/sign/sign.go
Normal file
111
internal/pass/sign/sign.go
Normal file
@ -0,0 +1,111 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ConnectionID string
|
||||
ExpiresAt time.Time
|
||||
ConnectionType ConnectionType
|
||||
}
|
||||
|
||||
// SignAndEncode information in the message. The result
|
||||
// is second and third part of jwt token. Since the first
|
||||
// part is constant it is omitted.
|
||||
func (s Service) SignAndEncode(m Message) (string, error) {
|
||||
token := jwt.NewWithClaims(s.signingMethod, jwt.MapClaims{
|
||||
"exp": m.ExpiresAt.Unix(),
|
||||
"aud": []string{string(m.ConnectionType)},
|
||||
"c_id": m.ConnectionID,
|
||||
})
|
||||
|
||||
// no key is needed, as we use custom signing method.
|
||||
signed, err := token.SignedString(nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign jwt: %w", err)
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(signed, jwtHeader) {
|
||||
return "", fmt.Errorf("unpexpected signed string format")
|
||||
}
|
||||
|
||||
return strings.TrimPrefix(signed, jwtHeader), nil
|
||||
}
|
||||
|
||||
type kmsSigningMethod struct {
|
||||
client *kms.KMS
|
||||
keyID string
|
||||
hash crypto.Hash
|
||||
}
|
||||
|
||||
// Verify implements jwt.SigningMethod#Method. Because we
|
||||
// provide key to jwt library, this is never called.
|
||||
func (s kmsSigningMethod) Verify(signingString string, sig []byte, key interface{}) error {
|
||||
panic("should never be called")
|
||||
}
|
||||
|
||||
// Sign implements jwt.SigningMethod#Sign method.
|
||||
func (s kmsSigningMethod) Sign(signingString string, key interface{}) ([]byte, error) {
|
||||
messageType := "DIGEST"
|
||||
|
||||
hasher := s.hash.New()
|
||||
if _, err := hasher.Write([]byte(signingString)); err != nil {
|
||||
return nil, fmt.Errorf("failed to hash input")
|
||||
}
|
||||
hashedSigningString := hasher.Sum(nil)
|
||||
|
||||
resp, err := s.client.Sign(&kms.SignInput{
|
||||
KeyId: &s.keyID,
|
||||
Message: hashedSigningString,
|
||||
MessageType: &messageType,
|
||||
SigningAlgorithm: aws.String(awsSigningAlgorithm),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign the message: %w", err)
|
||||
}
|
||||
// We are using encryption method with SHA_256 digest. Hence, key has 256/8=32 bytes.
|
||||
keySizeInBytes := 256 / 8
|
||||
return formatKMSSignatureForJWT(keySizeInBytes, resp.Signature)
|
||||
}
|
||||
|
||||
// Alg implements jwt.SigningMethod#Method.
|
||||
func (s kmsSigningMethod) Alg() string {
|
||||
return jwtSigningAlgorithm
|
||||
}
|
||||
|
||||
// formatKMSSignatureForJWT translates asn1 encoded signature (returned by AWS)
|
||||
// to format expected by JWT standard.
|
||||
// It is an algorithm I found on the internet
|
||||
// (here: https://github.com/twofas/2fas-server/pull/24/files/4f68cc2e611dca18b9787942e5cf12fc16518dd4#r1452702669 )
|
||||
// It should be tested using e2e tests.
|
||||
func formatKMSSignatureForJWT(keyBytes int, sig []byte) ([]byte, error) {
|
||||
p := struct {
|
||||
R *big.Int
|
||||
S *big.Int
|
||||
}{}
|
||||
|
||||
_, err := asn1.Unmarshal(sig, &p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rBytes := p.R.Bytes()
|
||||
rBytesPadded := make([]byte, keyBytes)
|
||||
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
||||
|
||||
sBytes := p.S.Bytes()
|
||||
sBytesPadded := make([]byte, keyBytes)
|
||||
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
||||
|
||||
out := append(rBytesPadded, sBytesPadded...)
|
||||
return out, nil
|
||||
}
|
45
internal/pass/sign/verify.go
Normal file
45
internal/pass/sign/verify.go
Normal file
@ -0,0 +1,45 @@
|
||||
package sign
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var ErrInvalidClaims = errors.New("invalid claims")
|
||||
|
||||
// CanI establish connection with type tp given claims in token.
|
||||
func (s Service) CanI(tokenString string, ct ConnectionType) error {
|
||||
cl := jwt.MapClaims{}
|
||||
|
||||
// In Sign we removed `jwtHeader` from JWT before returning it.
|
||||
// We need to add it again before doing the verification.
|
||||
tokenString = jwtHeader + tokenString
|
||||
|
||||
token, err := jwt.ParseWithClaims(
|
||||
tokenString,
|
||||
&cl,
|
||||
func(token *jwt.Token) (interface{}, error) {
|
||||
return s.publicKey, nil
|
||||
},
|
||||
jwt.WithValidMethods([]string{"ES256"}),
|
||||
jwt.WithExpirationRequired(),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse token: %w", err)
|
||||
}
|
||||
|
||||
claims, err := token.Claims.GetAudience()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get claims: %w", err)
|
||||
}
|
||||
|
||||
for _, aud := range claims {
|
||||
if aud == string(ct) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: claim %q not found in claims", ErrInvalidClaims, ct)
|
||||
}
|
18
tests/localstack_init.sh
Executable file
18
tests/localstack_init.sh
Executable file
@ -0,0 +1,18 @@
|
||||
#!/bin/sh
|
||||
|
||||
apt install --assume-yes jq
|
||||
|
||||
AWS_REGION=us-east-1
|
||||
KEY_ALIAS=pass_service
|
||||
|
||||
response=$(awslocal kms create-key \
|
||||
--region $AWS_REGION \
|
||||
--key-usage SIGN_VERIFY \
|
||||
--customer-master-key-spec ECC_NIST_P256)
|
||||
|
||||
key_id=$(echo "${response}" | jq -r '.KeyMetadata.KeyId')
|
||||
|
||||
awslocal kms create-alias \
|
||||
--region $AWS_REGION \
|
||||
--alias-name "alias/$KEY_ALIAS" \
|
||||
--target-key-id "${key_id}"
|
152
tests/pass/kms_test.go
Normal file
152
tests/pass/kms_test.go
Normal file
@ -0,0 +1,152 @@
|
||||
package pass
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/kms"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/twofas/2fas-server/internal/pass/sign"
|
||||
)
|
||||
|
||||
func TestSignAndVerifyHappyPath(t *testing.T) {
|
||||
sess, err := session.NewSession(&aws.Config{
|
||||
Region: aws.String("us-east-1"),
|
||||
Credentials: credentials.NewStaticCredentials("test", "test", ""),
|
||||
S3ForcePathStyle: aws.Bool(true),
|
||||
Endpoint: aws.String("http://localhost:4566"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
kmsClient := kms.New(sess)
|
||||
srv, err := sign.NewService("alias/pass_service", kmsClient)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
token, err := srv.SignAndEncode(sign.Message{
|
||||
ConnectionID: uuid.New().String(),
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
ConnectionType: sign.ConnectionTypeBrowserExtensionProxy,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(token)
|
||||
t.Log("Length of the token is", len(token))
|
||||
|
||||
if err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignAndVerify(t *testing.T) {
|
||||
sess, err := session.NewSession(&aws.Config{
|
||||
Region: aws.String("us-east-1"),
|
||||
Credentials: credentials.NewStaticCredentials("test", "test", ""),
|
||||
S3ForcePathStyle: aws.Bool(true),
|
||||
Endpoint: aws.String("http://localhost:4566"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
kmsClient := kms.New(sess)
|
||||
srv, err := sign.NewService("alias/pass_service", kmsClient)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenFn func() string
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "not even jwt token",
|
||||
tokenFn: func() string {
|
||||
return "xxx"
|
||||
},
|
||||
expectedError: jwt.ErrTokenMalformed,
|
||||
},
|
||||
{
|
||||
name: "token is expired",
|
||||
tokenFn: func() string {
|
||||
token, err := srv.SignAndEncode(sign.Message{
|
||||
ConnectionID: uuid.New().String(),
|
||||
ExpiresAt: now.Add(-time.Hour),
|
||||
ConnectionType: sign.ConnectionTypeBrowserExtensionProxy,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return token
|
||||
},
|
||||
expectedError: jwt.ErrTokenExpired,
|
||||
},
|
||||
{
|
||||
name: "invalid claim",
|
||||
tokenFn: func() string {
|
||||
token, err := srv.SignAndEncode(sign.Message{
|
||||
ConnectionID: uuid.New().String(),
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
ConnectionType: sign.ConnectionTypeBrowserExtensionWait,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return token
|
||||
},
|
||||
expectedError: sign.ErrInvalidClaims,
|
||||
},
|
||||
{
|
||||
name: "invalid signature",
|
||||
tokenFn: func() string {
|
||||
resp, err := kmsClient.CreateKey(&kms.CreateKeyInput{
|
||||
KeySpec: aws.String("ECC_NIST_P256"),
|
||||
KeyUsage: aws.String("SIGN_VERIFY"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
serviceWithAnotherKey, err := sign.NewService(*resp.KeyMetadata.KeyId, kmsClient)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
token, err := serviceWithAnotherKey.SignAndEncode(sign.Message{
|
||||
ConnectionID: uuid.New().String(),
|
||||
ExpiresAt: now.Add(-time.Hour),
|
||||
ConnectionType: sign.ConnectionTypeBrowserExtensionProxy,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return token
|
||||
},
|
||||
expectedError: jwt.ErrTokenSignatureInvalid,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
token := tc.tokenFn()
|
||||
err := srv.CanI(token, sign.ConnectionTypeBrowserExtensionProxy)
|
||||
if err == nil {
|
||||
t.Fatalf("Expected error %v, got nil", tc.expectedError)
|
||||
}
|
||||
if !errors.Is(err, tc.expectedError) {
|
||||
t.Fatalf("Expected error %v, got %v", tc.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user