From e6947aa509101e14ff4d3783261f6c854cec463d Mon Sep 17 00:00:00 2001 From: Tobiasz Heller <14020794+tobiaszheller@users.noreply.github.com> Date: Sat, 16 Mar 2024 19:05:21 +0100 Subject: [PATCH] feat: multiple fixes in logger (#32) --- cmd/admin/main.go | 2 +- cmd/api/main.go | 2 +- cmd/websocket/main.go | 2 +- .../app/command/request_2fa_token.go | 7 +- .../app/command/store_log_event.go | 12 +- .../app/security/middleware.go | 2 +- internal/api/browser_extension/ports/http.go | 4 +- .../api/mobile/app/command/send_2fa_token.go | 8 +- .../api/mobile/app/security/middleware.go | 2 +- internal/api/mobile/ports/http.go | 3 +- internal/common/http/client.go | 147 ------------------ internal/common/http/log.go | 49 ++++-- internal/common/http/request.go | 43 +---- internal/common/http/server.go | 3 +- internal/common/logging/logger.go | 144 +++++++---------- .../common/rate_limit/redis_rate_limit.go | 2 +- internal/common/recovery/gin.go | 5 +- internal/common/security/middleware.go | 2 +- .../websocket/gorilla_websocket_client.go | 7 +- internal/websocket/app.go | 3 +- internal/websocket/common/client.go | 6 +- internal/websocket/common/handler.go | 14 +- 22 files changed, 145 insertions(+), 324 deletions(-) delete mode 100644 internal/common/http/client.go diff --git a/cmd/admin/main.go b/cmd/admin/main.go index a961a6a..0bdbb47 100644 --- a/cmd/admin/main.go +++ b/cmd/admin/main.go @@ -10,7 +10,7 @@ import ( ) func main() { - logging.WithDefaultField("service_name", "admin_api") + logging.Init(logging.Fields{"service_name": "admin_api"}) config.LoadConfiguration() diff --git a/cmd/api/main.go b/cmd/api/main.go index f80d997..cce7a87 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -9,7 +9,7 @@ import ( ) func main() { - logging.WithDefaultField("service_name", "api") + logging.Init(logging.Fields{"service_name": "api"}) config.LoadConfiguration() diff --git a/cmd/websocket/main.go b/cmd/websocket/main.go index 82e92f8..c737bcc 100644 --- a/cmd/websocket/main.go +++ b/cmd/websocket/main.go @@ -7,7 +7,7 @@ import ( ) func main() { - logging.WithDefaultField("service_name", "websocket_api") + logging.Init(logging.Fields{"service_name": "websocket_api"}) config.LoadConfiguration() diff --git a/internal/api/browser_extension/app/command/request_2fa_token.go b/internal/api/browser_extension/app/command/request_2fa_token.go index f013bcc..ebd5f79 100644 --- a/internal/api/browser_extension/app/command/request_2fa_token.go +++ b/internal/api/browser_extension/app/command/request_2fa_token.go @@ -58,7 +58,8 @@ type Request2FaTokenHandler struct { Pusher push.Pusher } -func (h *Request2FaTokenHandler) Handle(cmd *Request2FaToken) error { +func (h *Request2FaTokenHandler) Handle(ctx context.Context, cmd *Request2FaToken) error { + log := logging.FromContext(ctx) extId, _ := uuid.Parse(cmd.ExtensionId) browserExtension, err := h.BrowserExtensionsRepository.FindById(extId) @@ -87,7 +88,7 @@ func (h *Request2FaTokenHandler) Handle(cmd *Request2FaToken) error { for _, device := range pairedDevices { if device.FcmToken == "" { - logging.WithFields(logging.Fields{ + log.WithFields(logging.Fields{ "extension_id": extId.String(), "device_id": device.Id.String(), "token_request_id": cmd.Id, @@ -117,7 +118,7 @@ func (h *Request2FaTokenHandler) Handle(cmd *Request2FaToken) error { ) if err != nil && !messaging.IsUnregistered(err) { - logging.WithFields(logging.Fields{ + log.WithFields(logging.Fields{ "extension_id": extId.String(), "device_id": device.Id.String(), "token_request_id": cmd.Id, diff --git a/internal/api/browser_extension/app/command/store_log_event.go b/internal/api/browser_extension/app/command/store_log_event.go index ac3f9b6..5cb78d6 100644 --- a/internal/api/browser_extension/app/command/store_log_event.go +++ b/internal/api/browser_extension/app/command/store_log_event.go @@ -1,7 +1,9 @@ package command import ( + "context" "encoding/json" + "github.com/google/uuid" "github.com/twofas/2fas-server/internal/api/browser_extension/domain" "github.com/twofas/2fas-server/internal/common/logging" @@ -18,7 +20,7 @@ type StoreLogEventHandler struct { BrowserExtensionsRepository domain.BrowserExtensionRepository } -func (h *StoreLogEventHandler) Handle(cmd *StoreLogEvent) { +func (h *StoreLogEventHandler) Handle(ctx context.Context, cmd *StoreLogEvent) { extId, _ := uuid.Parse(cmd.ExtensionId) _, err := h.BrowserExtensionsRepository.FindById(extId) @@ -35,12 +37,12 @@ func (h *StoreLogEventHandler) Handle(cmd *StoreLogEvent) { switch cmd.Level { case "info": - logging.WithFields(context).Info(cmd.Message) + logging.FromContext(ctx).WithFields(context).Info(cmd.Message) case "warning": - logging.WithFields(context).Warning(cmd.Message) + logging.FromContext(ctx).WithFields(context).Warning(cmd.Message) case "error": - logging.WithFields(context).Error(cmd.Message) + logging.FromContext(ctx).WithFields(context).Error(cmd.Message) case "debug": - logging.WithFields(context).Debug(cmd.Message) + logging.FromContext(ctx).WithFields(context).Debug(cmd.Message) } } diff --git a/internal/api/browser_extension/app/security/middleware.go b/internal/api/browser_extension/app/security/middleware.go index 3b35fee..54c054d 100644 --- a/internal/api/browser_extension/app/security/middleware.go +++ b/internal/api/browser_extension/app/security/middleware.go @@ -33,7 +33,7 @@ func BrowserExtensionBandwidthAuditMiddleware(rateLimiter rate_limit.RateLimiter limitReached := rateLimiter.Test(c, key, rate) if limitReached { - logging.WithFields(logging.Fields{ + logging.FromContext(c.Request.Context()).WithFields(logging.Fields{ "type": "security", "uri": c.Request.URL.String(), "browser_extension_id": extensionId, diff --git a/internal/api/browser_extension/ports/http.go b/internal/api/browser_extension/ports/http.go index 7c095c3..6f7617b 100644 --- a/internal/api/browser_extension/ports/http.go +++ b/internal/api/browser_extension/ports/http.go @@ -43,7 +43,7 @@ func (r *RoutesHandler) Log(c *gin.Context) { return } - r.cqrs.Commands.StoreLogEvent.Handle(cmd) + r.cqrs.Commands.StoreLogEvent.Handle(c.Request.Context(), cmd) c.JSON(200, api.NewOk("Log has been stored")) } @@ -311,7 +311,7 @@ func (r *RoutesHandler) Request2FaToken(c *gin.Context) { return } - err = r.cqrs.Commands.Request2FaToken.Handle(cmd) + err = r.cqrs.Commands.Request2FaToken.Handle(c.Request.Context(), cmd) if err != nil { c.JSON(500, api.NewInternalServerError(err)) diff --git a/internal/api/mobile/app/command/send_2fa_token.go b/internal/api/mobile/app/command/send_2fa_token.go index 741e437..fc50f8b 100644 --- a/internal/api/mobile/app/command/send_2fa_token.go +++ b/internal/api/mobile/app/command/send_2fa_token.go @@ -1,6 +1,7 @@ package command import ( + "context" "fmt" "github.com/avast/retry-go/v4" @@ -42,10 +43,11 @@ type Send2FaTokenHandler struct { WebsocketClient *websocket.WebsocketApiClient } -func (h *Send2FaTokenHandler) Handle(cmd *Send2FaToken) error { +func (h *Send2FaTokenHandler) Handle(ctx context.Context, cmd *Send2FaToken) error { extId, _ := uuid.Parse(cmd.ExtensionId) + log := logging.FromContext(ctx) - logging.WithFields(logging.Fields{ + log.WithFields(logging.Fields{ "browser_extension_id": cmd.ExtensionId, "device_id": cmd.DeviceId, "token_request_id": cmd.TokenRequestId, @@ -70,7 +72,7 @@ func (h *Send2FaTokenHandler) Handle(cmd *Send2FaToken) error { ) if err != nil { - logging.WithFields(logging.Fields{ + log.WithFields(logging.Fields{ "error": err.Error(), "message": message, }).Error("Cannot send websocket message") diff --git a/internal/api/mobile/app/security/middleware.go b/internal/api/mobile/app/security/middleware.go index eda05fd..ccf589d 100644 --- a/internal/api/mobile/app/security/middleware.go +++ b/internal/api/mobile/app/security/middleware.go @@ -38,7 +38,7 @@ func MobileIpAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter, rateLimitV limitReached := rateLimiter.Test(c, key, rate) if limitReached { - logging.WithFields(logging.Fields{ + logging.FromContext(c.Request.Context()).WithFields(logging.Fields{ "type": "security", "uri": c.Request.URL.String(), "device_id": deviceId, diff --git a/internal/api/mobile/ports/http.go b/internal/api/mobile/ports/http.go index 0734ddd..7155fa0 100644 --- a/internal/api/mobile/ports/http.go +++ b/internal/api/mobile/ports/http.go @@ -2,6 +2,7 @@ package ports import ( "errors" + "github.com/gin-gonic/gin" "github.com/go-playground/validator/v10" "github.com/google/uuid" @@ -289,7 +290,7 @@ func (r *RoutesHandler) Send2FaToken(c *gin.Context) { return } - err = r.cqrs.Commands.Send2FaToken.Handle(cmd) + err = r.cqrs.Commands.Send2FaToken.Handle(c.Request.Context(), cmd) if err != nil { c.JSON(500, api.NewInternalServerError(err)) diff --git a/internal/common/http/client.go b/internal/common/http/client.go deleted file mode 100644 index debc35b..0000000 --- a/internal/common/http/client.go +++ /dev/null @@ -1,147 +0,0 @@ -package http - -import ( - "bytes" - "context" - "encoding/json" - "github.com/twofas/2fas-server/internal/common/logging" - "io" - "io/ioutil" - "net" - "net/http" - "net/url" - "time" -) - -var tunedHttpTransport = &http.Transport{ - MaxIdleConns: 1024, - MaxIdleConnsPerHost: 1024, - TLSHandshakeTimeout: 10 * time.Second, - DialContext: (&net.Dialer{ - Timeout: 60 * time.Second, - KeepAlive: 60 * time.Second, - }).DialContext, -} - -type HttpClient struct { - client *http.Client - baseUrl *url.URL - credentialsCallback func(r *http.Request) -} - -func (w *HttpClient) CredentialsProvider(credentialsCallback func(r *http.Request)) { - w.credentialsCallback = credentialsCallback -} - -func (w *HttpClient) Post(ctx context.Context, path string, result interface{}, data interface{}) error { - req, err := w.newJsonRequest("POST", path, data) - - if err != nil { - return err - } - - return w.executeRequest(ctx, req, result) -} - -func (w *HttpClient) newJsonRequest(method, path string, body interface{}) (*http.Request, error) { - var buf io.ReadWriter - - logging.WithFields(logging.Fields{ - "method": method, - "body": body, - }).Debug("HTTP Request") - - if body != nil { - buf = new(bytes.Buffer) - - encoder := json.NewEncoder(buf) - err := encoder.Encode(body) - - if err != nil { - return nil, err - } - } - - return w.newRequest(method, path, buf, "application/json") -} - -func (w *HttpClient) newRequest(method, path string, buf io.Reader, contentType string) (*http.Request, error) { - u, err := w.baseUrl.Parse(path) - - if err != nil { - return nil, err - } - - req, err := http.NewRequest(method, u.String(), buf) - - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", contentType) - - return req, nil -} - -func (w *HttpClient) executeRequest(ctx context.Context, req *http.Request, v interface{}) error { - req = req.WithContext(ctx) - - if w.credentialsCallback != nil { - w.credentialsCallback(req) - } - - resp, err := w.client.Do(req) - - if err != nil { - return err - } - - defer resp.Body.Close() - - responseData, err := w.checkError(resp) - - if err != nil { - return err - } - - if v == nil { - return nil - } - - responseDataReader := bytes.NewReader(responseData) - - err = json.NewDecoder(responseDataReader).Decode(v) - - return err -} - -func (w *HttpClient) checkError(r *http.Response) ([]byte, error) { - errorResponse := &ErrorResponse{} - - responseData, err := ioutil.ReadAll(r.Body) - - if err == nil && responseData != nil { - json.Unmarshal(responseData, errorResponse) - } - - if httpCode := r.StatusCode; 200 <= httpCode && httpCode <= 300 { - return responseData, nil - } - - errorResponse.Status = r.StatusCode - - return responseData, errorResponse -} - -func NewHttpClient(baseUrl string) *HttpClient { - clientBaseUrl, err := url.Parse(baseUrl) - - if err != nil { - panic(err) - } - - return &HttpClient{ - client: &http.Client{Transport: tunedHttpTransport}, - baseUrl: clientBaseUrl, - } -} diff --git a/internal/common/http/log.go b/internal/common/http/log.go index 7480d2e..c6edadf 100644 --- a/internal/common/http/log.go +++ b/internal/common/http/log.go @@ -5,9 +5,37 @@ import ( "io" "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/twofas/2fas-server/internal/common/logging" ) +const ( + CorrelationIdHeader = "X-Correlation-ID" +) + +var ( + RequestId string + CorrelationId string +) + +func LoggingMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + requestId := uuid.New().String() + correlationId := c.Request.Header.Get(CorrelationIdHeader) + if correlationId == "" { + correlationId = uuid.New().String() + } + + ctxWithLog := logging.AddToContext(c.Request.Context(), logging.WithFields(map[string]any{ + "correlation_id": correlationId, + "request_id": requestId, + })) + c.Request = c.Request.WithContext(ctxWithLog) + + } +} + func RequestJsonLogger() gin.HandlerFunc { return func(c *gin.Context) { var buf bytes.Buffer @@ -17,22 +45,19 @@ func RequestJsonLogger() gin.HandlerFunc { c.Request.Body = io.NopCloser(&buf) - logging.WithFields(logging.Fields{ - "method": c.Request.Method, - "path": c.Request.URL.Path, - "headers": c.Request.Header, - "body": string(body), - "request_id": c.GetString(RequestIdKey), - "correlation_id": c.GetString(CorrelationIdKey), + log := logging.FromContext(c.Request.Context()) + + log.WithFields(logging.Fields{ + "method": c.Request.Method, + "path": c.Request.URL.Path, + "body": string(body), }).Info("Request") c.Next() - logging.WithFields(logging.Fields{ - "method": c.Request.Method, - "path": c.Request.URL.Path, - "request_id": c.GetString(RequestIdKey), - "correlation_id": c.GetString(CorrelationIdKey), + log.WithFields(logging.Fields{ + "method": c.Request.Method, + "path": c.Request.URL.Path, }).Info("Response") } } diff --git a/internal/common/http/request.go b/internal/common/http/request.go index 7f517ff..20f8588 100644 --- a/internal/common/http/request.go +++ b/internal/common/http/request.go @@ -1,50 +1,11 @@ package http import ( - "github.com/gin-gonic/gin" - "github.com/google/uuid" - "github.com/twofas/2fas-server/internal/common/logging" "net/http" + + "github.com/gin-gonic/gin" ) -const ( - RequestIdKey = "request_id" - - CorrelationIdKey = "correlation_id" - CorrelationIdHeader = "X-Correlation-ID" -) - -var ( - RequestId string - CorrelationId string -) - -func RequestIdMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - RequestId = uuid.New().String() - - c.Set(RequestIdKey, RequestId) - - logging.WithDefaultField(RequestIdKey, RequestId) - } -} - -func CorrelationIdMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - c.Set(CorrelationIdKey, uuid.New().String()) - - CorrelationId = c.Request.Header.Get(CorrelationIdHeader) - - if CorrelationId == "" { - CorrelationId = uuid.New().String() - } - - logging.WithDefaultField(CorrelationIdKey, CorrelationId) - - c.Set(CorrelationIdKey, CorrelationId) - } -} - func BodySizeLimitMiddleware(requestBytesLimit int64) gin.HandlerFunc { return func(c *gin.Context) { var w http.ResponseWriter = c.Writer diff --git a/internal/common/http/server.go b/internal/common/http/server.go index 6fbe1b6..f8ea649 100644 --- a/internal/common/http/server.go +++ b/internal/common/http/server.go @@ -9,8 +9,7 @@ func RunHttpServer(addr string, init func(engine *gin.Engine)) { router.Use(gin.Recovery()) router.Use(corsMiddleware()) - router.Use(RequestIdMiddleware()) - router.Use(CorrelationIdMiddleware()) + router.Use(LoggingMiddleware()) router.Use(RequestJsonLogger()) init(router) diff --git a/internal/common/logging/logger.go b/internal/common/logging/logger.go index 6e9f6d3..b73736b 100644 --- a/internal/common/logging/logger.go +++ b/internal/common/logging/logger.go @@ -1,6 +1,7 @@ package logging import ( + "context" "encoding/json" "reflect" "sync" @@ -8,111 +9,92 @@ import ( "github.com/sirupsen/logrus" ) -// TODO: do not log reuse on every request. -type Fields map[string]interface{} +type Fields = logrus.Fields -var ( - customLogger = New() - defaultFields = logrus.Fields{} - defaultFieldsMutex = sync.RWMutex{} -) - -func New() *logrus.Logger { - logger := logrus.New() - - logger.SetFormatter(&logrus.JSONFormatter{ - FieldMap: logrus.FieldMap{ - logrus.FieldKeyTime: "timestamp", - logrus.FieldKeyLevel: "level", - logrus.FieldKeyMsg: "message", - }, - }) - - logger.SetLevel(logrus.InfoLevel) - - return logger +type FieldLogger interface { + logrus.FieldLogger } -func WithDefaultField(key, value string) *logrus.Logger { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() +var ( + log logrus.FieldLogger + once sync.Once +) - defaultFields[key] = value +// Init initialize global instance of logging library. +func Init(fields Fields) FieldLogger { + once.Do(func() { + logger := logrus.New() - return customLogger + logger.SetFormatter(&logrus.JSONFormatter{ + FieldMap: logrus.FieldMap{ + logrus.FieldKeyTime: "timestamp", + logrus.FieldKeyLevel: "level", + logrus.FieldKeyMsg: "message", + }, + }) + + logger.SetLevel(logrus.InfoLevel) + + log = logger.WithFields(fields) + }) + return log +} + +type ctxKey int + +const ( + loggerKey ctxKey = iota +) + +func AddToContext(ctx context.Context, log FieldLogger) context.Context { + return context.WithValue(ctx, loggerKey, log) +} + +func FromContext(ctx context.Context) FieldLogger { + log, ok := ctx.Value(loggerKey).(FieldLogger) + if ok { + return log + } + return log +} + +func WithFields(fields Fields) FieldLogger { + return log.WithFields(fields) +} + +func WithField(key string, value any) FieldLogger { + return log.WithField(key, value) } func Info(args ...interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - customLogger.WithFields(defaultFields).Info(args...) + log.Info(args...) } func Infof(format string, args ...interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - customLogger.WithFields(defaultFields).Infof(format, args...) + log.Infof(format, args...) } func Error(args ...interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - customLogger.WithFields(defaultFields).Error(args...) + log.Error(args...) } func Errorf(format string, args ...interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - customLogger.WithFields(defaultFields).Errorf(format, args...) + log.Errorf(format, args...) } func Warning(args ...interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - customLogger.WithFields(defaultFields).Warning(args...) + log.Warning(args...) } func Fatal(args ...interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - customLogger.WithFields(defaultFields).Fatal(args...) + log.Fatal(args...) } func Fatalf(format string, args ...interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - customLogger.WithFields(defaultFields).Fatalf(format, args...) -} - -func WithField(key string, value interface{}) *logrus.Entry { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - return customLogger. - WithFields(defaultFields). - WithField(key, value) -} - -func WithFields(fields Fields) *logrus.Entry { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - - return customLogger. - WithFields(logrus.Fields(fields)). - WithFields(defaultFields) + log.Fatalf(format, args...) } func LogCommand(command interface{}) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - context, _ := json.Marshal(command) commandName := reflect.TypeOf(command).Elem().Name() @@ -120,8 +102,7 @@ func LogCommand(command interface{}) { var commandAsFields logrus.Fields json.Unmarshal(context, &commandAsFields) - customLogger. - WithFields(defaultFields). + log. WithFields(logrus.Fields{ "command_name": commandName, "command": commandAsFields, @@ -129,13 +110,8 @@ func LogCommand(command interface{}) { } func LogCommandFailed(command interface{}, err error) { - defaultFieldsMutex.Lock() - defer defaultFieldsMutex.Unlock() - commandName := reflect.TypeOf(command).Elem().Name() - - customLogger. - WithFields(defaultFields). + log. WithFields(logrus.Fields{ "reason": err.Error(), }).Info("Command failed" + commandName) diff --git a/internal/common/rate_limit/redis_rate_limit.go b/internal/common/rate_limit/redis_rate_limit.go index 18c2882..bfe9036 100644 --- a/internal/common/rate_limit/redis_rate_limit.go +++ b/internal/common/rate_limit/redis_rate_limit.go @@ -38,7 +38,7 @@ func (r *RedisRateLimit) Test(ctx context.Context, key string, rate Rate) bool { Period: rate.TimeUnit, }) if err != nil { - logging.WithFields(logging.Fields{ + logging.FromContext(ctx).WithFields(logging.Fields{ "type": "security", }).Warnf("Could not check rate limit: %v", err) diff --git a/internal/common/recovery/gin.go b/internal/common/recovery/gin.go index b07b16c..a2c8750 100644 --- a/internal/common/recovery/gin.go +++ b/internal/common/recovery/gin.go @@ -3,10 +3,11 @@ package recovery import ( "bytes" "fmt" - "github.com/gin-gonic/gin" - "github.com/twofas/2fas-server/internal/common/logging" "io/ioutil" "runtime" + + "github.com/gin-gonic/gin" + "github.com/twofas/2fas-server/internal/common/logging" ) func RecoveryMiddleware() gin.HandlerFunc { diff --git a/internal/common/security/middleware.go b/internal/common/security/middleware.go index f02dc99..b1a35a4 100644 --- a/internal/common/security/middleware.go +++ b/internal/common/security/middleware.go @@ -30,7 +30,7 @@ func IPAbuseAuditMiddleware(rateLimiter rate_limit.RateLimiter, rateLimitValue i limitReached := rateLimiter.Test(c, key, rate) if limitReached { - logging.WithFields(logging.Fields{ + logging.FromContext(c.Request.Context()).WithFields(logging.Fields{ "type": "security", "uri": c.Request.URL.String(), "ip": c.ClientIP(), diff --git a/internal/common/websocket/gorilla_websocket_client.go b/internal/common/websocket/gorilla_websocket_client.go index 5d2682b..3e6fd27 100644 --- a/internal/common/websocket/gorilla_websocket_client.go +++ b/internal/common/websocket/gorilla_websocket_client.go @@ -2,12 +2,13 @@ package websocket import ( "encoding/json" - "github.com/gorilla/websocket" - app_http "github.com/twofas/2fas-server/internal/common/http" - "github.com/twofas/2fas-server/internal/common/logging" "net/http" "net/url" "path" + + "github.com/gorilla/websocket" + app_http "github.com/twofas/2fas-server/internal/common/http" + "github.com/twofas/2fas-server/internal/common/logging" ) type WebsocketApiClient struct { diff --git a/internal/websocket/app.go b/internal/websocket/app.go index 88007f8..f54c36a 100644 --- a/internal/websocket/app.go +++ b/internal/websocket/app.go @@ -18,8 +18,7 @@ func NewServer(addr string) *Server { router := gin.New() router.Use(recovery.RecoveryMiddleware()) - router.Use(http.RequestIdMiddleware()) - router.Use(http.CorrelationIdMiddleware()) + router.Use(http.LoggingMiddleware()) router.Use(http.RequestJsonLogger()) connectionHandler := common.NewConnectionHandler() diff --git a/internal/websocket/common/client.go b/internal/websocket/common/client.go index d1af2b6..bf35fd7 100644 --- a/internal/websocket/common/client.go +++ b/internal/websocket/common/client.go @@ -50,7 +50,7 @@ type Client struct { // The application runs readPump in a per-connection goroutine. The application // ensures that there is at most one reader on a connection by executing all // reads from this goroutine. -func (c *Client) readPump() { +func (c *Client) readPump(log logging.FieldLogger) { defer func() { c.hub.unregisterClient(c) c.conn.Close() @@ -69,11 +69,11 @@ func (c *Client) readPump() { if err != nil { if websocket.IsUnexpectedCloseError(err, acceptedCloseStatus...) { - logging.WithFields(logging.Fields{ + log.WithFields(logging.Fields{ "reason": err.Error(), }).Error("Websocket connection closed unexpected") } else { - logging.WithFields(logging.Fields{ + log.WithFields(logging.Fields{ "reason": err.Error(), }).Info("Connection closed") } diff --git a/internal/websocket/common/handler.go b/internal/websocket/common/handler.go index 7085283..965623b 100644 --- a/internal/websocket/common/handler.go +++ b/internal/websocket/common/handler.go @@ -42,18 +42,18 @@ func (h *ConnectionHandler) Handler() gin.HandlerFunc { return func(c *gin.Context) { channel := c.Request.URL.Path - logging.WithDefaultField("channel", channel) + log := logging.FromContext(c.Request.Context()).WithField("channel", channel) - logging.Info("New channel subscriber") + log.Info("New channel subscriber") - h.serveWs(c.Writer, c.Request, channel) + h.serveWs(c.Writer, c.Request, channel, log) } } -func (h *ConnectionHandler) serveWs(w http.ResponseWriter, r *http.Request, channel string) { +func (h *ConnectionHandler) serveWs(w http.ResponseWriter, r *http.Request, channel string, log logging.FieldLogger) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { - logging.Errorf("Failed to upgrade connection: %v", err) + log.Errorf("Failed to upgrade connection: %v", err) w.WriteHeader(http.StatusInternalServerError) return } @@ -65,7 +65,7 @@ func (h *ConnectionHandler) serveWs(w http.ResponseWriter, r *http.Request, chan }) go recovery.DoNotPanic(func() { - client.readPump() + client.readPump(log) }) go recovery.DoNotPanic(func() { @@ -73,7 +73,7 @@ func (h *ConnectionHandler) serveWs(w http.ResponseWriter, r *http.Request, chan timeout := time.After(disconnectAfter) <-timeout - logging.Info("Connection closed after", disconnectAfter) + log.Info("Connection closed after", disconnectAfter) client.hub.unregisterClient(client) client.conn.Close()