diff --git a/go.mod b/go.mod index 5f926fa..295daff 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/go-sql-driver/mysql v1.7.1 github.com/golang-jwt/jwt/v5 v5.2.0 github.com/google/uuid v1.3.1 - github.com/gorilla/websocket v1.5.0 + github.com/gorilla/websocket v1.5.1 github.com/jaswdr/faker v1.16.0 github.com/kelseyhightower/envconfig v1.4.0 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index 1427dc2..50279e4 100644 --- a/go.sum +++ b/go.sum @@ -229,6 +229,8 @@ github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qK github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= diff --git a/tests/pass/pass_test.go b/tests/pass/pass_test.go new file mode 100644 index 0000000..ad6ec60 --- /dev/null +++ b/tests/pass/pass_test.go @@ -0,0 +1,215 @@ +package pass + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" +) + +type ConfigureBrowserExtensionResponse struct { + BrowserExtensionPairingToken string `json:"browser_extension_pairing_token"` + ConnectionToken string `json:"connection_token"` +} + +var ( + httpClient = http.DefaultClient + wsDialer = websocket.DefaultDialer +) + +const api = "localhost:8082" + +func TestPassHappyFlow(t *testing.T) { + resp, err := configureBrowserExtension() + if err != nil { + t.Fatalf("Failed to configure browser extension: %v", err) + } + + browserExtensionDone := make(chan struct{}) + mobileDone := make(chan struct{}) + + go func() { + defer close(browserExtensionDone) + + err := browserExtensionWaitForConfirm(resp.BrowserExtensionPairingToken) + if err != nil { + t.Errorf("Error when Browser Extension waited for confirm: %v", err) + return + } + + err = proxyWebSocket( + "ws://"+api+"/browser_extension/proxy_to_mobile", + resp.BrowserExtensionPairingToken, + "sent from browser extension", + "sent from mobile") + if err != nil { + t.Errorf("Browser Extension: proxy failed: %v", err) + return + } + + }() + go func() { + defer close(mobileDone) + + err := confirmMobile(resp.ConnectionToken) + if err != nil { + t.Errorf("Mobile: confirm failed: %v", err) + return + } + + err = proxyWebSocket( + "ws://"+api+"/mobile/proxy_to_browser_extension", + resp.BrowserExtensionPairingToken, + "sent from mobile", + "sent from browser extension", + ) + if err != nil { + t.Errorf("Mobile: proxy failed: %v", err) + return + } + }() + + <-browserExtensionDone + <-mobileDone +} + +func browserExtensionWaitForConfirm(token string) error { + url := "ws://" + api + "/browser_extension/wait_for_connection" + + var resp struct { + Status string `json:"status"` + } + + conn, err := dialWS(url, token) + if err != nil { + return err + } + defer conn.Close() + + conn.SetReadDeadline(time.Now().Add(time.Second)) + _, message, err := conn.ReadMessage() + if err != nil { + return fmt.Errorf("error reading from connection: %w", err) + } + if err := json.Unmarshal(message, &resp); err != nil { + return fmt.Errorf("failed to decode message: %w", err) + } + const expectedStatus = "ok" + if resp.Status != expectedStatus { + return fmt.Errorf("received status %q, expected %q", resp.Status, expectedStatus) + } + return nil +} + +func configureBrowserExtension() (ConfigureBrowserExtensionResponse, error) { + url := "http://" + api + "/browser_extension/configure" + + req, err := http.NewRequest("POST", url, bytesPrintf(`{"extension_id":"%s"}`, uuid.New().String())) + if err != nil { + return ConfigureBrowserExtensionResponse{}, fmt.Errorf("failed to create http request: %w", err) + } + httpResp, err := httpClient.Do(req) + if err != nil { + return ConfigureBrowserExtensionResponse{}, fmt.Errorf("failed perform the request: %w", err) + } + defer httpResp.Body.Close() + + bb, err := io.ReadAll(httpResp.Body) + if err != nil { + return ConfigureBrowserExtensionResponse{}, fmt.Errorf("failed to read body from response: %w", err) + } + + if httpResp.StatusCode >= 300 { + return ConfigureBrowserExtensionResponse{}, fmt.Errorf("received status %s and body %q", httpResp.Status, string(bb)) + } + + var resp ConfigureBrowserExtensionResponse + if err := json.Unmarshal(bb, &resp); err != nil { + return resp, fmt.Errorf("failed to decode the response: %w", err) + } + + return resp, nil +} + +func confirmMobile(connectionToken string) error { + url := "http://" + api + "/mobile/confirm" + + req, err := http.NewRequest("POST", url, bytesPrintf(`{"device_id":"%s"}`, uuid.New().String())) + if err != nil { + return fmt.Errorf("failed to prepare the reqest: %w", err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", connectionToken)) + + httpResp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to perform the reqest: %w", err) + } + defer httpResp.Body.Close() + + if httpResp.StatusCode > 299 { + return fmt.Errorf("unexpected response: %s", httpResp.Status) + } + + return nil +} + +// proxyWebSocket will dial `endpoint`, using `token` for auth. It will then write exactly one message and +// read exactly one message (and then check it is `expectedReadMsg`). +func proxyWebSocket(url, token string, writeMsg, expectedReadMsg string) error { + conn, err := dialWS(url, token) + if err != nil { + return nil + } + defer conn.Close() + + doneReading := make(chan error) + + go func() { + defer close(doneReading) + _, message, err := conn.ReadMessage() + if err != nil { + doneReading <- fmt.Errorf("faile to read message: %w", err) + return + } + if string(message) != expectedReadMsg { + doneReading <- fmt.Errorf("expected to read %q, read %q", expectedReadMsg, string(message)) + return + } + }() + + if err := conn.WriteMessage(websocket.TextMessage, []byte(writeMsg)); err != nil { + return fmt.Errorf("failed to write message: %w", err) + } + err, _ = <-doneReading + if err != nil { + return fmt.Errorf("error when reading: %w", err) + } + return nil +} + +func dialWS(url, auth string) (*websocket.Conn, error) { + authEncodedAsProtocol := fmt.Sprintf("base64url.bearer.authorization.2pass.io.%s", base64.RawURLEncoding.EncodeToString([]byte(auth))) + + conn, _, err := wsDialer.Dial(url, http.Header{ + "Sec-WebSocket-Protocol": []string{ + "2pass.io", + authEncodedAsProtocol, + }, + }) + if err != nil { + return nil, fmt.Errorf("failed to dial ws %q: %v", url, err) + } + return conn, nil +} + +func bytesPrintf(format string, ii ...interface{}) io.Reader { + s := fmt.Sprintf(format, ii...) + return bytes.NewBufferString(s) +}