mirror of
https://github.com/twofas/2fas-server.git
synced 2024-12-12 12:09:56 +01:00
112 lines
3.0 KiB
Go
112 lines
3.0 KiB
Go
|
package sign
|
||
|
|
||
|
import (
|
||
|
"crypto"
|
||
|
"encoding/asn1"
|
||
|
"fmt"
|
||
|
"math/big"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"github.com/aws/aws-sdk-go/aws"
|
||
|
"github.com/aws/aws-sdk-go/service/kms"
|
||
|
"github.com/golang-jwt/jwt/v5"
|
||
|
)
|
||
|
|
||
|
type Message struct {
|
||
|
ConnectionID string
|
||
|
ExpiresAt time.Time
|
||
|
ConnectionType ConnectionType
|
||
|
}
|
||
|
|
||
|
// SignAndEncode information in the message. The result
|
||
|
// is second and third part of jwt token. Since the first
|
||
|
// part is constant it is omitted.
|
||
|
func (s Service) SignAndEncode(m Message) (string, error) {
|
||
|
token := jwt.NewWithClaims(s.signingMethod, jwt.MapClaims{
|
||
|
"exp": m.ExpiresAt.Unix(),
|
||
|
"aud": []string{string(m.ConnectionType)},
|
||
|
"c_id": m.ConnectionID,
|
||
|
})
|
||
|
|
||
|
// no key is needed, as we use custom signing method.
|
||
|
signed, err := token.SignedString(nil)
|
||
|
if err != nil {
|
||
|
return "", fmt.Errorf("failed to sign jwt: %w", err)
|
||
|
}
|
||
|
|
||
|
if !strings.HasPrefix(signed, jwtHeader) {
|
||
|
return "", fmt.Errorf("unpexpected signed string format")
|
||
|
}
|
||
|
|
||
|
return strings.TrimPrefix(signed, jwtHeader), nil
|
||
|
}
|
||
|
|
||
|
type kmsSigningMethod struct {
|
||
|
client *kms.KMS
|
||
|
keyID string
|
||
|
hash crypto.Hash
|
||
|
}
|
||
|
|
||
|
// Verify implements jwt.SigningMethod#Method. Because we
|
||
|
// provide key to jwt library, this is never called.
|
||
|
func (s kmsSigningMethod) Verify(signingString string, sig []byte, key interface{}) error {
|
||
|
panic("should never be called")
|
||
|
}
|
||
|
|
||
|
// Sign implements jwt.SigningMethod#Sign method.
|
||
|
func (s kmsSigningMethod) Sign(signingString string, key interface{}) ([]byte, error) {
|
||
|
messageType := "DIGEST"
|
||
|
|
||
|
hasher := s.hash.New()
|
||
|
if _, err := hasher.Write([]byte(signingString)); err != nil {
|
||
|
return nil, fmt.Errorf("failed to hash input")
|
||
|
}
|
||
|
hashedSigningString := hasher.Sum(nil)
|
||
|
|
||
|
resp, err := s.client.Sign(&kms.SignInput{
|
||
|
KeyId: &s.keyID,
|
||
|
Message: hashedSigningString,
|
||
|
MessageType: &messageType,
|
||
|
SigningAlgorithm: aws.String(awsSigningAlgorithm),
|
||
|
})
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to sign the message: %w", err)
|
||
|
}
|
||
|
// We are using encryption method with SHA_256 digest. Hence, key has 256/8=32 bytes.
|
||
|
keySizeInBytes := 256 / 8
|
||
|
return formatKMSSignatureForJWT(keySizeInBytes, resp.Signature)
|
||
|
}
|
||
|
|
||
|
// Alg implements jwt.SigningMethod#Method.
|
||
|
func (s kmsSigningMethod) Alg() string {
|
||
|
return jwtSigningAlgorithm
|
||
|
}
|
||
|
|
||
|
// formatKMSSignatureForJWT translates asn1 encoded signature (returned by AWS)
|
||
|
// to format expected by JWT standard.
|
||
|
// It is an algorithm I found on the internet
|
||
|
// (here: https://github.com/twofas/2fas-server/pull/24/files/4f68cc2e611dca18b9787942e5cf12fc16518dd4#r1452702669 )
|
||
|
// It should be tested using e2e tests.
|
||
|
func formatKMSSignatureForJWT(keyBytes int, sig []byte) ([]byte, error) {
|
||
|
p := struct {
|
||
|
R *big.Int
|
||
|
S *big.Int
|
||
|
}{}
|
||
|
|
||
|
_, err := asn1.Unmarshal(sig, &p)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
rBytes := p.R.Bytes()
|
||
|
rBytesPadded := make([]byte, keyBytes)
|
||
|
copy(rBytesPadded[keyBytes-len(rBytes):], rBytes)
|
||
|
|
||
|
sBytes := p.S.Bytes()
|
||
|
sBytesPadded := make([]byte, keyBytes)
|
||
|
copy(sBytesPadded[keyBytes-len(sBytes):], sBytes)
|
||
|
|
||
|
out := append(rBytesPadded, sBytesPadded...)
|
||
|
return out, nil
|
||
|
}
|