diff --git a/.env b/.env index ac39e60..cfd79d7 100644 --- a/.env +++ b/.env @@ -15,6 +15,7 @@ SECURITY_RATE_LIMIT_BE=100 SECURITY_RATE_LIMIT_MOBILE=100 PASS_ADDR=:8082 +FAKE_MOBILE_PUSH=true AWS_ACCESS_KEY_ID=test AWS_SECRET_ACCESS_KEY=test diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..5a02b69 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,11 @@ +blank_issues_enabled: true +contact_links: + - name: Help center + url: https://2fas.com/help-center/ + about: Check out our extensive FaQ and video guides! + - name: Discord + url: https://discord.gg/q4cP6qh2g5 + about: Need support or have a question? Our Discord members are there to help! + - name: Reddit + url: https://www.reddit.com/r/2fas_com/ + about: Get support and discuss 2FAS with our Reddit community! \ No newline at end of file diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..607ab3c --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,25 @@ +name: Go + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main", "develop/pass" ] + +jobs: + + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.22' + + - name: Unit Test + run: make unit-tests-ci + + - name: e2e Test + run: make ci-e2e diff --git a/Makefile b/Makefile index 8442fe4..2434b37 100644 --- a/Makefile +++ b/Makefile @@ -15,25 +15,33 @@ migration: ## create database migrations file migration-up: ## apply all available migrations docker compose run -u ${USERID}:${USERID} --rm api migrate up - +.PHONY: up up: ## run all applications in stack docker compose build docker compose up -d +.PHONY: unit-tests +unit-tests: ## run unit tests without e2e-tests directory. + go test -race -count=1 `go list ./... | grep -v e2e-tests` -test: ## run unit tests - go test ./internal/... +.PHONY: unit-tests-ci +unit-tests-ci: ## run unit tests without e2e-tests directory (multiple times to find race conditions). + go test -race -count=50 -failfast `go list ./... | grep -v e2e-tests` +.PHONY: ci-e2e +ci-e2e: up + go run ./e2e-tests/scripts/wait-ready/main.go -addr=':80;:8081;:8082' + @$(MAKE) tests-e2e +.PHONY: tests-e2e tests-e2e: ## run end to end tests ## There is some race condition when running tests as go test -count=1 ./tests/... Come back at some point and fix it - go test ./tests/browser_extension/... -count=1 - go test ./tests/icons/... -count=1 - go test ./tests/mobile/... -count=1 - go test ./tests/support/... -count=1 - go test ./tests/system/... -count=1 - go test ./tests/pass/... -count=1 - + go test ./e2e-tests/browser_extension/... -count=1 + go test ./e2e-tests/icons/... -count=1 + go test ./e2e-tests/mobile/... -count=1 + go test ./e2e-tests/support/... -count=1 + go test ./e2e-tests/system/... -count=1 + PASS_ADDR="localhost:8088" go test ./e2e-tests/pass/... -count=1 vendor-licenses: ## report vendor licenses go-licenses report ./cmd/api --template licenses.tpl > licenses.json 2> licenses-errors diff --git a/cmd/admin/main.go b/cmd/admin/main.go index a961a6a..0bdbb47 100644 --- a/cmd/admin/main.go +++ b/cmd/admin/main.go @@ -10,7 +10,7 @@ import ( ) func main() { - logging.WithDefaultField("service_name", "admin_api") + logging.Init(logging.Fields{"service_name": "admin_api"}) config.LoadConfiguration() diff --git a/cmd/api/main.go b/cmd/api/main.go index f80d997..cce7a87 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -9,7 +9,7 @@ import ( ) func main() { - logging.WithDefaultField("service_name", "api") + logging.Init(logging.Fields{"service_name": "api"}) config.LoadConfiguration() diff --git a/cmd/pass/main.go b/cmd/pass/main.go index 3c770bd..9d91080 100644 --- a/cmd/pass/main.go +++ b/cmd/pass/main.go @@ -9,7 +9,7 @@ import ( ) func main() { - logging.WithDefaultField("service_name", "pass") + logging.Init(logging.Fields{"service_name": "pass"}) var cfg config.PassConfig err := envconfig.Process("", &cfg) diff --git a/cmd/websocket/main.go b/cmd/websocket/main.go index 82e92f8..c737bcc 100644 --- a/cmd/websocket/main.go +++ b/cmd/websocket/main.go @@ -7,7 +7,7 @@ import ( ) func main() { - logging.WithDefaultField("service_name", "websocket_api") + logging.Init(logging.Fields{"service_name": "websocket_api"}) config.LoadConfiguration() diff --git a/config/config.go b/config/config.go index 0d345ce..42da874 100644 --- a/config/config.go +++ b/config/config.go @@ -106,7 +106,7 @@ func initViper(configFilePath string) { err := viper.ReadInConfig() if err != nil { - logging.Fatal("failed to read the configuration file: %s", err) + logging.Fatalf("failed to read the configuration file: %s", err) } err = viper.Unmarshal(&Config) diff --git a/docker-compose.yml b/docker-compose.yml index 7c4cf75..0dac1df 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -89,10 +89,11 @@ services: group_add: - '1000' ports: - - "8084:8082" + - "8088:8082" environment: # overwrite AWS_ENDPOINT from .env file. One in env is used to running app from local also. AWS_ENDPOINT: http://localstack-main:4566 + AWS_REGION: us-east-1 env_file: - .env depends_on: @@ -113,7 +114,7 @@ services: timeout: 5s retries: 5 volumes: - - "./tests/localstack_init.sh:/etc/localstack/init/ready.d/localstack_init.sh" # ready hook + - "./e2e-tests/localstack_init.sh:/etc/localstack/init/ready.d/localstack_init.sh" # ready hook - "./data/localstack:/var/lib/localstack" - "/var/run/docker.sock:/var/run/docker.sock" diff --git a/e2e-tests/browser_extension/browser_extension_2fa_request_test.go b/e2e-tests/browser_extension/browser_extension_2fa_request_test.go new file mode 100644 index 0000000..34ae03d --- /dev/null +++ b/e2e-tests/browser_extension/browser_extension_2fa_request_test.go @@ -0,0 +1,121 @@ +package tests + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "github.com/twofas/2fas-server/e2e-tests" +) + +func TestBrowserExtensionTwoFactorAuthTestSuite(t *testing.T) { + suite.Run(t, new(BrowserExtensionTwoFactorAuthTestSuite)) +} + +type BrowserExtensionTwoFactorAuthTestSuite struct { + suite.Suite +} + +func (s *BrowserExtensionTwoFactorAuthTestSuite) SetupTest() { + e2e_tests.RemoveAllMobileDevices(s.T()) + e2e_tests.RemoveAllBrowserExtensions(s.T()) + e2e_tests.RemoveAllBrowserExtensionsDevices(s.T()) +} + +func (s *BrowserExtensionTwoFactorAuthTestSuite) TestRequest2FaToken() { + browserExtension := e2e_tests.CreateBrowserExtension(s.T(), "go-ext") + + var tokenRequest *e2e_tests.AuthTokenRequestResponse + request2FaTokenPayload := []byte(`{"domain":"https://facebook.com/path/nested"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", request2FaTokenPayload, &tokenRequest) + + assert.Equal(s.T(), browserExtension.Id, tokenRequest.ExtensionId) + + var tokenRequestById *e2e_tests.AuthTokenRequestResponse + e2e_tests.DoAPISuccessGet(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id, &tokenRequestById) + assert.Equal(s.T(), tokenRequest.Id, tokenRequestById.Id) + assert.Equal(s.T(), "https://facebook.com", tokenRequestById.Domain) +} + +func (s *BrowserExtensionTwoFactorAuthTestSuite) TestFindAll2FaRequestsForBrowserExtension() { + browserExtension := e2e_tests.CreateBrowserExtension(s.T(), "go-ext") + + facebook2FaTokenRequest := []byte(`{"domain":"facebook.com"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", facebook2FaTokenRequest, nil) + + google2FaTokenRequest := []byte(`{"domain":"google.com"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", google2FaTokenRequest, nil) + + var tokenRequestsCollection []*e2e_tests.AuthTokenRequestResponse + e2e_tests.DoAPISuccessGet(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests", &tokenRequestsCollection) + + assert.Len(s.T(), tokenRequestsCollection, 2) +} + +func (s *BrowserExtensionTwoFactorAuthTestSuite) TestClose2FaTokenRequest() { + var tokenRequest *e2e_tests.AuthTokenRequestResponse + browserExtension := e2e_tests.CreateBrowserExtension(s.T(), "go-ext") + tokenRequestPayload := []byte(`{"domain":"facebook.com"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", tokenRequestPayload, &tokenRequest) + closeTokenRequestPayload := []byte(`{"status":"completed"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id+"/commands/close_2fa_request", closeTokenRequestPayload, nil) + + var closedTokenRequest *e2e_tests.AuthTokenRequestResponse + e2e_tests.DoAPISuccessGet(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id, &closedTokenRequest) + assert.Equal(s.T(), "completed", closedTokenRequest.Status) +} + +func (s *BrowserExtensionTwoFactorAuthTestSuite) TestCloseNotExisting2FaTokenRequest() { + notExistingTokenRequestId := uuid.New() + browserExtension := e2e_tests.CreateBrowserExtension(s.T(), "go-ext") + + closeTokenRequestPayload := []byte(`{"status":"completed"}`) + e2e_tests.DoAPIPostAndAssertCode(s.T(), 404, "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+notExistingTokenRequestId.String()+"/commands/close_2fa_request", closeTokenRequestPayload, nil) + +} + +func (s *BrowserExtensionTwoFactorAuthTestSuite) TestDoNotReturnClosed2FaRequests() { + var tokenRequest *e2e_tests.AuthTokenRequestResponse + browserExtension := e2e_tests.CreateBrowserExtension(s.T(), "go-ext") + tokenRequestPayload := []byte(`{"domain":"facebook.com"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", tokenRequestPayload, &tokenRequest) + + closeTokenRequestPayload := []byte(`{"status":"completed"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id+"/commands/close_2fa_request", closeTokenRequestPayload, nil) + + var response []*e2e_tests.AuthTokenRequestResponse + e2e_tests.DoAPISuccessGet(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests", &response) + assert.Len(s.T(), response, 0) +} + +func (s *BrowserExtensionTwoFactorAuthTestSuite) TestTerminate2FaRequest() { + var tokenRequest *e2e_tests.AuthTokenRequestResponse + browserExtension := e2e_tests.CreateBrowserExtension(s.T(), "go-ext") + tokenRequestPayload := []byte(`{"domain":"facebook.com"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", tokenRequestPayload, &tokenRequest) + + closeTokenRequestPayload := []byte(`{"status":"terminated"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id+"/commands/close_2fa_request", closeTokenRequestPayload, nil) + + var response *e2e_tests.AuthTokenRequestResponse + e2e_tests.DoAPISuccessGet(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id, &response) + assert.Equal(s.T(), "terminated", response.Status) +} + +func (s *BrowserExtensionTwoFactorAuthTestSuite) TestClose2FaRequest() { + device, devicePubKey := e2e_tests.CreateDevice(s.T(), "SM-955F", "fcm-token") + browserExtension := e2e_tests.CreateBrowserExtension(s.T(), "go-ext") + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExtension, device) + + var tokenRequest *e2e_tests.AuthTokenRequestResponse + request2FaTokenPayload := []byte(`{"domain":"domain.com"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", request2FaTokenPayload, &tokenRequest) + + closeTokenRequestPayload := []byte(`{"status":"completed"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id+"/commands/close_2fa_request", closeTokenRequestPayload, nil) + + var closedTokenRequest *e2e_tests.AuthTokenRequestResponse + e2e_tests.DoAPISuccessGet(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id, &closedTokenRequest) + assert.Equal(s.T(), "completed", closedTokenRequest.Status) +} diff --git a/tests/browser_extension/browser_extension_2fa_test.go b/e2e-tests/browser_extension/browser_extension_2fa_test.go similarity index 62% rename from tests/browser_extension/browser_extension_2fa_test.go rename to e2e-tests/browser_extension/browser_extension_2fa_test.go index a52109b..2ad15d3 100644 --- a/tests/browser_extension/browser_extension_2fa_test.go +++ b/e2e-tests/browser_extension/browser_extension_2fa_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" - "github.com/twofas/2fas-server/tests" + "github.com/twofas/2fas-server/e2e-tests" ) func TestTwoFactorAuthTestSuite(t *testing.T) { @@ -18,33 +18,33 @@ type TwoFactorAuthTestSuite struct { } func (s *TwoFactorAuthTestSuite) SetupTest() { - tests.RemoveAllMobileDevices(s.T()) - tests.RemoveAllBrowserExtensions(s.T()) - tests.RemoveAllBrowserExtensionsDevices(s.T()) + e2e_tests.RemoveAllMobileDevices(s.T()) + e2e_tests.RemoveAllBrowserExtensions(s.T()) + e2e_tests.RemoveAllBrowserExtensionsDevices(s.T()) } func (s *TwoFactorAuthTestSuite) TestBrowserExtensionAuthFullFlow() { - device, devicePubKey := tests.CreateDevice(s.T(), "SM-955F", "some-token") - browserExtension := tests.CreateBrowserExtension(s.T(), "go-ext") + device, devicePubKey := e2e_tests.CreateDevice(s.T(), "SM-955F", "some-token") + browserExtension := e2e_tests.CreateBrowserExtension(s.T(), "go-ext") - websocketTestListener := tests.NewWebsocketTestListener("browser_extensions/" + browserExtension.Id) + websocketTestListener := e2e_tests.NewWebsocketTestListener("browser_extensions/" + browserExtension.Id) websocketConnection := websocketTestListener.StartListening() defer websocketConnection.Close() - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExtension, device) + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExtension, device) assertDeviceHasPairedExtension(s.T(), device, browserExtension) assertBrowserExtensionHasPairedDevice(s.T(), browserExtension, device) expectedPairingSuccessWebsocket := createPairingSuccessWebsocketMessage(browserExtension, device, devicePubKey) websocketTestListener.AssertMessageHasBeenReceived(s.T(), expectedPairingSuccessWebsocket) - tokenRequest := tests.Request2FaToken(s.T(), "facebook.com", browserExtension.Id) + tokenRequest := e2e_tests.Request2FaToken(s.T(), "facebook.com", browserExtension.Id) - extensionTokenRequestWebsocketListener := tests.NewWebsocketTestListener("browser_extensions/" + browserExtension.Id + "/2fa_requests/" + tokenRequest.Id) + extensionTokenRequestWebsocketListener := e2e_tests.NewWebsocketTestListener("browser_extensions/" + browserExtension.Id + "/2fa_requests/" + tokenRequest.Id) extensionTokenRequestWebsocketConnection := extensionTokenRequestWebsocketListener.StartListening() defer extensionTokenRequestWebsocketConnection.Close() - tests.Send2FaTokenToExtension(s.T(), browserExtension.Id, device.Id, tokenRequest.Id, "2fa-token") + e2e_tests.Send2FaTokenToExtension(s.T(), browserExtension.Id, device.Id, tokenRequest.Id, "2fa-token") expected2FaTokenWebsocket := createBrowserExtensionReceived2FaTokenMessage(browserExtension.Id, device.Id, tokenRequest.Id) extensionTokenRequestWebsocketListener.AssertMessageHasBeenReceived(s.T(), expected2FaTokenWebsocket) @@ -70,7 +70,7 @@ func createBrowserExtensionReceived2FaTokenMessage(extensionId, deviceId, reques return string(message) } -func createPairingSuccessWebsocketMessage(browserExtension *tests.BrowserExtensionResponse, device *tests.DeviceResponse, devicePubKey string) string { +func createPairingSuccessWebsocketMessage(browserExtension *e2e_tests.BrowserExtensionResponse, device *e2e_tests.DeviceResponse, devicePubKey string) string { expectedPairingWebsocketMessageRaw := &struct { Event string `json:"event"` BrowserExtensionId string `json:"browser_extension_id"` @@ -88,17 +88,17 @@ func createPairingSuccessWebsocketMessage(browserExtension *tests.BrowserExtensi return string(message) } -func assertBrowserExtensionHasPairedDevice(t *testing.T, browserExtension *tests.BrowserExtensionResponse, device *tests.DeviceResponse) { - var browserExtensionDevices []*tests.DeviceResponse - tests.DoAPISuccessGet(t, "browser_extensions/"+browserExtension.Id+"/devices", &browserExtensionDevices) +func assertBrowserExtensionHasPairedDevice(t *testing.T, browserExtension *e2e_tests.BrowserExtensionResponse, device *e2e_tests.DeviceResponse) { + var browserExtensionDevices []*e2e_tests.DeviceResponse + e2e_tests.DoAPISuccessGet(t, "browser_extensions/"+browserExtension.Id+"/devices", &browserExtensionDevices) assert.Len(t, browserExtensionDevices, 1) assert.Equal(t, device.Id, browserExtensionDevices[0].Id) } -func assertDeviceHasPairedExtension(t *testing.T, device *tests.DeviceResponse, browserExtension *tests.BrowserExtensionResponse) { - var deviceBrowserExtensions []*tests.BrowserExtensionResponse - tests.DoAPISuccessGet(t, "mobile/devices/"+device.Id+"/browser_extensions", &deviceBrowserExtensions) +func assertDeviceHasPairedExtension(t *testing.T, device *e2e_tests.DeviceResponse, browserExtension *e2e_tests.BrowserExtensionResponse) { + var deviceBrowserExtensions []*e2e_tests.BrowserExtensionResponse + e2e_tests.DoAPISuccessGet(t, "mobile/devices/"+device.Id+"/browser_extensions", &deviceBrowserExtensions) assert.Len(t, deviceBrowserExtensions, 1) assert.Equal(t, browserExtension.Id, deviceBrowserExtensions[0].Id) diff --git a/tests/browser_extension/browser_extension_log_test.go b/e2e-tests/browser_extension/browser_extension_log_test.go similarity index 62% rename from tests/browser_extension/browser_extension_log_test.go rename to e2e-tests/browser_extension/browser_extension_log_test.go index 44dc990..1a91762 100644 --- a/tests/browser_extension/browser_extension_log_test.go +++ b/e2e-tests/browser_extension/browser_extension_log_test.go @@ -5,11 +5,11 @@ import ( "testing" "github.com/google/uuid" - "github.com/twofas/2fas-server/tests" + "github.com/twofas/2fas-server/e2e-tests" ) func Test_BrowserExtensionLogging(t *testing.T) { - browserExtension := tests.CreateBrowserExtension(t, "go-ext") + browserExtension := e2e_tests.CreateBrowserExtension(t, "go-ext") log := &struct { Level string `json:"level"` @@ -20,7 +20,7 @@ func Test_BrowserExtensionLogging(t *testing.T) { } payload, _ := json.Marshal(log) - tests.DoAPISuccessPost(t, "/browser_extensions/"+browserExtension.Id+"/commands/store_log", payload, nil) + e2e_tests.DoAPISuccessPost(t, "/browser_extensions/"+browserExtension.Id+"/commands/store_log", payload, nil) } func Test_NotExistingBrowserExtensionLogging(t *testing.T) { @@ -35,5 +35,5 @@ func Test_NotExistingBrowserExtensionLogging(t *testing.T) { } payload, _ := json.Marshal(log) - tests.DoAPISuccessPost(t, "/browser_extensions/"+someId.String()+"/commands/store_log", payload, nil) + e2e_tests.DoAPISuccessPost(t, "/browser_extensions/"+someId.String()+"/commands/store_log", payload, nil) } diff --git a/e2e-tests/browser_extension/browser_extension_pairing_test.go b/e2e-tests/browser_extension/browser_extension_pairing_test.go new file mode 100644 index 0000000..8b2053c --- /dev/null +++ b/e2e-tests/browser_extension/browser_extension_pairing_test.go @@ -0,0 +1,207 @@ +package tests + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/twofas/2fas-server/e2e-tests" +) + +func TestBrowserExtensionPairingTestSuite(t *testing.T) { + suite.Run(t, new(BrowserExtensionPairingTestSuite)) +} + +type BrowserExtensionPairingTestSuite struct { + suite.Suite +} + +func (s *BrowserExtensionPairingTestSuite) SetupTest() { + e2e_tests.RemoveAllBrowserExtensions(s.T()) + e2e_tests.RemoveAllBrowserExtensionsDevices(s.T()) +} + +func (s *BrowserExtensionPairingTestSuite) TestPairBrowserExtensionWithMobileDevice() { + browserExt := e2e_tests.CreateBrowserExtension(s.T(), "go-test") + _, err := uuid.Parse(browserExt.Id) + require.NoError(s.T(), err) + + device, devicePubKey := e2e_tests.CreateDevice(s.T(), "go-test-device", "some-device-id") + _, err = uuid.Parse(device.Id) + require.NoError(s.T(), err) + + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt, device) + + var extensionDevice *e2e_tests.DevicePairedBrowserExtensionResponse + e2e_tests.DoAPISuccessGet(s.T(), "/browser_extensions/"+browserExt.Id+"/devices/"+device.Id, &extensionDevice) + + assert.Equal(s.T(), extensionDevice.Id, device.Id) +} + +func (s *BrowserExtensionPairingTestSuite) TestDoNotFindNotPairedBrowserExtensionMobileDevice() { + browserExt := e2e_tests.CreateBrowserExtension(s.T(), "go-test") + _, err := uuid.Parse(browserExt.Id) + require.NoError(s.T(), err) + + device, _ := e2e_tests.CreateDevice(s.T(), "go-test-device", "some-device-id") + + response := e2e_tests.DoAPIGet(s.T(), "/browser_extensions/"+browserExt.Id+"/devices/"+device.Id, nil) + + assert.Equal(s.T(), 404, response.StatusCode) +} + +func (s *BrowserExtensionPairingTestSuite) TestPairBrowserExtensionWithMultipleDevices() { + browserExt := e2e_tests.CreateBrowserExtension(s.T(), "go-test") + _, err := uuid.Parse(browserExt.Id) + require.NoError(s.T(), err) + + device1, devicePubKey1 := e2e_tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id-1") + device2, devicePubKey2 := e2e_tests.CreateDevice(s.T(), "go-test-device-2", "some-device-id-2") + + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey1, browserExt, device1) + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey2, browserExt, device2) + + extensionDevices := e2e_tests.GetExtensionDevices(s.T(), browserExt.Id) + + assert.Len(s.T(), extensionDevices, 2) +} + +func (s *BrowserExtensionPairingTestSuite) TestRemoveBrowserExtensionPairedDevice() { + browserExt := e2e_tests.CreateBrowserExtension(s.T(), "go-test") + + device1, devicePubKey1 := e2e_tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id-1") + device2, devicePubKey2 := e2e_tests.CreateDevice(s.T(), "go-test-device-2", "some-device-id-2") + + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey1, browserExt, device1) + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey2, browserExt, device2) + + extensionDevices := getExtensionPairedDevices(s.T(), browserExt) + assert.Len(s.T(), extensionDevices, 2) + + e2e_tests.DoAPISuccessDelete(s.T(), "/browser_extensions/"+browserExt.Id+"/devices/"+device1.Id) + + extensionDevices = getExtensionPairedDevices(s.T(), browserExt) + assert.Len(s.T(), extensionDevices, 1) + assert.Equal(s.T(), device2.Id, extensionDevices[0].Id) +} + +func (s *BrowserExtensionPairingTestSuite) TestRemoveBrowserExtensionPairedDeviceTwice() { + browserExt := e2e_tests.CreateBrowserExtension(s.T(), "go-test") + + device, devicePubKey := e2e_tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id-1") + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt, device) + + e2e_tests.DoAPISuccessDelete(s.T(), "/browser_extensions/"+browserExt.Id+"/devices/"+device.Id) + response := e2e_tests.DoAPIRequest(s.T(), "/browser_extensions/"+browserExt.Id+"/devices/"+device.Id, http.MethodDelete, nil /*payload*/, nil /*resp*/) + + assert.Equal(s.T(), 404, response.StatusCode) +} + +func (s *BrowserExtensionPairingTestSuite) TestRemoveAllBrowserExtensionPairedDevices() { + browserExt := e2e_tests.CreateBrowserExtension(s.T(), "go-test") + device1, devicePubKey1 := e2e_tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id1") + device2, devicePubKey2 := e2e_tests.CreateDevice(s.T(), "go-test-device-2", "some-device-id2") + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey1, browserExt, device1) + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey2, browserExt, device2) + + e2e_tests.DoAPISuccessDelete(s.T(), "/browser_extensions/"+browserExt.Id+"/devices") + + extensionDevices := e2e_tests.GetExtensionDevices(s.T(), browserExt.Id) + assert.Len(s.T(), extensionDevices, 0) +} + +func (s *BrowserExtensionPairingTestSuite) TestGetPairedDevicesWhichIDoNotOwn() { + browserExt1 := e2e_tests.CreateBrowserExtension(s.T(), "go-test-1") + browserExt2 := e2e_tests.CreateBrowserExtension(s.T(), "go-test-2") + + device1, devicePubKey1 := e2e_tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id-1") + device2, devicePubKey2 := e2e_tests.CreateDevice(s.T(), "go-test-device-2", "some-device-id-2") + + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey1, browserExt1, device1) + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey2, browserExt2, device2) + + firstExtensionDevices := getExtensionPairedDevices(s.T(), browserExt1) + assert.Len(s.T(), firstExtensionDevices, 1) + assert.Equal(s.T(), device1.Id, firstExtensionDevices[0].Id) + + secondExtensionDevices := getExtensionPairedDevices(s.T(), browserExt2) + assert.Len(s.T(), secondExtensionDevices, 1) + assert.Equal(s.T(), device2.Id, secondExtensionDevices[0].Id) +} + +func (s *BrowserExtensionPairingTestSuite) TestGetPairedDevicesByInvalidExtensionId() { + browserExt1 := e2e_tests.CreateBrowserExtension(s.T(), "go-test-1") + browserExt2 := e2e_tests.CreateBrowserExtension(s.T(), "go-test-2") + + device1, devicePubKey1 := e2e_tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id-1") + device2, devicePubKey2 := e2e_tests.CreateDevice(s.T(), "go-test-device-2", "some-device-id-2") + + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey1, browserExt1, device1) + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey2, browserExt2, device2) + + invalidResp := map[string]any{} + response := e2e_tests.DoAPIGet(s.T(), "/browser_extensions/some-invalid-id/devices/", &invalidResp) + assert.Equal(s.T(), 400, response.StatusCode) + assert.Contains(s.T(), invalidResp["Reason"], `Field validation for 'ExtensionId' failed on the 'uuid4'`) +} + +func (s *BrowserExtensionPairingTestSuite) TestGetPairedDevicesByNotExistingExtensionId() { + browserExt1 := e2e_tests.CreateBrowserExtension(s.T(), "go-test-1") + browserExt2 := e2e_tests.CreateBrowserExtension(s.T(), "go-test-2") + + device1, devicePubKey1 := e2e_tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id-1") + device2, devicePubKey2 := e2e_tests.CreateDevice(s.T(), "go-test-device-2", "some-device-id-2") + + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey1, browserExt1, device1) + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey2, browserExt2, device2) + + notExistingExtensionId := uuid.New() + var firstExtensionDevices []*e2e_tests.ExtensionPairedDeviceResponse + e2e_tests.DoAPISuccessGet(s.T(), "/browser_extensions/"+notExistingExtensionId.String()+"/devices/", &firstExtensionDevices) + assert.Len(s.T(), firstExtensionDevices, 0) +} + +func (s *BrowserExtensionPairingTestSuite) TestShareExtensionPublicKeyWithMobileDevice() { + browserExt := e2e_tests.CreateBrowserExtensionWithPublicKey(s.T(), "go-test", "b64-rsa-pub-key") + _, err := uuid.Parse(browserExt.Id) + require.NoError(s.T(), err) + + device, devicePubKey := e2e_tests.CreateDevice(s.T(), "go-test-device", "some-device-id") + _, err = uuid.Parse(device.Id) + require.NoError(s.T(), err) + + result := e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt, device) + assert.Equal(s.T(), "b64-rsa-pub-key", result.ExtensionPublicKey) +} + +func (s *BrowserExtensionPairingTestSuite) TestCannotPairSameDeviceAndExtensionTwice() { + browserExtension := e2e_tests.CreateBrowserExtensionWithPublicKey(s.T(), "go-test", "b64-rsa-pub-key") + device, devicePubKey := e2e_tests.CreateDevice(s.T(), "go-test-device", "some-device-id") + + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExtension, device) + + payload := struct { + ExtensionId string `json:"extension_id"` + DeviceName string `json:"device_name"` + DevicePublicKey string `json:"device_public_key"` + }{ + ExtensionId: browserExtension.Id, + DeviceName: device.Name, + DevicePublicKey: "device-pub-key", + } + + pairingResult := new(e2e_tests.PairingResultResponse) + payloadJson, _ := json.Marshal(payload) + + e2e_tests.DoAPIPostAndAssertCode(s.T(), 409, "/mobile/devices/"+device.Id+"/browser_extensions", payloadJson, pairingResult) +} + +func getExtensionPairedDevices(t *testing.T, browserExt *e2e_tests.BrowserExtensionResponse) []*e2e_tests.ExtensionPairedDeviceResponse { + var extensionDevices []*e2e_tests.ExtensionPairedDeviceResponse + e2e_tests.DoAPISuccessGet(t, "/browser_extensions/"+browserExt.Id+"/devices/", &extensionDevices) + return extensionDevices +} diff --git a/tests/browser_extension/browser_extension_test.go b/e2e-tests/browser_extension/browser_extension_test.go similarity index 69% rename from tests/browser_extension/browser_extension_test.go rename to e2e-tests/browser_extension/browser_extension_test.go index 04874d4..1052638 100644 --- a/tests/browser_extension/browser_extension_test.go +++ b/e2e-tests/browser_extension/browser_extension_test.go @@ -8,8 +8,8 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "github.com/twofas/2fas-server/e2e-tests" "github.com/twofas/2fas-server/internal/common/crypto" - "github.com/twofas/2fas-server/tests" ) func TestBrowserExtensionTestSuite(t *testing.T) { @@ -21,7 +21,7 @@ type BrowserExtensionTestSuite struct { } func (s *BrowserExtensionTestSuite) SetupTest() { - tests.RemoveAllBrowserExtensions(s.T()) + e2e_tests.RemoveAllBrowserExtensions(s.T()) } func (s *BrowserExtensionTestSuite) TestCreateBrowserExtension() { @@ -47,13 +47,13 @@ func (s *BrowserExtensionTestSuite) TestCreateBrowserExtension() { } func (s *BrowserExtensionTestSuite) TestUpdateBrowserExtension() { - browserExt := tests.CreateBrowserExtension(s.T(), "go-test") + browserExt := e2e_tests.CreateBrowserExtension(s.T(), "go-test") payload := []byte(`{"name": "updated-extension-name"}`) - tests.DoAPISuccessPut(s.T(), "/browser_extensions/"+browserExt.Id, payload, nil) + e2e_tests.DoAPISuccessPut(s.T(), "/browser_extensions/"+browserExt.Id, payload, nil) - var browserExtension *tests.BrowserExtensionResponse - tests.DoAPISuccessGet(s.T(), "/browser_extensions/"+browserExt.Id, &browserExtension) + var browserExtension *e2e_tests.BrowserExtensionResponse + e2e_tests.DoAPISuccessGet(s.T(), "/browser_extensions/"+browserExt.Id, &browserExtension) assert.Equal(s.T(), "updated-extension-name", browserExtension.Name) } @@ -62,16 +62,16 @@ func (s *BrowserExtensionTestSuite) TestUpdateNotExistingBrowserExtension() { id := uuid.New() payload := []byte(`{"name": "updated-extension-name"}`) - response := tests.DoAPIRequest(s.T(), "/browser_extensions/"+id.String(), http.MethodPut, payload, nil) + response := e2e_tests.DoAPIRequest(s.T(), "/browser_extensions/"+id.String(), http.MethodPut, payload, nil) assert.Equal(s.T(), 404, response.StatusCode) } func (s *BrowserExtensionTestSuite) TestUpdateBrowserExtensionSetEmptyName() { - browserExt := tests.CreateBrowserExtension(s.T(), "go-test") + browserExt := e2e_tests.CreateBrowserExtension(s.T(), "go-test") payload := []byte(`{"name": ""}`) - response := tests.DoAPIRequest(s.T(), "/browser_extensions/"+browserExt.Id, http.MethodPut, payload, nil) + response := e2e_tests.DoAPIRequest(s.T(), "/browser_extensions/"+browserExt.Id, http.MethodPut, payload, nil) assert.Equal(s.T(), 400, response.StatusCode) } @@ -79,8 +79,8 @@ func (s *BrowserExtensionTestSuite) TestUpdateBrowserExtensionSetEmptyName() { func (s *BrowserExtensionTestSuite) TestDoNotFindNotExistingExtension() { notExistingId := uuid.New() - var browserExtension *tests.BrowserExtensionResponse - response := tests.DoAPIGet(s.T(), "/browser_extensions/"+notExistingId.String(), &browserExtension) + var browserExtension *e2e_tests.BrowserExtensionResponse + response := e2e_tests.DoAPIGet(s.T(), "/browser_extensions/"+notExistingId.String(), &browserExtension) assert.Equal(s.T(), 404, response.StatusCode) } @@ -92,6 +92,6 @@ func createBrowserExtension(t *testing.T, name string) *http.Response { payload := []byte(fmt.Sprintf(`{"name":"%s","browser_name":"go-browser","browser_version":"0.1","public_key":"%s"}`, name, pubKey)) - return tests.DoAPIRequest(t, "/browser_extensions", http.MethodPost, payload, nil) + return e2e_tests.DoAPIRequest(t, "/browser_extensions", http.MethodPost, payload, nil) } diff --git a/tests/helpers.go b/e2e-tests/helpers.go similarity index 99% rename from tests/helpers.go rename to e2e-tests/helpers.go index 4f9f8fd..5e6b788 100644 --- a/tests/helpers.go +++ b/e2e-tests/helpers.go @@ -1,4 +1,4 @@ -package tests +package e2e_tests import ( "encoding/json" diff --git a/tests/http.go b/e2e-tests/http.go similarity index 99% rename from tests/http.go rename to e2e-tests/http.go index 07db9d4..11c315f 100644 --- a/tests/http.go +++ b/e2e-tests/http.go @@ -1,4 +1,4 @@ -package tests +package e2e_tests import ( "bytes" diff --git a/tests/icons/icons_collection_test.go b/e2e-tests/icons/icons_collection_test.go similarity index 69% rename from tests/icons/icons_collection_test.go rename to e2e-tests/icons/icons_collection_test.go index c1ef64e..b2fb7ba 100644 --- a/tests/icons/icons_collection_test.go +++ b/e2e-tests/icons/icons_collection_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" - "github.com/twofas/2fas-server/tests" + "github.com/twofas/2fas-server/e2e-tests" ) type iconsCollectionResponse struct { @@ -26,7 +26,7 @@ type IconsCollectionsTestSuite struct { } func (s *IconsCollectionsTestSuite) SetupTest() { - tests.RemoveAllMobileIconsCollections(s.T()) + e2e_tests.RemoveAllMobileIconsCollections(s.T()) } func (s *IconsCollectionsTestSuite) TestCreateIconsCollection() { @@ -39,7 +39,7 @@ func (s *IconsCollectionsTestSuite) TestCreateIconsCollection() { `) var IconsCollection *iconsCollectionResponse - tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/collections", payload, &IconsCollection) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/collections", payload, &IconsCollection) assert.Equal(s.T(), "facebook", IconsCollection.Name) assert.Equal(s.T(), "desc", IconsCollection.Description) @@ -55,7 +55,7 @@ func (s *IconsCollectionsTestSuite) TestUpdateIconsCollection() { } `) var iconsCollection *iconsCollectionResponse - tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/collections", payload, &iconsCollection) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/collections", payload, &iconsCollection) updatePayload := []byte(` { @@ -65,7 +65,7 @@ func (s *IconsCollectionsTestSuite) TestUpdateIconsCollection() { `) var updatedIconsCollection *iconsCollectionResponse - tests.DoAdminSuccessPut(s.T(), "mobile/icons/collections/"+iconsCollection.Id, updatePayload, &updatedIconsCollection) + e2e_tests.DoAdminSuccessPut(s.T(), "mobile/icons/collections/"+iconsCollection.Id, updatePayload, &updatedIconsCollection) assert.Equal(s.T(), "meta", updatedIconsCollection.Name) assert.Equal(s.T(), []string{"icon-1", "icon-2"}, updatedIconsCollection.Icons) @@ -79,11 +79,11 @@ func (s *IconsCollectionsTestSuite) TestDeleteIconsCollection() { } `) var iconsCollection *iconsCollectionResponse - tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/collections", payload, &iconsCollection) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/collections", payload, &iconsCollection) - tests.DoAdminSuccessDelete(s.T(), "mobile/icons/collections/"+iconsCollection.Id) + e2e_tests.DoAdminSuccessDelete(s.T(), "mobile/icons/collections/"+iconsCollection.Id) - response := tests.DoAPIGet(s.T(), "mobile/icons/collections/"+iconsCollection.Id, nil) + response := e2e_tests.DoAPIGet(s.T(), "mobile/icons/collections/"+iconsCollection.Id, nil) assert.Equal(s.T(), 404, response.StatusCode) } @@ -94,7 +94,7 @@ func (s *IconsCollectionsTestSuite) TestFindAllIconsCollections() { "icons":["icon-1", "icon-2"] } `) - tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/collections", payload, nil) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/collections", payload, nil) payload2 := []byte(` { @@ -103,10 +103,10 @@ func (s *IconsCollectionsTestSuite) TestFindAllIconsCollections() { "icons":["123e4567-e89b-12d3-a456-426614174000"] } `) - tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/collections", payload2, nil) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/collections", payload2, nil) var IconsCollections []*iconsCollectionResponse - tests.DoAPISuccessGet(s.T(), "mobile/icons/collections", &IconsCollections) + e2e_tests.DoAPISuccessGet(s.T(), "mobile/icons/collections", &IconsCollections) assert.Len(s.T(), IconsCollections, 2) } @@ -119,10 +119,10 @@ func (s *IconsCollectionsTestSuite) TestFindIconsCollection() { } `) var createdIconsCollection *iconsCollectionResponse - tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/collections", payload, &createdIconsCollection) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/collections", payload, &createdIconsCollection) var IconsCollection *iconsCollectionResponse - tests.DoAPISuccessGet(s.T(), "mobile/icons/collections/"+createdIconsCollection.Id, &IconsCollection) + e2e_tests.DoAPISuccessGet(s.T(), "mobile/icons/collections/"+createdIconsCollection.Id, &IconsCollection) assert.Equal(s.T(), "just-one", IconsCollection.Name) } diff --git a/tests/icons/icons_requests_test.go b/e2e-tests/icons/icons_requests_test.go similarity index 74% rename from tests/icons/icons_requests_test.go rename to e2e-tests/icons/icons_requests_test.go index 4f972b4..021800f 100644 --- a/tests/icons/icons_requests_test.go +++ b/e2e-tests/icons/icons_requests_test.go @@ -9,8 +9,8 @@ import ( "github.com/jaswdr/faker" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "github.com/twofas/2fas-server/e2e-tests" "github.com/twofas/2fas-server/internal/api/icons/app/queries" - "github.com/twofas/2fas-server/tests" ) func TestIconsRequestsTestSuite(t *testing.T) { @@ -22,10 +22,10 @@ type IconsRequestsTestSuite struct { } func (s *IconsRequestsTestSuite) SetupTest() { - tests.RemoveAllMobileWebServices(s.T()) - tests.RemoveAllMobileIcons(s.T()) - tests.RemoveAllMobileIconsCollections(s.T()) - tests.RemoveAllMobileIconsRequests(s.T()) + e2e_tests.RemoveAllMobileWebServices(s.T()) + e2e_tests.RemoveAllMobileIcons(s.T()) + e2e_tests.RemoveAllMobileIconsCollections(s.T()) + e2e_tests.RemoveAllMobileIconsRequests(s.T()) } func (s *IconsRequestsTestSuite) TestCreateIconRequest() { @@ -58,15 +58,15 @@ func (s *IconsRequestsTestSuite) TestCreateIconRequestWithNotAllowedIconDimensio var iconRequest *queries.IconRequestPresenter - tests.DoAPIPostAndAssertCode(s.T(), 400, "mobile/icons/requests", payload, &iconRequest) + e2e_tests.DoAPIPostAndAssertCode(s.T(), 400, "mobile/icons/requests", payload, &iconRequest) } func (s *IconsRequestsTestSuite) TestDeleteIconRequest() { iconRequest := createIconRequest(s.T(), "service") - tests.DoAdminSuccessDelete(s.T(), "mobile/icons/requests/"+iconRequest.Id) + e2e_tests.DoAdminSuccessDelete(s.T(), "mobile/icons/requests/"+iconRequest.Id) - response := tests.DoAPIGet(s.T(), "mobile/icons/requests/"+iconRequest.Id, nil) + response := e2e_tests.DoAPIGet(s.T(), "mobile/icons/requests/"+iconRequest.Id, nil) assert.Equal(s.T(), 404, response.StatusCode) } @@ -75,7 +75,7 @@ func (s *IconsRequestsTestSuite) TestFindAllIconsRequests() { createIconRequest(s.T(), "service2") var iconsRequests []*queries.IconRequestPresenter - tests.DoAPISuccessGet(s.T(), "mobile/icons/requests", &iconsRequests) + e2e_tests.DoAPISuccessGet(s.T(), "mobile/icons/requests", &iconsRequests) assert.Len(s.T(), iconsRequests, 2) } @@ -84,7 +84,7 @@ func (s *IconsRequestsTestSuite) TestFindIconRequest() { iconRequest := createIconRequest(s.T(), "service") var searchResult *queries.IconPresenter - tests.DoAdminSuccessGet(s.T(), "mobile/icons/requests/"+iconRequest.Id, &searchResult) + e2e_tests.DoAdminSuccessGet(s.T(), "mobile/icons/requests/"+iconRequest.Id, &searchResult) assert.Equal(s.T(), "service", searchResult.Name) } @@ -93,7 +93,7 @@ func (s *IconsRequestsTestSuite) TestTransformIconRequestIntoWebService() { iconRequest := createIconRequest(s.T(), "service") var result *queries.WebServicePresenter - tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/requests/"+iconRequest.Id+"/commands/transform_to_web_service", nil, &result) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/requests/"+iconRequest.Id+"/commands/transform_to_web_service", nil, &result) assert.Equal(s.T(), "service", result.Name) } @@ -103,10 +103,10 @@ func (s *IconsRequestsTestSuite) TestTransformSingleIconRequestsIntoWebServiceFr createIconRequest(s.T(), "service") var result *queries.WebServicePresenter - tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/requests/"+iconRequest.Id+"/commands/transform_to_web_service", nil, &result) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/requests/"+iconRequest.Id+"/commands/transform_to_web_service", nil, &result) var icons []*queries.IconPresenter - tests.DoAPIGet(s.T(), "mobile/icons", &icons) + e2e_tests.DoAPIGet(s.T(), "mobile/icons", &icons) assert.Len(s.T(), icons, 1) } @@ -116,7 +116,7 @@ func (s *IconsRequestsTestSuite) TestTransformIconRequestWithAlreadyExistingWebS iconRequest := createIconRequest(s.T(), webService.Name) var result *queries.WebServicePresenter - tests.DoAdminPostAndAssertCode(s.T(), 409, "mobile/icons/requests/"+iconRequest.Id+"/commands/transform_to_web_service", nil, &result) + e2e_tests.DoAdminPostAndAssertCode(s.T(), 409, "mobile/icons/requests/"+iconRequest.Id+"/commands/transform_to_web_service", nil, &result) } func (s *IconsRequestsTestSuite) TestUpdateWebServiceFromIconRequest() { @@ -125,7 +125,7 @@ func (s *IconsRequestsTestSuite) TestUpdateWebServiceFromIconRequest() { var result *queries.WebServicePresenter payload := []byte(`{"web_service_id":"` + webService.Id + `"}`) - tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/requests/"+iconRequest.Id+"/commands/update_web_service", payload, &result) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/icons/requests/"+iconRequest.Id+"/commands/update_web_service", payload, &result) assert.Equal(s.T(), webService.Name, result.Name) @@ -163,7 +163,7 @@ func createIconRequest(t *testing.T, serviceName string) *queries.IconRequestPre var iconRequest *queries.IconRequestPresenter - tests.DoAPISuccessPost(t, "mobile/icons/requests", payload, &iconRequest) + e2e_tests.DoAPISuccessPost(t, "mobile/icons/requests", payload, &iconRequest) return iconRequest } diff --git a/tests/icons/icons_test.go b/e2e-tests/icons/icons_test.go similarity index 75% rename from tests/icons/icons_test.go rename to e2e-tests/icons/icons_test.go index f7b4d10..1806aa4 100644 --- a/tests/icons/icons_test.go +++ b/e2e-tests/icons/icons_test.go @@ -8,8 +8,8 @@ import ( "github.com/jaswdr/faker" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "github.com/twofas/2fas-server/e2e-tests" query "github.com/twofas/2fas-server/internal/api/icons/app/queries" - "github.com/twofas/2fas-server/tests" ) func createIcon(t *testing.T) *query.IconPresenter { @@ -34,7 +34,7 @@ func createIcon(t *testing.T) *query.IconPresenter { var icon *query.IconPresenter - tests.DoAdminAPISuccessPost(t, "mobile/icons", payload, &icon) + e2e_tests.DoAdminAPISuccessPost(t, "mobile/icons", payload, &icon) return icon } @@ -48,7 +48,7 @@ type IconsTestSuite struct { } func (s *IconsTestSuite) SetupTest() { - tests.DoAdminSuccessDelete(s.T(), "mobile/icons") + e2e_tests.DoAdminSuccessDelete(s.T(), "mobile/icons") } func (s *IconsTestSuite) TestCreateIcon() { @@ -67,7 +67,7 @@ func (s *IconsTestSuite) TestUpdateIcon() { `) var updatedIcon *query.IconPresenter - tests.DoAdminSuccessPut(s.T(), "mobile/icons/"+icon.Id, updatePayload, &updatedIcon) + e2e_tests.DoAdminSuccessPut(s.T(), "mobile/icons/"+icon.Id, updatePayload, &updatedIcon) assert.Equal(s.T(), "meta", updatedIcon.Name) } @@ -75,9 +75,9 @@ func (s *IconsTestSuite) TestUpdateIcon() { func (s *IconsTestSuite) TestDeleteIcon() { icon := createIcon(s.T()) - tests.DoAdminSuccessDelete(s.T(), "mobile/icons/"+icon.Id) + e2e_tests.DoAdminSuccessDelete(s.T(), "mobile/icons/"+icon.Id) - response := tests.DoAPIGet(s.T(), "mobile/icons/"+icon.Id, nil) + response := e2e_tests.DoAPIGet(s.T(), "mobile/icons/"+icon.Id, nil) assert.Equal(s.T(), 404, response.StatusCode) } @@ -86,7 +86,7 @@ func (s *IconsTestSuite) TestFindAllIcons() { createIcon(s.T()) var Icons []*query.IconPresenter - tests.DoAPISuccessGet(s.T(), "mobile/icons", &Icons) + e2e_tests.DoAPISuccessGet(s.T(), "mobile/icons", &Icons) assert.Len(s.T(), Icons, 2) } @@ -95,7 +95,7 @@ func (s *IconsTestSuite) TestFindIcon() { icon := createIcon(s.T()) var searchResult *query.IconPresenter - tests.DoAPISuccessGet(s.T(), "mobile/icons/"+icon.Id, &searchResult) + e2e_tests.DoAPISuccessGet(s.T(), "mobile/icons/"+icon.Id, &searchResult) assert.Equal(s.T(), "facebook", searchResult.Name) } diff --git a/tests/icons/web_services_dump_test.go b/e2e-tests/icons/web_services_dump_test.go similarity index 73% rename from tests/icons/web_services_dump_test.go rename to e2e-tests/icons/web_services_dump_test.go index 8537471..15a7a81 100644 --- a/tests/icons/web_services_dump_test.go +++ b/e2e-tests/icons/web_services_dump_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" - "github.com/twofas/2fas-server/tests" + "github.com/twofas/2fas-server/e2e-tests" ) func TestWebServicesDumpTestSuite(t *testing.T) { @@ -19,16 +19,16 @@ type WebServicesDumpTestSuite struct { } func (s *WebServicesDumpTestSuite) SetupTest() { - tests.RemoveAllMobileIcons(s.T()) - tests.RemoveAllMobileIconsCollections(s.T()) - tests.RemoveAllMobileWebServices(s.T()) + e2e_tests.RemoveAllMobileIcons(s.T()) + e2e_tests.RemoveAllMobileIconsCollections(s.T()) + e2e_tests.RemoveAllMobileWebServices(s.T()) } func (s *WebServicesDumpTestSuite) TestWebServicesDump() { createWebService(s.T()) createWebService(s.T()) - response := tests.DoAPIGet(s.T(), "mobile/web_services/dump", nil) + response := e2e_tests.DoAPIGet(s.T(), "mobile/web_services/dump", nil) assert.Equal(s.T(), 200, response.StatusCode) } @@ -48,7 +48,7 @@ func createWebService(t *testing.T) *webServiceResponse { var webService *webServiceResponse - tests.DoAdminAPISuccessPost(t, "mobile/web_services", payload, &webService) + e2e_tests.DoAdminAPISuccessPost(t, "mobile/web_services", payload, &webService) return webService } @@ -66,7 +66,7 @@ func createIconsCollection(t *testing.T) *iconsCollectionResponse { var createdIconsCollection *iconsCollectionResponse - tests.DoAdminAPISuccessPost(t, "mobile/icons/collections", payload, &createdIconsCollection) + e2e_tests.DoAdminAPISuccessPost(t, "mobile/icons/collections", payload, &createdIconsCollection) return createdIconsCollection } diff --git a/tests/icons/web_services_test.go b/e2e-tests/icons/web_services_test.go similarity index 78% rename from tests/icons/web_services_test.go rename to e2e-tests/icons/web_services_test.go index 12d09fa..2e23e57 100644 --- a/tests/icons/web_services_test.go +++ b/e2e-tests/icons/web_services_test.go @@ -5,8 +5,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "github.com/twofas/2fas-server/e2e-tests" "github.com/twofas/2fas-server/internal/api/icons/app/command" - "github.com/twofas/2fas-server/tests" ) type webServiceResponse struct { @@ -30,7 +30,7 @@ type WebServicesTestSuite struct { } func (s *WebServicesTestSuite) SetupTest() { - tests.RemoveAllMobileWebServices(s.T()) + e2e_tests.RemoveAllMobileWebServices(s.T()) } func (s *WebServicesTestSuite) TestCreateWebService() { @@ -45,7 +45,7 @@ func (s *WebServicesTestSuite) TestCreateWebService() { `) var webService *webServiceResponse - tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, &webService) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, &webService) assert.Equal(s.T(), "facebook", webService.Name) assert.Equal(s.T(), "desc", webService.Description) @@ -65,8 +65,8 @@ func (s *WebServicesTestSuite) TestCreateWebServiceWithAlreadyExistingName() { } `) - tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, nil) - tests.DoAdminPostAndAssertCode(s.T(), 409, "mobile/web_services", payload, nil) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, nil) + e2e_tests.DoAdminPostAndAssertCode(s.T(), 409, "mobile/web_services", payload, nil) } func (s *WebServicesTestSuite) TestCreateWebServiceWithMatchRules() { @@ -80,7 +80,7 @@ func (s *WebServicesTestSuite) TestCreateWebServiceWithMatchRules() { `) var webService *webServiceResponse - tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, &webService) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, &webService) assert.Equal(s.T(), []*command.MatchRule{{ Field: "label", @@ -101,7 +101,7 @@ func (s *WebServicesTestSuite) TestUpdateWebService() { } `) var webService *webServiceResponse - tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, &webService) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, &webService) updatePayload := []byte(`{ "name":"meta", @@ -112,7 +112,7 @@ func (s *WebServicesTestSuite) TestUpdateWebService() { `) var updatedWebService *webServiceResponse - tests.DoAdminSuccessPut(s.T(), "mobile/web_services/"+webService.Id, updatePayload, &updatedWebService) + e2e_tests.DoAdminSuccessPut(s.T(), "mobile/web_services/"+webService.Id, updatePayload, &updatedWebService) assert.Equal(s.T(), "meta", updatedWebService.Name) assert.Equal(s.T(), []string{"meta", "facebook"}, updatedWebService.Issuers) @@ -130,7 +130,7 @@ func (s *WebServicesTestSuite) TestUpdateWebServiceMatchRule() { } `) var webService *webServiceResponse - tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, &webService) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, &webService) updatePayload := []byte(`{ "name":"meta", @@ -141,7 +141,7 @@ func (s *WebServicesTestSuite) TestUpdateWebServiceMatchRule() { `) var updatedWebService *webServiceResponse - tests.DoAdminSuccessPut(s.T(), "mobile/web_services/"+webService.Id, updatePayload, &updatedWebService) + e2e_tests.DoAdminSuccessPut(s.T(), "mobile/web_services/"+webService.Id, updatePayload, &updatedWebService) assert.Equal(s.T(), "issuer", updatedWebService.MatchRules[0].Field) assert.Equal(s.T(), "facebook.pl", updatedWebService.MatchRules[0].Text) @@ -160,11 +160,11 @@ func (s *WebServicesTestSuite) TestDeleteWebService() { } `) var webService *webServiceResponse - tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, &webService) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, &webService) - tests.DoAdminSuccessDelete(s.T(), "mobile/web_services/"+webService.Id) + e2e_tests.DoAdminSuccessDelete(s.T(), "mobile/web_services/"+webService.Id) - response := tests.DoAPIGet(s.T(), "mobile/web_services/"+webService.Id, nil) + response := e2e_tests.DoAPIGet(s.T(), "mobile/web_services/"+webService.Id, nil) assert.Equal(s.T(), 404, response.StatusCode) } @@ -178,7 +178,7 @@ func (s *WebServicesTestSuite) TestFindAllWebServices() { "icons_collections":["123e4567-e89b-12d3-a456-426614174000"] } `) - tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, nil) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, nil) payload2 := []byte(` { @@ -189,10 +189,10 @@ func (s *WebServicesTestSuite) TestFindAllWebServices() { "icons_collections":["123e4567-e89b-12d3-a456-426614174000"] } `) - tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload2, nil) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload2, nil) var webServices []*webServiceResponse - tests.DoAPISuccessGet(s.T(), "mobile/web_services", &webServices) + e2e_tests.DoAPISuccessGet(s.T(), "mobile/web_services", &webServices) assert.Len(s.T(), webServices, 2) } @@ -207,10 +207,10 @@ func (s *WebServicesTestSuite) TestFindWebService() { } `) var createdWebService *webServiceResponse - tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, &createdWebService) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/web_services", payload, &createdWebService) var webService *webServiceResponse - tests.DoAPISuccessGet(s.T(), "mobile/web_services/"+createdWebService.Id, &webService) + e2e_tests.DoAPISuccessGet(s.T(), "mobile/web_services/"+createdWebService.Id, &webService) assert.Equal(s.T(), "just-one", webService.Name) } diff --git a/tests/localstack_init.sh b/e2e-tests/localstack_init.sh similarity index 100% rename from tests/localstack_init.sh rename to e2e-tests/localstack_init.sh diff --git a/e2e-tests/mobile/mobile_browser_extensions_2fa_requests_test.go b/e2e-tests/mobile/mobile_browser_extensions_2fa_requests_test.go new file mode 100644 index 0000000..b9f813e --- /dev/null +++ b/e2e-tests/mobile/mobile_browser_extensions_2fa_requests_test.go @@ -0,0 +1,54 @@ +package tests + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "github.com/twofas/2fas-server/e2e-tests" +) + +func TestMobileDeviceExtensionIntegrationTestSuite(t *testing.T) { + suite.Run(t, new(MobileDeviceExtensionIntegrationTestSuite)) +} + +type MobileDeviceExtensionIntegrationTestSuite struct { + suite.Suite +} + +func (s *MobileDeviceExtensionIntegrationTestSuite) SetupTest() { + e2e_tests.RemoveAllMobileDevices(s.T()) + e2e_tests.RemoveAllBrowserExtensions(s.T()) + e2e_tests.RemoveAllBrowserExtensionsDevices(s.T()) +} + +func (s *MobileDeviceExtensionIntegrationTestSuite) TestGetPending2FaRequests() { + device, devicePubKey := e2e_tests.CreateDevice(s.T(), "SM-955F", "fcm-token") + browserExtension := e2e_tests.CreateBrowserExtension(s.T(), "go-ext") + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExtension, device) + + var tokenRequest *e2e_tests.AuthTokenRequestResponse + request2FaTokenPayload := []byte(`{"domain":"domain.com"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", request2FaTokenPayload, &tokenRequest) + + var tokenRequestsCollection []*e2e_tests.AuthTokenRequestResponse + e2e_tests.DoAPISuccessGet(s.T(), "mobile/devices/"+device.Id+"/browser_extensions/2fa_requests", &tokenRequestsCollection) + assert.Len(s.T(), tokenRequestsCollection, 1) +} + +func (s *MobileDeviceExtensionIntegrationTestSuite) TestDoNotReturnCompleted2FaRequests() { + device, devicePubKey := e2e_tests.CreateDevice(s.T(), "SM-955F", "fcm-token") + browserExtension := e2e_tests.CreateBrowserExtension(s.T(), "go-ext") + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExtension, device) + + var tokenRequest *e2e_tests.AuthTokenRequestResponse + request2FaTokenPayload := []byte(`{"domain":"domain.com"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", request2FaTokenPayload, &tokenRequest) + + closeTokenRequestPayload := []byte(`{"status":"completed"}`) + e2e_tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id+"/commands/close_2fa_request", closeTokenRequestPayload, nil) + + var tokenRequestsCollection []*e2e_tests.AuthTokenRequestResponse + e2e_tests.DoAPISuccessGet(s.T(), "mobile/devices/"+device.Id+"/browser_extensions/2fa_requests", &tokenRequestsCollection) + assert.Len(s.T(), tokenRequestsCollection, 0) +} diff --git a/e2e-tests/mobile/mobile_device_extension_test.go b/e2e-tests/mobile/mobile_device_extension_test.go new file mode 100644 index 0000000..07052ed --- /dev/null +++ b/e2e-tests/mobile/mobile_device_extension_test.go @@ -0,0 +1,97 @@ +package tests + +import ( + "fmt" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "github.com/twofas/2fas-server/e2e-tests" +) + +func TestMobileDeviceExtensionTestSuite(t *testing.T) { + suite.Run(t, new(MobileDeviceExtensionTestSuite)) +} + +type MobileDeviceExtensionTestSuite struct { + suite.Suite +} + +func (s *MobileDeviceExtensionTestSuite) SetupTest() { + e2e_tests.RemoveAllMobileDevices(s.T()) + e2e_tests.RemoveAllBrowserExtensions(s.T()) + e2e_tests.RemoveAllBrowserExtensionsDevices(s.T()) +} + +func (s *MobileDeviceExtensionTestSuite) TestDoNotFindExtensionsForNotExistingDevice() { + notExistingDeviceId := uuid.New() + + response := e2e_tests.DoAPIGet(s.T(), "/mobile/devices/"+notExistingDeviceId.String()+"/browser_extensions", nil) + + assert.Equal(s.T(), 404, response.StatusCode) +} + +func (s *MobileDeviceExtensionTestSuite) TestDoNotFindNotExistingMobileDeviceExtension() { + browserExt := e2e_tests.CreateBrowserExtension(s.T(), "go-test") + device, devicePubKey := e2e_tests.CreateDevice(s.T(), "go-test-device", "some-device-id") + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt, device) + + notExistingExtensionId := uuid.New() + response := e2e_tests.DoAPIGet(s.T(), "/mobile/devices/"+device.Id+"/browser_extensions/"+notExistingExtensionId.String(), nil) + + assert.Equal(s.T(), 404, response.StatusCode) +} + +func (s *MobileDeviceExtensionTestSuite) Test_FindExtensionForDevice() { + browserExt := e2e_tests.CreateBrowserExtension(s.T(), "go-test") + device, devicePubKey := e2e_tests.CreateDevice(s.T(), "go-test-device", "some-device-id") + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt, device) + + var deviceBrowserExtension *e2e_tests.BrowserExtensionResponse + e2e_tests.DoAPISuccessGet(s.T(), "/mobile/devices/"+device.Id+"/browser_extensions/"+browserExt.Id, &deviceBrowserExtension) + + assert.Equal(s.T(), browserExt.Id, deviceBrowserExtension.Id) +} + +func (s *MobileDeviceExtensionTestSuite) Test_FindAllDeviceExtensions() { + browserExt1 := e2e_tests.CreateBrowserExtension(s.T(), "go-test-1") + browserExt2 := e2e_tests.CreateBrowserExtension(s.T(), "go-test-2") + device, devicePubKey := e2e_tests.CreateDevice(s.T(), "go-test-device", "some-device-id") + + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt1, device) + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt2, device) + + var deviceBrowserExtensions []*e2e_tests.BrowserExtensionResponse + e2e_tests.DoAPISuccessGet(s.T(), "/mobile/devices/"+device.Id+"/browser_extensions/", &deviceBrowserExtensions) + + assert.Len(s.T(), deviceBrowserExtensions, 2) +} + +func (s *MobileDeviceExtensionTestSuite) Test_DisconnectExtensionFromDevice() { + browserExt1 := e2e_tests.CreateBrowserExtension(s.T(), "go-test") + browserExt2 := e2e_tests.CreateBrowserExtension(s.T(), "go-test") + device, devicePubKey := e2e_tests.CreateDevice(s.T(), "go-test-device", "some-device-id") + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt1, device) + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt2, device) + + e2e_tests.DoAPISuccessDelete(s.T(), "/mobile/devices/"+device.Id+"/browser_extensions/"+browserExt1.Id) + + var deviceBrowserExtension1 *e2e_tests.BrowserExtensionResponse + response := e2e_tests.DoAPIGet(s.T(), "/mobile/devices/"+device.Id+"/browser_extensions/"+browserExt1.Id, &deviceBrowserExtension1) + assert.Equal(s.T(), 404, response.StatusCode) + + var deviceBrowserExtension2 *e2e_tests.BrowserExtensionResponse + e2e_tests.DoAPISuccessGet(s.T(), "/mobile/devices/"+device.Id+"/browser_extensions/"+browserExt2.Id, &deviceBrowserExtension2) + assert.Equal(s.T(), browserExt2.Id, deviceBrowserExtension2.Id) +} + +func (s *MobileDeviceExtensionTestSuite) TestExtensionHasAlreadyBeenConnected() { + extension := e2e_tests.CreateBrowserExtension(s.T(), "go-test") + device, devicePubKey := e2e_tests.CreateDevice(s.T(), "go-test-device", "some-device-id") + e2e_tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, extension, device) + + payload := []byte(fmt.Sprintf(`{"extension_id":"%s","device_name":"%s","device_public_key":"%s"}`, extension.Id, device.Name, devicePubKey)) + + e2e_tests.DoAPIPostAndAssertCode(s.T(), 409, "/mobile/devices/"+device.Id+"/browser_extensions", payload, nil) +} diff --git a/tests/mobile/mobile_device_test.go b/e2e-tests/mobile/mobile_device_test.go similarity index 89% rename from tests/mobile/mobile_device_test.go rename to e2e-tests/mobile/mobile_device_test.go index 2ccc92d..8929ed5 100644 --- a/tests/mobile/mobile_device_test.go +++ b/e2e-tests/mobile/mobile_device_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" - "github.com/twofas/2fas-server/tests" + "github.com/twofas/2fas-server/e2e-tests" ) func TestMobileDeviceTestSuite(t *testing.T) { @@ -19,7 +19,7 @@ type MobileDeviceTestSuite struct { } func (s *MobileDeviceTestSuite) SetupTest() { - tests.RemoveAllMobileDevices(s.T()) + e2e_tests.RemoveAllMobileDevices(s.T()) } func (s *MobileDeviceTestSuite) TestCreateMobileDevice() { @@ -49,5 +49,5 @@ func (s *MobileDeviceTestSuite) TestCreateMobileDevice() { func createDevice(t *testing.T, name, fcmToken string) *http.Response { payload := []byte(fmt.Sprintf(`{"name":"%s","platform":"android","fcm_token":"%s"}`, name, fcmToken)) - return tests.DoAPIRequest(t, "mobile/devices", http.MethodPost, payload, nil) + return e2e_tests.DoAPIRequest(t, "mobile/devices", http.MethodPost, payload, nil) } diff --git a/tests/mobile/mobile_notifications_test.go b/e2e-tests/mobile/mobile_notifications_test.go similarity index 72% rename from tests/mobile/mobile_notifications_test.go rename to e2e-tests/mobile/mobile_notifications_test.go index ae997a4..98f5a43 100644 --- a/tests/mobile/mobile_notifications_test.go +++ b/e2e-tests/mobile/mobile_notifications_test.go @@ -7,8 +7,8 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "github.com/twofas/2fas-server/e2e-tests" query "github.com/twofas/2fas-server/internal/api/mobile/app/queries" - "github.com/twofas/2fas-server/tests" ) func TestMobileNotificationsTestSuite(t *testing.T) { @@ -20,7 +20,7 @@ type MobileNotificationsTestSuite struct { } func (s *MobileNotificationsTestSuite) SetupTest() { - tests.RemoveAllMobileNotifications(s.T()) + e2e_tests.RemoveAllMobileNotifications(s.T()) } func (s *MobileNotificationsTestSuite) TestCreateMobileNotification() { @@ -28,7 +28,7 @@ func (s *MobileNotificationsTestSuite) TestCreateMobileNotification() { var notification *query.MobileNotificationPresenter - tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications", payload, ¬ification) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications", payload, ¬ification) assert.Equal(s.T(), "android", notification.Platform) assert.Equal(s.T(), "0.1", notification.Version) @@ -40,11 +40,11 @@ func (s *MobileNotificationsTestSuite) TestCreateMobileNotification() { func (s *MobileNotificationsTestSuite) TestUpdateMobileNotification() { payload := []byte(`{"icon":"features","platform":"android","link":"2fas.com","message":"demo","version":"0.1"}`) var notification *query.MobileNotificationPresenter - tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications", payload, ¬ification) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications", payload, ¬ification) payload = []byte(`{"icon":"youtube","platform":"ios","link":"new-2fas.com","message":"new-demo","version":"1.1"}`) var updatedNotification *query.MobileNotificationPresenter - tests.DoAdminSuccessPut(s.T(), "mobile/notifications/"+notification.Id, payload, &updatedNotification) + e2e_tests.DoAdminSuccessPut(s.T(), "mobile/notifications/"+notification.Id, payload, &updatedNotification) assert.Equal(s.T(), "ios", updatedNotification.Platform) assert.Equal(s.T(), "1.1", updatedNotification.Version) @@ -56,18 +56,18 @@ func (s *MobileNotificationsTestSuite) TestUpdateMobileNotification() { func (s *MobileNotificationsTestSuite) TestDeleteMobileNotification() { payload := []byte(`{"icon":"features","platform":"android","link":"2fas.com","message":"demo","version":"0.1"}`) var notification *query.MobileNotificationPresenter - tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications", payload, ¬ification) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications", payload, ¬ification) - tests.DoAdminSuccessDelete(s.T(), "mobile/notifications/"+notification.Id) + e2e_tests.DoAdminSuccessDelete(s.T(), "mobile/notifications/"+notification.Id) - response := tests.DoAPIGet(s.T(), "mobile/notifications/"+notification.Id, nil) + response := e2e_tests.DoAPIGet(s.T(), "mobile/notifications/"+notification.Id, nil) assert.Equal(s.T(), 404, response.StatusCode) } func (s *MobileNotificationsTestSuite) TestDeleteNotExistingMobileNotification() { id := uuid.New() - response := tests.DoAPIRequest(s.T(), "mobile/notifications/"+id.String(), http.MethodDelete, nil /*payload*/, nil /*resp*/) + response := e2e_tests.DoAPIRequest(s.T(), "mobile/notifications/"+id.String(), http.MethodDelete, nil /*payload*/, nil /*resp*/) assert.Equal(s.T(), 404, response.StatusCode) } @@ -75,14 +75,14 @@ func (s *MobileNotificationsTestSuite) TestDeleteNotExistingMobileNotification() func (s *MobileNotificationsTestSuite) TestFindAllNotifications() { payload1 := []byte(`{"icon":"features","platform":"android","link":"2fas.com","message":"demo","version":"0.1"}`) var notification1 *query.MobileNotificationPresenter - tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications", payload1, ¬ification1) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications", payload1, ¬ification1) payload2 := []byte(`{"icon":"youtube","platform":"android","link":"2fas.com","message":"demo2","version":"1.1"}`) var notification2 *query.MobileNotificationPresenter - tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications", payload2, ¬ification2) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications", payload2, ¬ification2) var collection []*query.MobileNotificationPresenter - tests.DoAPISuccessGet(s.T(), "mobile/notifications", &collection) + e2e_tests.DoAPISuccessGet(s.T(), "mobile/notifications", &collection) assert.Len(s.T(), collection, 2) } @@ -90,7 +90,7 @@ func (s *MobileNotificationsTestSuite) TestFindAllNotifications() { func (s *MobileNotificationsTestSuite) TestDoNotFindNotifications() { var collection []*query.MobileNotificationPresenter - tests.DoAPISuccessGet(s.T(), "mobile/notifications", &collection) + e2e_tests.DoAPISuccessGet(s.T(), "mobile/notifications", &collection) assert.Len(s.T(), collection, 0) } @@ -98,10 +98,10 @@ func (s *MobileNotificationsTestSuite) TestDoNotFindNotifications() { func (s *MobileNotificationsTestSuite) TestPublishNotification() { payload := []byte(`{"icon":"features","platform":"android","link":"2fas.com","message":"demo","version":"0.1"}`) var notification *query.MobileNotificationPresenter - tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications", payload, ¬ification) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications", payload, ¬ification) var publishedNotification *query.MobileNotificationPresenter - tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications/"+notification.Id+"/commands/publish", payload, &publishedNotification) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/notifications/"+notification.Id+"/commands/publish", payload, &publishedNotification) assert.NotEmpty(s.T(), "published_at", notification.PublishedAt) } diff --git a/tests/mobile/mobile_security_test.go b/e2e-tests/mobile/mobile_security_test.go similarity index 88% rename from tests/mobile/mobile_security_test.go rename to e2e-tests/mobile/mobile_security_test.go index 689df09..283b4d2 100644 --- a/tests/mobile/mobile_security_test.go +++ b/e2e-tests/mobile/mobile_security_test.go @@ -6,7 +6,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" - "github.com/twofas/2fas-server/tests" + "github.com/twofas/2fas-server/e2e-tests" "golang.org/x/sync/errgroup" ) @@ -21,7 +21,7 @@ func Test_MobileApiBandwidthAbuse(t *testing.T) { eg.SetLimit(noOfWorkers) for i := 0; i < noOfRequest; i++ { eg.Go(func() error { - resp := tests.DoAPIGet(t, "/mobile/devices/"+someId.String()+"/browser_extensions", nil) + resp := e2e_tests.DoAPIGet(t, "/mobile/devices/"+someId.String()+"/browser_extensions", nil) responseCh <- resp.StatusCode @@ -59,7 +59,7 @@ func Test_BrowserExtensionApiBandwidthAbuse(t *testing.T) { eg.SetLimit(noOfWorkers) for i := 0; i < noOfRequest; i++ { eg.Go(func() error { - resp := tests.DoAPIGet(t, "/browser_extensions/"+someId.String(), nil) + resp := e2e_tests.DoAPIGet(t, "/browser_extensions/"+someId.String(), nil) responseCh <- resp.StatusCode diff --git a/tests/pass/http.go b/e2e-tests/pass/http.go similarity index 98% rename from tests/pass/http.go rename to e2e-tests/pass/http.go index 2668b40..3dd78e0 100644 --- a/tests/pass/http.go +++ b/e2e-tests/pass/http.go @@ -106,7 +106,7 @@ func getMobileToken(fcm string) (string, error) { MobileSyncConfirmToken string `json:"mobile_sync_confirm_token"` } if err := request("GET", fmt.Sprintf("/mobile/sync/%s/token", fcm), "", nil, &resp); err != nil { - return "", fmt.Errorf("failed to get mobile token") + return "", fmt.Errorf("failed to get mobile token: %w", err) } return resp.MobileSyncConfirmToken, nil diff --git a/tests/pass/kms_test.go b/e2e-tests/pass/kms_test.go similarity index 100% rename from tests/pass/kms_test.go rename to e2e-tests/pass/kms_test.go diff --git a/tests/pass/lib.go b/e2e-tests/pass/lib.go similarity index 100% rename from tests/pass/lib.go rename to e2e-tests/pass/lib.go diff --git a/tests/pass/pair_test.go b/e2e-tests/pass/pair_test.go similarity index 100% rename from tests/pass/pair_test.go rename to e2e-tests/pass/pair_test.go diff --git a/tests/pass/sync_test.go b/e2e-tests/pass/sync_test.go similarity index 100% rename from tests/pass/sync_test.go rename to e2e-tests/pass/sync_test.go diff --git a/tests/pass/ws.go b/e2e-tests/pass/ws.go similarity index 100% rename from tests/pass/ws.go rename to e2e-tests/pass/ws.go diff --git a/tests/responses.go b/e2e-tests/responses.go similarity index 98% rename from tests/responses.go rename to e2e-tests/responses.go index ec2c26f..2f24e7a 100644 --- a/tests/responses.go +++ b/e2e-tests/responses.go @@ -1,4 +1,4 @@ -package tests +package e2e_tests type DeviceResponse struct { Id string `json:"id"` diff --git a/e2e-tests/scripts/wait-ready/main.go b/e2e-tests/scripts/wait-ready/main.go new file mode 100644 index 0000000..52107cc --- /dev/null +++ b/e2e-tests/scripts/wait-ready/main.go @@ -0,0 +1,50 @@ +package main + +import ( + "flag" + "log" + "net" + "strings" + "time" +) + +func main() { + addrFlag := flag.String("addr", ":80;:8081;:8082", "list of addresses to check sep by ;") + flag.Parse() + + addresses := strings.Split(*addrFlag, ";") + if len(addresses) < 1 { + log.Fatal("-addr value not provided") + } + for _, address := range addresses { + running := waitForApp(address, 30*time.Second) + if !running { + log.Fatal("App not running on addr: ", address) + } + } +} + +// waitForApp returns true if app is listening on provided address. +// If it cannot connect up to specified timeout, it returns false. +func waitForApp(address string, timeout time.Duration) bool { + done := make(chan struct{}) + + go func() { + for { + _, err := net.DialTimeout("tcp", address, time.Second) + if err != nil { + time.Sleep(time.Second) + continue + } + close(done) + return + } + }() + timeoutCh := time.After(timeout) + select { + case <-done: + return true + case <-timeoutCh: + return false + } +} diff --git a/tests/support/mobile_debug_logs_test.go b/e2e-tests/support/mobile_debug_logs_test.go similarity index 85% rename from tests/support/mobile_debug_logs_test.go rename to e2e-tests/support/mobile_debug_logs_test.go index f416179..d273d91 100644 --- a/tests/support/mobile_debug_logs_test.go +++ b/e2e-tests/support/mobile_debug_logs_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/twofas/2fas-server/e2e-tests" query "github.com/twofas/2fas-server/internal/api/support/app/queries" - "github.com/twofas/2fas-server/tests" ) func TestDebugLogsAuditTestSuite(t *testing.T) { @@ -25,7 +25,7 @@ type DebugLogsAuditTestSuite struct { } func (s *DebugLogsAuditTestSuite) SetupTest() { - tests.DoAdminSuccessDelete(s.T(), "mobile/support/debug_logs/audit") + e2e_tests.DoAdminSuccessDelete(s.T(), "mobile/support/debug_logs/audit") } func (s *DebugLogsAuditTestSuite) TestCreateDebugLogsAuditClaim() { @@ -33,7 +33,7 @@ func (s *DebugLogsAuditTestSuite) TestCreateDebugLogsAuditClaim() { auditClaim := new(query.DebugLogsAuditPresenter) - tests.DoAdminAPISuccessPost(s.T(), "mobile/support/debug_logs/audit/claim", payload, auditClaim) + e2e_tests.DoAdminAPISuccessPost(s.T(), "mobile/support/debug_logs/audit/claim", payload, auditClaim) assert.Equal(s.T(), "app-user", auditClaim.Username) assert.Equal(s.T(), "some description", auditClaim.Description) @@ -44,7 +44,7 @@ func (s *DebugLogsAuditTestSuite) TestUpdateDebugLogsAuditClaim() { var updatedAuditClaim *query.DebugLogsAuditPresenter updatePayload := []byte(`{"username": "app-user-1", "description": "another description"}`) - tests.DoAdminSuccessPut(s.T(), "mobile/support/debug_logs/audit/claim/"+auditClaim.Id, updatePayload, &updatedAuditClaim) + e2e_tests.DoAdminSuccessPut(s.T(), "mobile/support/debug_logs/audit/claim/"+auditClaim.Id, updatePayload, &updatedAuditClaim) assert.Equal(s.T(), "app-user-1", updatedAuditClaim.Username) assert.Equal(s.T(), "another description", updatedAuditClaim.Description) @@ -121,7 +121,7 @@ func (s *DebugLogsAuditTestSuite) TestGetDebugLogsAudit() { auditClaim := createDebugLogsAuditClaim(s.T(), "user1", "desc1") audit := new(query.DebugLogsAuditPresenter) - tests.DoAdminSuccessGet(s.T(), "mobile/support/debug_logs/audit/"+auditClaim.Id, audit) + e2e_tests.DoAdminSuccessGet(s.T(), "mobile/support/debug_logs/audit/"+auditClaim.Id, audit) assert.Equal(s.T(), auditClaim.Id, audit.Id) assert.Equal(s.T(), "user1", audit.Username) @@ -131,9 +131,9 @@ func (s *DebugLogsAuditTestSuite) TestGetDebugLogsAudit() { func (s *DebugLogsAuditTestSuite) TestDeleteDebugLogsAudit() { auditClaim := createDebugLogsAuditClaim(s.T(), "user1", "desc1") - tests.DoAdminSuccessDelete(s.T(), "mobile/support/debug_logs/audit/"+auditClaim.Id) + e2e_tests.DoAdminSuccessDelete(s.T(), "mobile/support/debug_logs/audit/"+auditClaim.Id) - response := tests.DoAPIGet(s.T(), "mobile/support/debug_logs/audit/"+auditClaim.Id, nil) + response := e2e_tests.DoAPIGet(s.T(), "mobile/support/debug_logs/audit/"+auditClaim.Id, nil) assert.Equal(s.T(), 404, response.StatusCode) } @@ -142,7 +142,7 @@ func (s *DebugLogsAuditTestSuite) TestFindAllDebugLogsAudit() { createDebugLogsAuditClaim(s.T(), "user2", "desc2") var audits []*query.DebugLogsAuditPresenter - tests.DoAdminSuccessGet(s.T(), "mobile/support/debug_logs/audit", &audits) + e2e_tests.DoAdminSuccessGet(s.T(), "mobile/support/debug_logs/audit", &audits) assert.Len(s.T(), audits, 2) } @@ -151,7 +151,7 @@ func createDebugLogsAuditClaim(t *testing.T, username, description string) *quer payload := []byte(`{"username": "` + username + `", "description": "` + description + `"}`) auditClaim := new(query.DebugLogsAuditPresenter) - tests.DoAdminAPISuccessPost(t, "mobile/support/debug_logs/audit/claim", payload, auditClaim) + e2e_tests.DoAdminAPISuccessPost(t, "mobile/support/debug_logs/audit/claim", payload, auditClaim) return auditClaim } diff --git a/tests/system/api_test.go b/e2e-tests/system/api_test.go similarity index 79% rename from tests/system/api_test.go rename to e2e-tests/system/api_test.go index c5b9e2d..dc45d4a 100644 --- a/tests/system/api_test.go +++ b/e2e-tests/system/api_test.go @@ -6,11 +6,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/twofas/2fas-server/tests" + "github.com/twofas/2fas-server/e2e-tests" ) func Test_Default404Response(t *testing.T) { - response := tests.DoAPIGet(t, "some/not/existing/endpoint", nil) + response := e2e_tests.DoAPIGet(t, "some/not/existing/endpoint", nil) rawBody, err := io.ReadAll(response.Body) require.NoError(t, err) diff --git a/tests/websocket.go b/e2e-tests/websocket.go similarity index 98% rename from tests/websocket.go rename to e2e-tests/websocket.go index b8ed2d0..282faea 100644 --- a/tests/websocket.go +++ b/e2e-tests/websocket.go @@ -1,4 +1,4 @@ -package tests +package e2e_tests import ( "github.com/gorilla/websocket" diff --git a/internal/api/browser_extension/app/command/request_2fa_token.go b/internal/api/browser_extension/app/command/request_2fa_token.go index f013bcc..ebd5f79 100644 --- a/internal/api/browser_extension/app/command/request_2fa_token.go +++ b/internal/api/browser_extension/app/command/request_2fa_token.go @@ -58,7 +58,8 @@ type Request2FaTokenHandler struct { Pusher push.Pusher } -func (h *Request2FaTokenHandler) Handle(cmd *Request2FaToken) error { +func (h *Request2FaTokenHandler) Handle(ctx context.Context, cmd *Request2FaToken) error { + log := logging.FromContext(ctx) extId, _ := uuid.Parse(cmd.ExtensionId) browserExtension, err := h.BrowserExtensionsRepository.FindById(extId) @@ -87,7 +88,7 @@ func (h *Request2FaTokenHandler) Handle(cmd *Request2FaToken) error { for _, device := range pairedDevices { if device.FcmToken == "" { - logging.WithFields(logging.Fields{ + log.WithFields(logging.Fields{ "extension_id": extId.String(), "device_id": device.Id.String(), "token_request_id": cmd.Id, @@ -117,7 +118,7 @@ func (h *Request2FaTokenHandler) Handle(cmd *Request2FaToken) error { ) if err != nil && !messaging.IsUnregistered(err) { - logging.WithFields(logging.Fields{ + log.WithFields(logging.Fields{ "extension_id": extId.String(), "device_id": device.Id.String(), "token_request_id": cmd.Id, diff --git a/internal/api/browser_extension/app/command/store_log_event.go b/internal/api/browser_extension/app/command/store_log_event.go index ac3f9b6..5cb78d6 100644 --- a/internal/api/browser_extension/app/command/store_log_event.go +++ b/internal/api/browser_extension/app/command/store_log_event.go @@ -1,7 +1,9 @@ package command import ( + "context" "encoding/json" + "github.com/google/uuid" "github.com/twofas/2fas-server/internal/api/browser_extension/domain" "github.com/twofas/2fas-server/internal/common/logging" @@ -18,7 +20,7 @@ type StoreLogEventHandler struct { BrowserExtensionsRepository domain.BrowserExtensionRepository } -func (h *StoreLogEventHandler) Handle(cmd *StoreLogEvent) { +func (h *StoreLogEventHandler) Handle(ctx context.Context, cmd *StoreLogEvent) { extId, _ := uuid.Parse(cmd.ExtensionId) _, err := h.BrowserExtensionsRepository.FindById(extId) @@ -35,12 +37,12 @@ func (h *StoreLogEventHandler) Handle(cmd *StoreLogEvent) { switch cmd.Level { case "info": - logging.WithFields(context).Info(cmd.Message) + logging.FromContext(ctx).WithFields(context).Info(cmd.Message) case "warning": - logging.WithFields(context).Warning(cmd.Message) + logging.FromContext(ctx).WithFields(context).Warning(cmd.Message) case "error": - logging.WithFields(context).Error(cmd.Message) + logging.FromContext(ctx).WithFields(context).Error(cmd.Message) case "debug": - logging.WithFields(context).Debug(cmd.Message) + logging.FromContext(ctx).WithFields(context).Debug(cmd.Message) } } diff --git a/internal/api/browser_extension/app/security/middleware.go b/internal/api/browser_extension/app/security/middleware.go index 3b35fee..54c054d 100644 --- a/internal/api/browser_extension/app/security/middleware.go +++ b/internal/api/browser_extension/app/security/middleware.go @@ -33,7 +33,7 @@ func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter limitReached := rateLimiter.Test(c, key, rate) if limitReached { - logging.WithFields(logging.Fields{ + logging.FromContext(c.Request.Context()).WithFields(logging.Fields{ "type": "security", "uri": c.Request.URL.String(), "browser_extension_id": extensionId, diff --git a/internal/api/browser_extension/ports/http.go b/internal/api/browser_extension/ports/http.go index 7c095c3..6f7617b 100644 --- a/internal/api/browser_extension/ports/http.go +++ b/internal/api/browser_extension/ports/http.go @@ -43,7 +43,7 @@ func (r *RoutesHandler) Log(c *gin.Context) { return } - r.cqrs.Commands.StoreLogEvent.Handle(cmd) + r.cqrs.Commands.StoreLogEvent.Handle(c.Request.Context(), cmd) c.JSON(200, api.NewOk("Log has been stored")) } @@ -311,7 +311,7 @@ func (r *RoutesHandler) Request2FaToken(c *gin.Context) { return } - err = r.cqrs.Commands.Request2FaToken.Handle(cmd) + err = r.cqrs.Commands.Request2FaToken.Handle(c.Request.Context(), cmd) if err != nil { c.JSON(500, api.NewInternalServerError(err)) diff --git a/internal/api/mobile/app/command/send_2fa_token.go b/internal/api/mobile/app/command/send_2fa_token.go index d4f2834..d87346d 100644 --- a/internal/api/mobile/app/command/send_2fa_token.go +++ b/internal/api/mobile/app/command/send_2fa_token.go @@ -1,9 +1,12 @@ package command import ( + "context" "fmt" + "github.com/avast/retry-go/v4" "github.com/google/uuid" + "github.com/twofas/2fas-server/internal/api/browser_extension/domain" "github.com/twofas/2fas-server/internal/api/mobile/adapters" "github.com/twofas/2fas-server/internal/common/logging" @@ -41,15 +44,14 @@ type Send2FaTokenHandler struct { WebsocketClient *websocket.WebsocketApiClient } -func (h *Send2FaTokenHandler) Handle(cmd *Send2FaToken) error { +func (h *Send2FaTokenHandler) Handle(ctx context.Context, cmd *Send2FaToken) error { extId, _ := uuid.Parse(cmd.ExtensionId) - - logging.WithFields(logging.Fields{ + log := logging.FromContext(ctx).WithFields(logging.Fields{ "browser_extension_id": cmd.ExtensionId, "device_id": cmd.DeviceId, - "token": cmd.Token, "token_request_id": cmd.TokenRequestId, - }).Info("Start command `Send2FaToken`") + }) + log.Info("Start command `Send2FaToken`") browserExtension, err := h.BrowserExtensionsRepository.FindById(extId) @@ -70,9 +72,9 @@ func (h *Send2FaTokenHandler) Handle(cmd *Send2FaToken) error { ) if err != nil { - logging.WithFields(logging.Fields{ - "error": err.Error(), - "message": message, + log.WithFields(logging.Fields{ + "error": err.Error(), + "websocket_message": message, }).Error("Cannot send websocket message") } diff --git a/internal/api/mobile/app/security/middleware.go b/internal/api/mobile/app/security/middleware.go index eda05fd..ccf589d 100644 --- a/internal/api/mobile/app/security/middleware.go +++ b/internal/api/mobile/app/security/middleware.go @@ -38,7 +38,7 @@ func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter, rateLimitV limitReached := rateLimiter.Test(c, key, rate) if limitReached { - logging.WithFields(logging.Fields{ + logging.FromContext(c.Request.Context()).WithFields(logging.Fields{ "type": "security", "uri": c.Request.URL.String(), "device_id": deviceId, diff --git a/internal/api/mobile/ports/http.go b/internal/api/mobile/ports/http.go index 0734ddd..7155fa0 100644 --- a/internal/api/mobile/ports/http.go +++ b/internal/api/mobile/ports/http.go @@ -2,6 +2,7 @@ package ports import ( "errors" + "github.com/gin-gonic/gin" "github.com/go-playground/validator/v10" "github.com/google/uuid" @@ -289,7 +290,7 @@ func (r *RoutesHandler) Send2FaToken(c *gin.Context) { return } - err = r.cqrs.Commands.Send2FaToken.Handle(cmd) + err = r.cqrs.Commands.Send2FaToken.Handle(c.Request.Context(), cmd) if err != nil { c.JSON(500, api.NewInternalServerError(err)) diff --git a/internal/common/http/client.go b/internal/common/http/client.go deleted file mode 100644 index debc35b..0000000 --- a/internal/common/http/client.go +++ /dev/null @@ -1,147 +0,0 @@ -package http - -import ( - "bytes" - "context" - "encoding/json" - "github.com/twofas/2fas-server/internal/common/logging" - "io" - "io/ioutil" - "net" - "net/http" - "net/url" - "time" -) - -var tunedHttpTransport = &http.Transport{ - MaxIdleConns: 1024, - MaxIdleConnsPerHost: 1024, - TLSHandshakeTimeout: 10 * time.Second, - DialContext: (&net.Dialer{ - Timeout: 60 * time.Second, - KeepAlive: 60 * time.Second, - }).DialContext, -} - -type HttpClient struct { - client *http.Client - baseUrl *url.URL - credentialsCallback func(r *http.Request) -} - -func (w *HttpClient) CredentialsProvider(credentialsCallback func(r *http.Request)) { - w.credentialsCallback = credentialsCallback -} - -func (w *HttpClient) Post(ctx context.Context, path string, result interface{}, data interface{}) error { - req, err := w.newJsonRequest("POST", path, data) - - if err != nil { - return err - } - - return w.executeRequest(ctx, req, result) -} - -func (w *HttpClient) newJsonRequest(method, path string, body interface{}) (*http.Request, error) { - var buf io.ReadWriter - - logging.WithFields(logging.Fields{ - "method": method, - "body": body, - }).Debug("HTTP Request") - - if body != nil { - buf = new(bytes.Buffer) - - encoder := json.NewEncoder(buf) - err := encoder.Encode(body) - - if err != nil { - return nil, err - } - } - - return w.newRequest(method, path, buf, "application/json") -} - -func (w *HttpClient) newRequest(method, path string, buf io.Reader, contentType string) (*http.Request, error) { - u, err := w.baseUrl.Parse(path) - - if err != nil { - return nil, err - } - - req, err := http.NewRequest(method, u.String(), buf) - - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", contentType) - - return req, nil -} - -func (w *HttpClient) executeRequest(ctx context.Context, req *http.Request, v interface{}) error { - req = req.WithContext(ctx) - - if w.credentialsCallback != nil { - w.credentialsCallback(req) - } - - resp, err := w.client.Do(req) - - if err != nil { - return err - } - - defer resp.Body.Close() - - responseData, err := w.checkError(resp) - - if err != nil { - return err - } - - if v == nil { - return nil - } - - responseDataReader := bytes.NewReader(responseData) - - err = json.NewDecoder(responseDataReader).Decode(v) - - return err -} - -func (w *HttpClient) checkError(r *http.Response) ([]byte, error) { - errorResponse := &ErrorResponse{} - - responseData, err := ioutil.ReadAll(r.Body) - - if err == nil && responseData != nil { - json.Unmarshal(responseData, errorResponse) - } - - if httpCode := r.StatusCode; 200 <= httpCode && httpCode <= 300 { - return responseData, nil - } - - errorResponse.Status = r.StatusCode - - return responseData, errorResponse -} - -func NewHttpClient(baseUrl string) *HttpClient { - clientBaseUrl, err := url.Parse(baseUrl) - - if err != nil { - panic(err) - } - - return &HttpClient{ - client: &http.Client{Transport: tunedHttpTransport}, - baseUrl: clientBaseUrl, - } -} diff --git a/internal/common/http/log.go b/internal/common/http/log.go index 7480d2e..c6edadf 100644 --- a/internal/common/http/log.go +++ b/internal/common/http/log.go @@ -5,9 +5,37 @@ import ( "io" "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/twofas/2fas-server/internal/common/logging" ) +const ( + CorrelationIdHeader = "X-Correlation-ID" +) + +var ( + RequestId string + CorrelationId string +) + +func LoggingMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + requestId := uuid.New().String() + correlationId := c.Request.Header.Get(CorrelationIdHeader) + if correlationId == "" { + correlationId = uuid.New().String() + } + + ctxWithLog := logging.AddToContext(c.Request.Context(), logging.WithFields(map[string]any{ + "correlation_id": correlationId, + "request_id": requestId, + })) + c.Request = c.Request.WithContext(ctxWithLog) + + } +} + func RequestJsonLogger() gin.HandlerFunc { return func(c *gin.Context) { var buf bytes.Buffer @@ -17,22 +45,19 @@ func RequestJsonLogger() gin.HandlerFunc { c.Request.Body = io.NopCloser(&buf) - logging.WithFields(logging.Fields{ - "method": c.Request.Method, - "path": c.Request.URL.Path, - "headers": c.Request.Header, - "body": string(body), - "request_id": c.GetString(RequestIdKey), - "correlation_id": c.GetString(CorrelationIdKey), + log := logging.FromContext(c.Request.Context()) + + log.WithFields(logging.Fields{ + "method": c.Request.Method, + "path": c.Request.URL.Path, + "body": string(body), }).Info("Request") c.Next() - logging.WithFields(logging.Fields{ - "method": c.Request.Method, - "path": c.Request.URL.Path, - "request_id": c.GetString(RequestIdKey), - "correlation_id": c.GetString(CorrelationIdKey), + log.WithFields(logging.Fields{ + "method": c.Request.Method, + "path": c.Request.URL.Path, }).Info("Response") } } diff --git a/internal/common/http/request.go b/internal/common/http/request.go index 7f517ff..20f8588 100644 --- a/internal/common/http/request.go +++ b/internal/common/http/request.go @@ -1,50 +1,11 @@ package http import ( - "github.com/gin-gonic/gin" - "github.com/google/uuid" - "github.com/twofas/2fas-server/internal/common/logging" "net/http" + + "github.com/gin-gonic/gin" ) -const ( - RequestIdKey = "request_id" - - CorrelationIdKey = "correlation_id" - CorrelationIdHeader = "X-Correlation-ID" -) - -var ( - RequestId string - CorrelationId string -) - -func RequestIdMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - RequestId = uuid.New().String() - - c.Set(RequestIdKey, RequestId) - - logging.WithDefaultField(RequestIdKey, RequestId) - } -} - -func CorrelationIdMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - c.Set(CorrelationIdKey, uuid.New().String()) - - CorrelationId = c.Request.Header.Get(CorrelationIdHeader) - - if CorrelationId == "" { - CorrelationId = uuid.New().String() - } - - logging.WithDefaultField(CorrelationIdKey, CorrelationId) - - c.Set(CorrelationIdKey, CorrelationId) - } -} - func BodySizeLimitMiddleware(requestBytesLimit int64) gin.HandlerFunc { return func(c *gin.Context) { var w http.ResponseWriter = c.Writer diff --git a/internal/common/http/server.go b/internal/common/http/server.go index 6fbe1b6..f8ea649 100644 --- a/internal/common/http/server.go +++ b/internal/common/http/server.go @@ -9,8 +9,7 @@ func RunHttpServer(addr string, init func(engine *gin.Engine)) { router.Use(gin.Recovery()) router.Use(corsMiddleware()) - router.Use(RequestIdMiddleware()) - router.Use(CorrelationIdMiddleware()) + router.Use(LoggingMiddleware()) router.Use(RequestJsonLogger()) init(router) diff --git a/internal/common/logging/logger.go b/internal/common/logging/logger.go index 457b55a..b73736b 100644 --- a/internal/common/logging/logger.go +++ b/internal/common/logging/logger.go @@ -1,6 +1,7 @@ package logging import ( + "context" "encoding/json" "reflect" "sync" @@ -8,104 +9,92 @@ import ( "github.com/sirupsen/logrus" ) -// TODO: do not log reuse on every request. -type Fields map[string]interface{} +type Fields = logrus.Fields -var ( - customLogger = New() - defaultFields = logrus.Fields{} - defaultFieldsMutex = sync.RWMutex{} -) - -func New() *logrus.Logger { - logger := logrus.New() - - logger.SetFormatter(&logrus.JSONFormatter{ - FieldMap: logrus.FieldMap{ - logrus.FieldKeyTime: "timestamp", - logrus.FieldKeyLevel: "level", - logrus.FieldKeyMsg: "message", - }, - }) - - logger.SetLevel(logrus.InfoLevel) - - return logger +type FieldLogger interface { + logrus.FieldLogger } -func WithDefaultField(key, value string) *logrus.Logger { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() +var ( + log logrus.FieldLogger + once sync.Once +) - defaultFields[key] = value +// Init initialize global instance of logging library. +func Init(fields Fields) FieldLogger { + once.Do(func() { + logger := logrus.New() - return customLogger + logger.SetFormatter(&logrus.JSONFormatter{ + FieldMap: logrus.FieldMap{ + logrus.FieldKeyTime: "timestamp", + logrus.FieldKeyLevel: "level", + logrus.FieldKeyMsg: "message", + }, + }) + + logger.SetLevel(logrus.InfoLevel) + + log = logger.WithFields(fields) + }) + return log +} + +type ctxKey int + +const ( + loggerKey ctxKey = iota +) + +func AddToContext(ctx context.Context, log FieldLogger) context.Context { + return context.WithValue(ctx, loggerKey, log) +} + +func FromContext(ctx context.Context) FieldLogger { + log, ok := ctx.Value(loggerKey).(FieldLogger) + if ok { + return log + } + return log +} + +func WithFields(fields Fields) FieldLogger { + return log.WithFields(fields) +} + +func WithField(key string, value any) FieldLogger { + return log.WithField(key, value) } func Info(args ...interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - customLogger.WithFields(defaultFields).Info(args...) + log.Info(args...) } func Infof(format string, args ...interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - customLogger.WithFields(defaultFields).Infof(format, args...) + log.Infof(format, args...) } func Error(args ...interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - customLogger.WithFields(defaultFields).Error(args...) + log.Error(args...) } func Errorf(format string, args ...interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - customLogger.WithFields(defaultFields).Errorf(format, args...) + log.Errorf(format, args...) } func Warning(args ...interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - customLogger.WithFields(defaultFields).Warning(args...) + log.Warning(args...) } func Fatal(args ...interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - customLogger.WithFields(defaultFields).Fatal(args...) + log.Fatal(args...) } -func WithField(key string, value interface{}) *logrus.Entry { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - return customLogger. - WithFields(defaultFields). - WithField(key, value) -} - -func WithFields(fields Fields) *logrus.Entry { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - return customLogger. - WithFields(logrus.Fields(fields)). - WithFields(defaultFields) +func Fatalf(format string, args ...interface{}) { + log.Fatalf(format, args...) } func LogCommand(command interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - context, _ := json.Marshal(command) commandName := reflect.TypeOf(command).Elem().Name() @@ -113,8 +102,7 @@ func LogCommand(command interface{}) { var commandAsFields logrus.Fields json.Unmarshal(context, &commandAsFields) - customLogger. - WithFields(defaultFields). + log. WithFields(logrus.Fields{ "command_name": commandName, "command": commandAsFields, @@ -122,13 +110,8 @@ func LogCommand(command interface{}) { } func LogCommandFailed(command interface{}, err error) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - commandName := reflect.TypeOf(command).Elem().Name() - - customLogger. - WithFields(defaultFields). + log. WithFields(logrus.Fields{ "reason": err.Error(), }).Info("Command failed" + commandName) diff --git a/internal/common/rate_limit/redis_rate_limit.go b/internal/common/rate_limit/redis_rate_limit.go index 18c2882..bfe9036 100644 --- a/internal/common/rate_limit/redis_rate_limit.go +++ b/internal/common/rate_limit/redis_rate_limit.go @@ -38,7 +38,7 @@ func (r *RedisRateLimit) Test(ctx context.Context, key string, rate Rate) bool { Period: rate.TimeUnit, }) if err != nil { - logging.WithFields(logging.Fields{ + logging.FromContext(ctx).WithFields(logging.Fields{ "type": "security", }).Warnf("Could not check rate limit: %v", err) diff --git a/internal/common/recovery/gin.go b/internal/common/recovery/gin.go index b07b16c..a2c8750 100644 --- a/internal/common/recovery/gin.go +++ b/internal/common/recovery/gin.go @@ -3,10 +3,11 @@ package recovery import ( "bytes" "fmt" - "github.com/gin-gonic/gin" - "github.com/twofas/2fas-server/internal/common/logging" "io/ioutil" "runtime" + + "github.com/gin-gonic/gin" + "github.com/twofas/2fas-server/internal/common/logging" ) func RecoveryMiddleware() gin.HandlerFunc { diff --git a/internal/common/security/middleware.go b/internal/common/security/middleware.go index f02dc99..b1a35a4 100644 --- a/internal/common/security/middleware.go +++ b/internal/common/security/middleware.go @@ -30,7 +30,7 @@ func IPAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter, rateLimitValue i limitReached := rateLimiter.Test(c, key, rate) if limitReached { - logging.WithFields(logging.Fields{ + logging.FromContext(c.Request.Context()).WithFields(logging.Fields{ "type": "security", "uri": c.Request.URL.String(), "ip": c.ClientIP(), diff --git a/internal/common/websocket/gorilla_websocket_client.go b/internal/common/websocket/gorilla_websocket_client.go index 5d2682b..e7903b5 100644 --- a/internal/common/websocket/gorilla_websocket_client.go +++ b/internal/common/websocket/gorilla_websocket_client.go @@ -2,12 +2,15 @@ package websocket import ( "encoding/json" - "github.com/gorilla/websocket" - app_http "github.com/twofas/2fas-server/internal/common/http" - "github.com/twofas/2fas-server/internal/common/logging" + "fmt" "net/http" "net/url" "path" + + "github.com/gorilla/websocket" + + app_http "github.com/twofas/2fas-server/internal/common/http" + "github.com/twofas/2fas-server/internal/common/logging" ) type WebsocketApiClient struct { @@ -21,13 +24,15 @@ func NewWebsocketApiClient(websocketApiUrl string) *WebsocketApiClient { } func (ws *WebsocketApiClient) SendMessage(uri string, message interface{}) error { - u, _ := url.Parse(ws.wsAddr) + u, err := url.Parse(ws.wsAddr) + if err != nil { + return fmt.Errorf("failed to parse %q: %w", ws.wsAddr, err) + } u.Path = path.Join(u.Path, uri) msg, err := json.Marshal(message) - if err != nil { - return err + return fmt.Errorf("failed to marshal message: %w", err) } logging.WithFields(logging.Fields{ @@ -40,23 +45,20 @@ func (ws *WebsocketApiClient) SendMessage(uri string, message interface{}) error } c, _, err := websocket.DefaultDialer.Dial(u.String(), requestHeaders) - if err != nil { - return err + return fmt.Errorf("failed to dial: %q: %w", u.String(), err) } err = c.WriteMessage(websocket.TextMessage, msg) - if err != nil { logging.WithField("error", err.Error()).Error("Cannot send websocket message") - return err + return fmt.Errorf("failed to write message to the conection: %w", err) } err = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) - if err != nil { logging.WithField("error", err.Error()).Error("Cannot close websocket connection") - return err + return fmt.Errorf("failed to write close message to the conection: %w", err) } return nil diff --git a/internal/pass/pairing/pairing.go b/internal/pass/pairing/pairing.go index dde520b..eb54b03 100644 --- a/internal/pass/pairing/pairing.go +++ b/internal/pass/pairing/pairing.go @@ -7,8 +7,6 @@ import ( "time" "github.com/gorilla/websocket" - "github.com/sirupsen/logrus" - "github.com/twofas/2fas-server/internal/common/logging" "github.com/twofas/2fas-server/internal/pass/connection" "github.com/twofas/2fas-server/internal/pass/sign" @@ -128,7 +126,7 @@ func (p *Pairing) ServePairingWS(w http.ResponseWriter, r *http.Request, extID s } } -func (p *Pairing) isExtensionPaired(ctx context.Context, extID string, log *logrus.Entry) (PairingInfo, bool) { +func (p *Pairing) isExtensionPaired(ctx context.Context, extID string, log logging.FieldLogger) (PairingInfo, bool) { pairingInfo, err := p.store.GetPairingInfo(ctx, extID) if err != nil { log.Warn("Failed to get pairing info") diff --git a/internal/pass/server.go b/internal/pass/server.go index 3f83373..b3db1f5 100644 --- a/internal/pass/server.go +++ b/internal/pass/server.go @@ -59,9 +59,7 @@ func NewServer(cfg config.PassConfig) *Server { router := gin.New() router.Use(recovery.RecoveryMiddleware()) - router.Use(httphelpers.RequestIdMiddleware()) - router.Use(httphelpers.CorrelationIdMiddleware()) - // TODO: don't log auth headers. + router.Use(httphelpers.LoggingMiddleware()) router.Use(httphelpers.RequestJsonLogger()) router.GET("/health", func(context *gin.Context) { diff --git a/internal/pass/sign/lib_test.go b/internal/pass/sign/lib_test.go index 7eabe6f..8f95c29 100644 --- a/internal/pass/sign/lib_test.go +++ b/internal/pass/sign/lib_test.go @@ -8,10 +8,6 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kms" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" ) @@ -70,16 +66,6 @@ func createTestService(t *testing.T) Service { } func TestSignAndVerify(t *testing.T) { - sess, err := session.NewSession(&aws.Config{ - Region: aws.String("us-east-1"), - Credentials: credentials.NewStaticCredentials("test", "test", ""), - S3ForcePathStyle: aws.Bool(true), - Endpoint: aws.String("http://localhost:4566"), - }) - if err != nil { - t.Fatal(err) - } - kmsClient := kms.New(sess) srv := createTestService(t) now := time.Now() @@ -128,18 +114,7 @@ func TestSignAndVerify(t *testing.T) { { name: "invalid signature", tokenFn: func() string { - resp, err := kmsClient.CreateKey(&kms.CreateKeyInput{ - KeySpec: aws.String("ECC_NIST_P256"), - KeyUsage: aws.String("SIGN_VERIFY"), - }) - if err != nil { - t.Fatal(err) - } - serviceWithAnotherKey, err := NewService(*resp.KeyMetadata.KeyId, kmsClient) - if err != nil { - t.Fatal(err) - } - + serviceWithAnotherKey := createTestService(t) token, err := serviceWithAnotherKey.SignAndEncode(Message{ ConnectionID: uuid.New().String(), ExpiresAt: now.Add(-time.Hour), diff --git a/internal/pass/sync/sync.go b/internal/pass/sync/sync.go index 980aea2..fb02dba 100644 --- a/internal/pass/sync/sync.go +++ b/internal/pass/sync/sync.go @@ -75,7 +75,7 @@ func (s *Syncing) ServeSyncingRequestWS(w http.ResponseWriter, r *http.Request, if err := s.sendTokenAndCloseConn(fcmToken, conn); err != nil { log.Errorf("Failed to send token: %v", err) } - log.Infof("Paring ws finished") + log.Infof("Sync ws finished") return nil } @@ -89,7 +89,7 @@ func (s *Syncing) ServeSyncingRequestWS(w http.ResponseWriter, r *http.Request, for { select { case <-maxWaitC: - log.Info("Closing paring ws after timeout") + log.Info("Closing sync ws after timeout") return nil case <-connectedCheckTicker.C: if syncConfirmed := s.isSyncConfirmed(r.Context(), fcmToken); syncConfirmed { @@ -97,7 +97,7 @@ func (s *Syncing) ServeSyncingRequestWS(w http.ResponseWriter, r *http.Request, log.Errorf("Failed to send token: %v", err) return nil } - log.Infof("Paring ws finished") + log.Infof("Sync ws finished") return nil } } diff --git a/internal/websocket/app.go b/internal/websocket/app.go index f81787f..95b161f 100644 --- a/internal/websocket/app.go +++ b/internal/websocket/app.go @@ -19,8 +19,7 @@ func NewServer(addr string) *Server { router := gin.New() router.Use(recovery.RecoveryMiddleware()) - router.Use(http.RequestIdMiddleware()) - router.Use(http.CorrelationIdMiddleware()) + router.Use(http.LoggingMiddleware()) router.Use(http.RequestJsonLogger()) connectionHandler := common.NewConnectionHandler() diff --git a/internal/websocket/common/client.go b/internal/websocket/common/client.go index d1af2b6..db1ffe2 100644 --- a/internal/websocket/common/client.go +++ b/internal/websocket/common/client.go @@ -2,9 +2,11 @@ package common import ( "bytes" + "sync" "time" "github.com/gorilla/websocket" + "github.com/twofas/2fas-server/internal/common/logging" ) @@ -43,6 +45,8 @@ type Client struct { // Buffered channel of outbound messages. send chan []byte + + sendMtx *sync.Mutex } // readPump pumps messages from the websocket connection to the hub. @@ -50,7 +54,7 @@ type Client struct { // The application runs readPump in a per-connection goroutine. The application // ensures that there is at most one reader on a connection by executing all // reads from this goroutine. -func (c *Client) readPump() { +func (c *Client) readPump(log logging.FieldLogger) { defer func() { c.hub.unregisterClient(c) c.conn.Close() @@ -69,11 +73,11 @@ func (c *Client) readPump() { if err != nil { if websocket.IsUnexpectedCloseError(err, acceptedCloseStatus...) { - logging.WithFields(logging.Fields{ + log.WithFields(logging.Fields{ "reason": err.Error(), }).Error("Websocket connection closed unexpected") } else { - logging.WithFields(logging.Fields{ + log.WithFields(logging.Fields{ "reason": err.Error(), }).Info("Connection closed") } @@ -133,3 +137,27 @@ func (c *Client) writePump() { } } } + +func (c *Client) sendMsg(bb []byte) bool { + c.sendMtx.Lock() + defer c.sendMtx.Unlock() + + if c.send == nil { + return false + } + + c.send <- bb + return true +} + +func (c *Client) close() { + c.sendMtx.Lock() + defer c.sendMtx.Unlock() + + if c.send == nil { + return + } + + close(c.send) + c.send = nil +} diff --git a/internal/websocket/common/handler.go b/internal/websocket/common/handler.go index 7085283..965623b 100644 --- a/internal/websocket/common/handler.go +++ b/internal/websocket/common/handler.go @@ -42,18 +42,18 @@ func (h *ConnectionHandler) Handler() gin.HandlerFunc { return func(c *gin.Context) { channel := c.Request.URL.Path - logging.WithDefaultField("channel", channel) + log := logging.FromContext(c.Request.Context()).WithField("channel", channel) - logging.Info("New channel subscriber") + log.Info("New channel subscriber") - h.serveWs(c.Writer, c.Request, channel) + h.serveWs(c.Writer, c.Request, channel, log) } } -func (h *ConnectionHandler) serveWs(w http.ResponseWriter, r *http.Request, channel string) { +func (h *ConnectionHandler) serveWs(w http.ResponseWriter, r *http.Request, channel string, log logging.FieldLogger) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { - logging.Errorf("Failed to upgrade connection: %v", err) + log.Errorf("Failed to upgrade connection: %v", err) w.WriteHeader(http.StatusInternalServerError) return } @@ -65,7 +65,7 @@ func (h *ConnectionHandler) serveWs(w http.ResponseWriter, r *http.Request, chan }) go recovery.DoNotPanic(func() { - client.readPump() + client.readPump(log) }) go recovery.DoNotPanic(func() { @@ -73,7 +73,7 @@ func (h *ConnectionHandler) serveWs(w http.ResponseWriter, r *http.Request, chan timeout := time.After(disconnectAfter) <-timeout - logging.Info("Connection closed after", disconnectAfter) + log.Info("Connection closed after", disconnectAfter) client.hub.unregisterClient(client) client.conn.Close() diff --git a/internal/websocket/common/hub.go b/internal/websocket/common/hub.go index fb4ec31..54eb01a 100644 --- a/internal/websocket/common/hub.go +++ b/internal/websocket/common/hub.go @@ -28,7 +28,7 @@ func (h *Hub) unregisterClient(c *Client) { if !ok { return } - close(c.send) + c.close() if h.isEmpty() { h.onHubHasNoClients(h.id) } @@ -39,9 +39,8 @@ func (h *Hub) sendToClient(c *Client, msg []byte) { if !ok { return } - select { - case c.send <- msg: - default: + ok = c.sendMsg(msg) + if !ok { h.unregisterClient(c) } } diff --git a/internal/websocket/common/hub_pool.go b/internal/websocket/common/hub_pool.go index 10b6958..70b5244 100644 --- a/internal/websocket/common/hub_pool.go +++ b/internal/websocket/common/hub_pool.go @@ -29,7 +29,7 @@ func (h *hubPool) registerClient(channel string, conn *websocket.Conn) (*Client, defer h.mtx.Unlock() hub := h.getOrCreateHub(channel) - client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} + client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256), sendMtx: &sync.Mutex{}} hub.registerClient(client) // handler (caller of this method) isn't really interested in hub, diff --git a/internal/websocket/common/hub_pool_test.go b/internal/websocket/common/hub_pool_test.go index 403fd82..5a0154b 100644 --- a/internal/websocket/common/hub_pool_test.go +++ b/internal/websocket/common/hub_pool_test.go @@ -50,27 +50,44 @@ func TestCreateRemoveConcurrently(t *testing.T) { hp := newHubPool() const channelsNo = 100 const clientsPerChannel = 1000 + const messagesSentToEachHub = 100 hubs := &sync.Map{} wg := sync.WaitGroup{} + // First we create `channelsNo` goroutines. Each of them creates `clientsPerChannel` sub-goroutines. + // This gives us `channelsNo*clientsPerChannel` sub go-routines and `channelsNo` parent goroutines. + // Each of them will call `wg.Done() once and we can't progress until all of them are done. + wg.Add(channelsNo*clientsPerChannel + channelsNo) + // We will close `channelsNo*clientsPerChannel + channelsNo` clients. We create fakeReadPump for each of them and + // wait for it to finish. wg.Add(channelsNo * clientsPerChannel) + for i := 0; i < channelsNo; i++ { channelID := fmt.Sprintf("channel-%d", i) + + c, h := hp.registerClient(channelID, &websocket.Conn{}) + hubs.Store(h, struct{}{}) + go fakeReadPump(c.send, &wg) go func() { + for i := 0; i < messagesSentToEachHub; i++ { + h.broadcastMsg([]byte("test")) + } + }() + + go func() { + defer wg.Done() for j := 0; j < clientsPerChannel; j++ { c, h := hp.registerClient(channelID, &websocket.Conn{}) - hubs.Store(h, struct{}{}) + go fakeReadPump(c.send, &wg) + go func() { h.unregisterClient(c) wg.Done() }() } - _, h := hp.registerClient(channelID, &websocket.Conn{}) - hubs.Store(h, struct{}{}) }() } - wg.Wait() for c, hub := range hp.hubs { @@ -89,3 +106,9 @@ func TestCreateRemoveConcurrently(t *testing.T) { return true }) } + +func fakeReadPump(c chan []byte, wg *sync.WaitGroup) { + defer wg.Done() + for range c { + } +} diff --git a/tests/browser_extension/browser_extension_2fa_request_test.go b/tests/browser_extension/browser_extension_2fa_request_test.go deleted file mode 100644 index 0b638d6..0000000 --- a/tests/browser_extension/browser_extension_2fa_request_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package tests - -import ( - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" - "github.com/twofas/2fas-server/tests" -) - -func TestBrowserExtensionTwoFactorAuthTestSuite(t *testing.T) { - suite.Run(t, new(BrowserExtensionTwoFactorAuthTestSuite)) -} - -type BrowserExtensionTwoFactorAuthTestSuite struct { - suite.Suite -} - -func (s *BrowserExtensionTwoFactorAuthTestSuite) SetupTest() { - tests.RemoveAllMobileDevices(s.T()) - tests.RemoveAllBrowserExtensions(s.T()) - tests.RemoveAllBrowserExtensionsDevices(s.T()) -} - -func (s *BrowserExtensionTwoFactorAuthTestSuite) TestRequest2FaToken() { - browserExtension := tests.CreateBrowserExtension(s.T(), "go-ext") - - var tokenRequest *tests.AuthTokenRequestResponse - request2FaTokenPayload := []byte(`{"domain":"https://facebook.com/path/nested"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", request2FaTokenPayload, &tokenRequest) - - assert.Equal(s.T(), browserExtension.Id, tokenRequest.ExtensionId) - - var tokenRequestById *tests.AuthTokenRequestResponse - tests.DoAPISuccessGet(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id, &tokenRequestById) - assert.Equal(s.T(), tokenRequest.Id, tokenRequestById.Id) - assert.Equal(s.T(), "https://facebook.com", tokenRequestById.Domain) -} - -func (s *BrowserExtensionTwoFactorAuthTestSuite) TestFindAll2FaRequestsForBrowserExtension() { - browserExtension := tests.CreateBrowserExtension(s.T(), "go-ext") - - facebook2FaTokenRequest := []byte(`{"domain":"facebook.com"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", facebook2FaTokenRequest, nil) - - google2FaTokenRequest := []byte(`{"domain":"google.com"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", google2FaTokenRequest, nil) - - var tokenRequestsCollection []*tests.AuthTokenRequestResponse - tests.DoAPISuccessGet(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests", &tokenRequestsCollection) - - assert.Len(s.T(), tokenRequestsCollection, 2) -} - -func (s *BrowserExtensionTwoFactorAuthTestSuite) TestClose2FaTokenRequest() { - var tokenRequest *tests.AuthTokenRequestResponse - browserExtension := tests.CreateBrowserExtension(s.T(), "go-ext") - tokenRequestPayload := []byte(`{"domain":"facebook.com"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", tokenRequestPayload, &tokenRequest) - closeTokenRequestPayload := []byte(`{"status":"completed"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id+"/commands/close_2fa_request", closeTokenRequestPayload, nil) - - var closedTokenRequest *tests.AuthTokenRequestResponse - tests.DoAPISuccessGet(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id, &closedTokenRequest) - assert.Equal(s.T(), "completed", closedTokenRequest.Status) -} - -func (s *BrowserExtensionTwoFactorAuthTestSuite) TestCloseNotExisting2FaTokenRequest() { - notExistingTokenRequestId := uuid.New() - browserExtension := tests.CreateBrowserExtension(s.T(), "go-ext") - - closeTokenRequestPayload := []byte(`{"status":"completed"}`) - tests.DoAPIPostAndAssertCode(s.T(), 404, "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+notExistingTokenRequestId.String()+"/commands/close_2fa_request", closeTokenRequestPayload, nil) - -} - -func (s *BrowserExtensionTwoFactorAuthTestSuite) TestDoNotReturnClosed2FaRequests() { - var tokenRequest *tests.AuthTokenRequestResponse - browserExtension := tests.CreateBrowserExtension(s.T(), "go-ext") - tokenRequestPayload := []byte(`{"domain":"facebook.com"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", tokenRequestPayload, &tokenRequest) - - closeTokenRequestPayload := []byte(`{"status":"completed"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id+"/commands/close_2fa_request", closeTokenRequestPayload, nil) - - var response []*tests.AuthTokenRequestResponse - tests.DoAPISuccessGet(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests", &response) - assert.Len(s.T(), response, 0) -} - -func (s *BrowserExtensionTwoFactorAuthTestSuite) TestTerminate2FaRequest() { - var tokenRequest *tests.AuthTokenRequestResponse - browserExtension := tests.CreateBrowserExtension(s.T(), "go-ext") - tokenRequestPayload := []byte(`{"domain":"facebook.com"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", tokenRequestPayload, &tokenRequest) - - closeTokenRequestPayload := []byte(`{"status":"terminated"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id+"/commands/close_2fa_request", closeTokenRequestPayload, nil) - - var response *tests.AuthTokenRequestResponse - tests.DoAPISuccessGet(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id, &response) - assert.Equal(s.T(), "terminated", response.Status) -} - -func (s *BrowserExtensionTwoFactorAuthTestSuite) TestClose2FaRequest() { - device, devicePubKey := tests.CreateDevice(s.T(), "SM-955F", "fcm-token") - browserExtension := tests.CreateBrowserExtension(s.T(), "go-ext") - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExtension, device) - - var tokenRequest *tests.AuthTokenRequestResponse - request2FaTokenPayload := []byte(`{"domain":"domain.com"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", request2FaTokenPayload, &tokenRequest) - - closeTokenRequestPayload := []byte(`{"status":"completed"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id+"/commands/close_2fa_request", closeTokenRequestPayload, nil) - - var closedTokenRequest *tests.AuthTokenRequestResponse - tests.DoAPISuccessGet(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id, &closedTokenRequest) - assert.Equal(s.T(), "completed", closedTokenRequest.Status) -} diff --git a/tests/browser_extension/browser_extension_pairing_test.go b/tests/browser_extension/browser_extension_pairing_test.go deleted file mode 100644 index 3b7bc44..0000000 --- a/tests/browser_extension/browser_extension_pairing_test.go +++ /dev/null @@ -1,207 +0,0 @@ -package tests - -import ( - "encoding/json" - "net/http" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - "github.com/twofas/2fas-server/tests" -) - -func TestBrowserExtensionPairingTestSuite(t *testing.T) { - suite.Run(t, new(BrowserExtensionPairingTestSuite)) -} - -type BrowserExtensionPairingTestSuite struct { - suite.Suite -} - -func (s *BrowserExtensionPairingTestSuite) SetupTest() { - tests.RemoveAllBrowserExtensions(s.T()) - tests.RemoveAllBrowserExtensionsDevices(s.T()) -} - -func (s *BrowserExtensionPairingTestSuite) TestPairBrowserExtensionWithMobileDevice() { - browserExt := tests.CreateBrowserExtension(s.T(), "go-test") - _, err := uuid.Parse(browserExt.Id) - require.NoError(s.T(), err) - - device, devicePubKey := tests.CreateDevice(s.T(), "go-test-device", "some-device-id") - _, err = uuid.Parse(device.Id) - require.NoError(s.T(), err) - - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt, device) - - var extensionDevice *tests.DevicePairedBrowserExtensionResponse - tests.DoAPISuccessGet(s.T(), "/browser_extensions/"+browserExt.Id+"/devices/"+device.Id, &extensionDevice) - - assert.Equal(s.T(), extensionDevice.Id, device.Id) -} - -func (s *BrowserExtensionPairingTestSuite) TestDoNotFindNotPairedBrowserExtensionMobileDevice() { - browserExt := tests.CreateBrowserExtension(s.T(), "go-test") - _, err := uuid.Parse(browserExt.Id) - require.NoError(s.T(), err) - - device, _ := tests.CreateDevice(s.T(), "go-test-device", "some-device-id") - - response := tests.DoAPIGet(s.T(), "/browser_extensions/"+browserExt.Id+"/devices/"+device.Id, nil) - - assert.Equal(s.T(), 404, response.StatusCode) -} - -func (s *BrowserExtensionPairingTestSuite) TestPairBrowserExtensionWithMultipleDevices() { - browserExt := tests.CreateBrowserExtension(s.T(), "go-test") - _, err := uuid.Parse(browserExt.Id) - require.NoError(s.T(), err) - - device1, devicePubKey1 := tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id-1") - device2, devicePubKey2 := tests.CreateDevice(s.T(), "go-test-device-2", "some-device-id-2") - - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey1, browserExt, device1) - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey2, browserExt, device2) - - extensionDevices := tests.GetExtensionDevices(s.T(), browserExt.Id) - - assert.Len(s.T(), extensionDevices, 2) -} - -func (s *BrowserExtensionPairingTestSuite) TestRemoveBrowserExtensionPairedDevice() { - browserExt := tests.CreateBrowserExtension(s.T(), "go-test") - - device1, devicePubKey1 := tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id-1") - device2, devicePubKey2 := tests.CreateDevice(s.T(), "go-test-device-2", "some-device-id-2") - - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey1, browserExt, device1) - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey2, browserExt, device2) - - extensionDevices := getExtensionPairedDevices(s.T(), browserExt) - assert.Len(s.T(), extensionDevices, 2) - - tests.DoAPISuccessDelete(s.T(), "/browser_extensions/"+browserExt.Id+"/devices/"+device1.Id) - - extensionDevices = getExtensionPairedDevices(s.T(), browserExt) - assert.Len(s.T(), extensionDevices, 1) - assert.Equal(s.T(), device2.Id, extensionDevices[0].Id) -} - -func (s *BrowserExtensionPairingTestSuite) TestRemoveBrowserExtensionPairedDeviceTwice() { - browserExt := tests.CreateBrowserExtension(s.T(), "go-test") - - device, devicePubKey := tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id-1") - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt, device) - - tests.DoAPISuccessDelete(s.T(), "/browser_extensions/"+browserExt.Id+"/devices/"+device.Id) - response := tests.DoAPIRequest(s.T(), "/browser_extensions/"+browserExt.Id+"/devices/"+device.Id, http.MethodDelete, nil /*payload*/, nil /*resp*/) - - assert.Equal(s.T(), 404, response.StatusCode) -} - -func (s *BrowserExtensionPairingTestSuite) TestRemoveAllBrowserExtensionPairedDevices() { - browserExt := tests.CreateBrowserExtension(s.T(), "go-test") - device1, devicePubKey1 := tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id1") - device2, devicePubKey2 := tests.CreateDevice(s.T(), "go-test-device-2", "some-device-id2") - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey1, browserExt, device1) - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey2, browserExt, device2) - - tests.DoAPISuccessDelete(s.T(), "/browser_extensions/"+browserExt.Id+"/devices") - - extensionDevices := tests.GetExtensionDevices(s.T(), browserExt.Id) - assert.Len(s.T(), extensionDevices, 0) -} - -func (s *BrowserExtensionPairingTestSuite) TestGetPairedDevicesWhichIDoNotOwn() { - browserExt1 := tests.CreateBrowserExtension(s.T(), "go-test-1") - browserExt2 := tests.CreateBrowserExtension(s.T(), "go-test-2") - - device1, devicePubKey1 := tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id-1") - device2, devicePubKey2 := tests.CreateDevice(s.T(), "go-test-device-2", "some-device-id-2") - - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey1, browserExt1, device1) - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey2, browserExt2, device2) - - firstExtensionDevices := getExtensionPairedDevices(s.T(), browserExt1) - assert.Len(s.T(), firstExtensionDevices, 1) - assert.Equal(s.T(), device1.Id, firstExtensionDevices[0].Id) - - secondExtensionDevices := getExtensionPairedDevices(s.T(), browserExt2) - assert.Len(s.T(), secondExtensionDevices, 1) - assert.Equal(s.T(), device2.Id, secondExtensionDevices[0].Id) -} - -func (s *BrowserExtensionPairingTestSuite) TestGetPairedDevicesByInvalidExtensionId() { - browserExt1 := tests.CreateBrowserExtension(s.T(), "go-test-1") - browserExt2 := tests.CreateBrowserExtension(s.T(), "go-test-2") - - device1, devicePubKey1 := tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id-1") - device2, devicePubKey2 := tests.CreateDevice(s.T(), "go-test-device-2", "some-device-id-2") - - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey1, browserExt1, device1) - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey2, browserExt2, device2) - - invalidResp := map[string]any{} - response := tests.DoAPIGet(s.T(), "/browser_extensions/some-invalid-id/devices/", &invalidResp) - assert.Equal(s.T(), 400, response.StatusCode) - assert.Contains(s.T(), invalidResp["Reason"], `Field validation for 'ExtensionId' failed on the 'uuid4'`) -} - -func (s *BrowserExtensionPairingTestSuite) TestGetPairedDevicesByNotExistingExtensionId() { - browserExt1 := tests.CreateBrowserExtension(s.T(), "go-test-1") - browserExt2 := tests.CreateBrowserExtension(s.T(), "go-test-2") - - device1, devicePubKey1 := tests.CreateDevice(s.T(), "go-test-device-1", "some-device-id-1") - device2, devicePubKey2 := tests.CreateDevice(s.T(), "go-test-device-2", "some-device-id-2") - - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey1, browserExt1, device1) - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey2, browserExt2, device2) - - notExistingExtensionId := uuid.New() - var firstExtensionDevices []*tests.ExtensionPairedDeviceResponse - tests.DoAPISuccessGet(s.T(), "/browser_extensions/"+notExistingExtensionId.String()+"/devices/", &firstExtensionDevices) - assert.Len(s.T(), firstExtensionDevices, 0) -} - -func (s *BrowserExtensionPairingTestSuite) TestShareExtensionPublicKeyWithMobileDevice() { - browserExt := tests.CreateBrowserExtensionWithPublicKey(s.T(), "go-test", "b64-rsa-pub-key") - _, err := uuid.Parse(browserExt.Id) - require.NoError(s.T(), err) - - device, devicePubKey := tests.CreateDevice(s.T(), "go-test-device", "some-device-id") - _, err = uuid.Parse(device.Id) - require.NoError(s.T(), err) - - result := tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt, device) - assert.Equal(s.T(), "b64-rsa-pub-key", result.ExtensionPublicKey) -} - -func (s *BrowserExtensionPairingTestSuite) TestCannotPairSameDeviceAndExtensionTwice() { - browserExtension := tests.CreateBrowserExtensionWithPublicKey(s.T(), "go-test", "b64-rsa-pub-key") - device, devicePubKey := tests.CreateDevice(s.T(), "go-test-device", "some-device-id") - - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExtension, device) - - payload := struct { - ExtensionId string `json:"extension_id"` - DeviceName string `json:"device_name"` - DevicePublicKey string `json:"device_public_key"` - }{ - ExtensionId: browserExtension.Id, - DeviceName: device.Name, - DevicePublicKey: "device-pub-key", - } - - pairingResult := new(tests.PairingResultResponse) - payloadJson, _ := json.Marshal(payload) - - tests.DoAPIPostAndAssertCode(s.T(), 409, "/mobile/devices/"+device.Id+"/browser_extensions", payloadJson, pairingResult) -} - -func getExtensionPairedDevices(t *testing.T, browserExt *tests.BrowserExtensionResponse) []*tests.ExtensionPairedDeviceResponse { - var extensionDevices []*tests.ExtensionPairedDeviceResponse - tests.DoAPISuccessGet(t, "/browser_extensions/"+browserExt.Id+"/devices/", &extensionDevices) - return extensionDevices -} diff --git a/tests/mobile/mobile_browser_extensions_2fa_requests_test.go b/tests/mobile/mobile_browser_extensions_2fa_requests_test.go deleted file mode 100644 index 04114fc..0000000 --- a/tests/mobile/mobile_browser_extensions_2fa_requests_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package tests - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" - "github.com/twofas/2fas-server/tests" -) - -func TestMobileDeviceExtensionIntegrationTestSuite(t *testing.T) { - suite.Run(t, new(MobileDeviceExtensionIntegrationTestSuite)) -} - -type MobileDeviceExtensionIntegrationTestSuite struct { - suite.Suite -} - -func (s *MobileDeviceExtensionIntegrationTestSuite) SetupTest() { - tests.RemoveAllMobileDevices(s.T()) - tests.RemoveAllBrowserExtensions(s.T()) - tests.RemoveAllBrowserExtensionsDevices(s.T()) -} - -func (s *MobileDeviceExtensionIntegrationTestSuite) TestGetPending2FaRequests() { - device, devicePubKey := tests.CreateDevice(s.T(), "SM-955F", "fcm-token") - browserExtension := tests.CreateBrowserExtension(s.T(), "go-ext") - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExtension, device) - - var tokenRequest *tests.AuthTokenRequestResponse - request2FaTokenPayload := []byte(`{"domain":"domain.com"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", request2FaTokenPayload, &tokenRequest) - - var tokenRequestsCollection []*tests.AuthTokenRequestResponse - tests.DoAPISuccessGet(s.T(), "mobile/devices/"+device.Id+"/browser_extensions/2fa_requests", &tokenRequestsCollection) - assert.Len(s.T(), tokenRequestsCollection, 1) -} - -func (s *MobileDeviceExtensionIntegrationTestSuite) TestDoNotReturnCompleted2FaRequests() { - device, devicePubKey := tests.CreateDevice(s.T(), "SM-955F", "fcm-token") - browserExtension := tests.CreateBrowserExtension(s.T(), "go-ext") - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExtension, device) - - var tokenRequest *tests.AuthTokenRequestResponse - request2FaTokenPayload := []byte(`{"domain":"domain.com"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/commands/request_2fa_token", request2FaTokenPayload, &tokenRequest) - - closeTokenRequestPayload := []byte(`{"status":"completed"}`) - tests.DoAPISuccessPost(s.T(), "browser_extensions/"+browserExtension.Id+"/2fa_requests/"+tokenRequest.Id+"/commands/close_2fa_request", closeTokenRequestPayload, nil) - - var tokenRequestsCollection []*tests.AuthTokenRequestResponse - tests.DoAPISuccessGet(s.T(), "mobile/devices/"+device.Id+"/browser_extensions/2fa_requests", &tokenRequestsCollection) - assert.Len(s.T(), tokenRequestsCollection, 0) -} diff --git a/tests/mobile/mobile_device_extension_test.go b/tests/mobile/mobile_device_extension_test.go deleted file mode 100644 index 6de968c..0000000 --- a/tests/mobile/mobile_device_extension_test.go +++ /dev/null @@ -1,97 +0,0 @@ -package tests - -import ( - "fmt" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" - "github.com/twofas/2fas-server/tests" -) - -func TestMobileDeviceExtensionTestSuite(t *testing.T) { - suite.Run(t, new(MobileDeviceExtensionTestSuite)) -} - -type MobileDeviceExtensionTestSuite struct { - suite.Suite -} - -func (s *MobileDeviceExtensionTestSuite) SetupTest() { - tests.RemoveAllMobileDevices(s.T()) - tests.RemoveAllBrowserExtensions(s.T()) - tests.RemoveAllBrowserExtensionsDevices(s.T()) -} - -func (s *MobileDeviceExtensionTestSuite) TestDoNotFindExtensionsForNotExistingDevice() { - notExistingDeviceId := uuid.New() - - response := tests.DoAPIGet(s.T(), "/mobile/devices/"+notExistingDeviceId.String()+"/browser_extensions", nil) - - assert.Equal(s.T(), 404, response.StatusCode) -} - -func (s *MobileDeviceExtensionTestSuite) TestDoNotFindNotExistingMobileDeviceExtension() { - browserExt := tests.CreateBrowserExtension(s.T(), "go-test") - device, devicePubKey := tests.CreateDevice(s.T(), "go-test-device", "some-device-id") - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt, device) - - notExistingExtensionId := uuid.New() - response := tests.DoAPIGet(s.T(), "/mobile/devices/"+device.Id+"/browser_extensions/"+notExistingExtensionId.String(), nil) - - assert.Equal(s.T(), 404, response.StatusCode) -} - -func (s *MobileDeviceExtensionTestSuite) Test_FindExtensionForDevice() { - browserExt := tests.CreateBrowserExtension(s.T(), "go-test") - device, devicePubKey := tests.CreateDevice(s.T(), "go-test-device", "some-device-id") - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt, device) - - var deviceBrowserExtension *tests.BrowserExtensionResponse - tests.DoAPISuccessGet(s.T(), "/mobile/devices/"+device.Id+"/browser_extensions/"+browserExt.Id, &deviceBrowserExtension) - - assert.Equal(s.T(), browserExt.Id, deviceBrowserExtension.Id) -} - -func (s *MobileDeviceExtensionTestSuite) Test_FindAllDeviceExtensions() { - browserExt1 := tests.CreateBrowserExtension(s.T(), "go-test-1") - browserExt2 := tests.CreateBrowserExtension(s.T(), "go-test-2") - device, devicePubKey := tests.CreateDevice(s.T(), "go-test-device", "some-device-id") - - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt1, device) - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt2, device) - - var deviceBrowserExtensions []*tests.BrowserExtensionResponse - tests.DoAPISuccessGet(s.T(), "/mobile/devices/"+device.Id+"/browser_extensions/", &deviceBrowserExtensions) - - assert.Len(s.T(), deviceBrowserExtensions, 2) -} - -func (s *MobileDeviceExtensionTestSuite) Test_DisconnectExtensionFromDevice() { - browserExt1 := tests.CreateBrowserExtension(s.T(), "go-test") - browserExt2 := tests.CreateBrowserExtension(s.T(), "go-test") - device, devicePubKey := tests.CreateDevice(s.T(), "go-test-device", "some-device-id") - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt1, device) - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, browserExt2, device) - - tests.DoAPISuccessDelete(s.T(), "/mobile/devices/"+device.Id+"/browser_extensions/"+browserExt1.Id) - - var deviceBrowserExtension1 *tests.BrowserExtensionResponse - response := tests.DoAPIGet(s.T(), "/mobile/devices/"+device.Id+"/browser_extensions/"+browserExt1.Id, &deviceBrowserExtension1) - assert.Equal(s.T(), 404, response.StatusCode) - - var deviceBrowserExtension2 *tests.BrowserExtensionResponse - tests.DoAPISuccessGet(s.T(), "/mobile/devices/"+device.Id+"/browser_extensions/"+browserExt2.Id, &deviceBrowserExtension2) - assert.Equal(s.T(), browserExt2.Id, deviceBrowserExtension2.Id) -} - -func (s *MobileDeviceExtensionTestSuite) TestExtensionHasAlreadyBeenConnected() { - extension := tests.CreateBrowserExtension(s.T(), "go-test") - device, devicePubKey := tests.CreateDevice(s.T(), "go-test-device", "some-device-id") - tests.PairDeviceWithBrowserExtension(s.T(), devicePubKey, extension, device) - - payload := []byte(fmt.Sprintf(`{"extension_id":"%s","device_name":"%s","device_public_key":"%s"}`, extension.Id, device.Name, devicePubKey)) - - tests.DoAPIPostAndAssertCode(s.T(), 409, "/mobile/devices/"+device.Id+"/browser_extensions", payload, nil) -}