From 3ecfdccd9a0065365a00e7c8b60de7ee2df4e40b Mon Sep 17 00:00:00 2001 From: gary rong Date: Mon, 22 Feb 2021 21:33:11 +0800 Subject: [PATCH] les: clean up server handler (#22357) --- les/server_handler.go | 185 +++++++++++++++++++++++------------------ les/server_requests.go | 16 ++-- 2 files changed, 114 insertions(+), 87 deletions(-) diff --git a/les/server_handler.go b/les/server_handler.go index b6e8b050b1..7651d03cab 100644 --- a/les/server_handler.go +++ b/les/server_handler.go @@ -204,6 +204,90 @@ func (h *serverHandler) handle(p *clientPeer) error { } } +// beforeHandle will do a series of prechecks before handling message. +func (h *serverHandler) beforeHandle(p *clientPeer, reqID, responseCount uint64, msg p2p.Msg, reqCnt uint64, maxCount uint64) (*servingTask, uint64) { + // Ensure that the request sent by client peer is valid + inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0) + if reqCnt == 0 || reqCnt > maxCount { + p.fcClient.OneTimeCost(inSizeCost) + return nil, 0 + } + // Ensure that the client peer complies with the flow control + // rules agreed by both sides. + if p.isFrozen() { + p.fcClient.OneTimeCost(inSizeCost) + return nil, 0 + } + maxCost := p.fcCosts.getMaxCost(msg.Code, reqCnt) + accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost) + if !accepted { + p.freeze() + p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge))) + p.fcClient.OneTimeCost(inSizeCost) + return nil, 0 + } + // Create a multi-stage task, estimate the time it takes for the task to + // execute, and cache it in the request service queue. + factor := h.server.costTracker.globalFactor() + if factor < 0.001 { + factor = 1 + p.Log().Error("Invalid global cost factor", "factor", factor) + } + maxTime := uint64(float64(maxCost) / factor) + task := h.server.servingQueue.newTask(p, maxTime, priority) + if !task.start() { + p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost) + return nil, 0 + } + return task, maxCost +} + +// Afterhandle will perform a series of operations after message handling, +// such as updating flow control data, sending reply, etc. +func (h *serverHandler) afterHandle(p *clientPeer, reqID, responseCount uint64, msg p2p.Msg, maxCost uint64, reqCnt uint64, task *servingTask, reply *reply) { + if reply != nil { + task.done() + } + p.responseLock.Lock() + defer p.responseLock.Unlock() + + // Short circuit if the client is already frozen. + if p.isFrozen() { + realCost := h.server.costTracker.realCost(task.servingTime, msg.Size, 0) + p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost) + return + } + // Positive correction buffer value with real cost. + var replySize uint32 + if reply != nil { + replySize = reply.size() + } + var realCost uint64 + if h.server.costTracker.testing { + realCost = maxCost // Assign a fake cost for testing purpose + } else { + realCost = h.server.costTracker.realCost(task.servingTime, msg.Size, replySize) + if realCost > maxCost { + realCost = maxCost + } + } + bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost) + if reply != nil { + // Feed cost tracker request serving statistic. + h.server.costTracker.updateStats(msg.Code, reqCnt, task.servingTime, realCost) + // Reduce priority "balance" for the specific peer. + p.balance.RequestServed(realCost) + p.queueSend(func() { + if err := reply.send(bv); err != nil { + select { + case p.errCh <- err: + default: + } + } + }) + } +} + // handleMsg is invoked whenever an inbound message is received from a remote // peer. The remote connection is torn down upon returning any error. func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { @@ -221,9 +305,8 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { } defer msg.Discard() - p.responseCount++ - responseCount := p.responseCount - + // Lookup the request handler table, ensure it's supported + // message type by the protocol. req, ok := Les3[msg.Code] if !ok { p.Log().Trace("Received invalid message", "code", msg.Code) @@ -232,98 +315,42 @@ func (h *serverHandler) handleMsg(p *clientPeer, wg *sync.WaitGroup) error { } p.Log().Trace("Received " + req.Name) + // Decode the p2p message, resolve the concrete handler for it. serve, reqID, reqCnt, err := req.Handle(msg) if err != nil { clientErrorMeter.Mark(1) return errResp(ErrDecode, "%v: %v", msg, err) } - if metrics.EnabledExpensive { req.InPacketsMeter.Mark(1) req.InTrafficMeter.Mark(int64(msg.Size)) } + p.responseCount++ + responseCount := p.responseCount - // Short circuit if the peer is already frozen or the request is invalid. - inSizeCost := h.server.costTracker.realCost(0, msg.Size, 0) - if p.isFrozen() || reqCnt == 0 || reqCnt > req.MaxCount { - p.fcClient.OneTimeCost(inSizeCost) - return nil - } - // Prepaid max cost units before request been serving. - maxCost := p.fcCosts.getMaxCost(msg.Code, reqCnt) - accepted, bufShort, priority := p.fcClient.AcceptRequest(reqID, responseCount, maxCost) - if !accepted { - p.freeze() - p.Log().Error("Request came too early", "remaining", common.PrettyDuration(time.Duration(bufShort*1000000/p.fcParams.MinRecharge))) - p.fcClient.OneTimeCost(inSizeCost) + // First check this client message complies all rules before + // handling it and return a processor if all checks are passed. + task, maxCost := h.beforeHandle(p, reqID, responseCount, msg, reqCnt, req.MaxCount) + if task == nil { return nil } - // Create a multi-stage task, estimate the time it takes for the task to - // execute, and cache it in the request service queue. - factor := h.server.costTracker.globalFactor() - if factor < 0.001 { - factor = 1 - p.Log().Error("Invalid global cost factor", "factor", factor) - } - maxTime := uint64(float64(maxCost) / factor) - task := h.server.servingQueue.newTask(p, maxTime, priority) - if task.start() { - wg.Add(1) - go func() { - defer wg.Done() - reply := serve(h, p, task.waitOrStop) - if reply != nil { - task.done() - } + wg.Add(1) + go func() { + defer wg.Done() - p.responseLock.Lock() - defer p.responseLock.Unlock() + reply := serve(h, p, task.waitOrStop) + h.afterHandle(p, reqID, responseCount, msg, maxCost, reqCnt, task, reply) - // Short circuit if the client is already frozen. - if p.isFrozen() { - realCost := h.server.costTracker.realCost(task.servingTime, msg.Size, 0) - p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost) - return - } - // Positive correction buffer value with real cost. - var replySize uint32 + if metrics.EnabledExpensive { + size := uint32(0) if reply != nil { - replySize = reply.size() - } - var realCost uint64 - if h.server.costTracker.testing { - realCost = maxCost // Assign a fake cost for testing purpose - } else { - realCost = h.server.costTracker.realCost(task.servingTime, msg.Size, replySize) - if realCost > maxCost { - realCost = maxCost - } + size = reply.size() } - bv := p.fcClient.RequestProcessed(reqID, responseCount, maxCost, realCost) - if reply != nil { - // Feed cost tracker request serving statistic. - h.server.costTracker.updateStats(msg.Code, reqCnt, task.servingTime, realCost) - // Reduce priority "balance" for the specific peer. - p.balance.RequestServed(realCost) - p.queueSend(func() { - if err := reply.send(bv); err != nil { - select { - case p.errCh <- err: - default: - } - } - }) - if metrics.EnabledExpensive { - req.OutPacketsMeter.Mark(1) - req.OutTrafficMeter.Mark(int64(replySize)) - req.ServingTimeMeter.Update(time.Duration(task.servingTime)) - } - } - }() - } else { - p.fcClient.RequestProcessed(reqID, responseCount, maxCost, inSizeCost) - } - + req.OutPacketsMeter.Mark(1) + req.OutTrafficMeter.Mark(int64(size)) + req.ServingTimeMeter.Update(time.Duration(task.servingTime)) + } + }() // If the client has made too much invalid request(e.g. request a non-existent data), // reject them to prevent SPAM attack. if p.getInvalid() > maxRequestErrors { diff --git a/les/server_requests.go b/les/server_requests.go index d4af8006f0..07f30b1b73 100644 --- a/les/server_requests.go +++ b/les/server_requests.go @@ -65,7 +65,7 @@ type serveRequestFn func(backend serverBackend, peer *clientPeer, waitOrStop fun // Les3 contains the request types supported by les/2 and les/3 var Les3 = map[uint64]RequestType{ - GetBlockHeadersMsg: RequestType{ + GetBlockHeadersMsg: { Name: "block header request", MaxCount: MaxHeaderFetch, InPacketsMeter: miscInHeaderPacketsMeter, @@ -75,7 +75,7 @@ var Les3 = map[uint64]RequestType{ ServingTimeMeter: miscServingTimeHeaderTimer, Handle: handleGetBlockHeaders, }, - GetBlockBodiesMsg: RequestType{ + GetBlockBodiesMsg: { Name: "block bodies request", MaxCount: MaxBodyFetch, InPacketsMeter: miscInBodyPacketsMeter, @@ -85,7 +85,7 @@ var Les3 = map[uint64]RequestType{ ServingTimeMeter: miscServingTimeBodyTimer, Handle: handleGetBlockBodies, }, - GetCodeMsg: RequestType{ + GetCodeMsg: { Name: "code request", MaxCount: MaxCodeFetch, InPacketsMeter: miscInCodePacketsMeter, @@ -95,7 +95,7 @@ var Les3 = map[uint64]RequestType{ ServingTimeMeter: miscServingTimeCodeTimer, Handle: handleGetCode, }, - GetReceiptsMsg: RequestType{ + GetReceiptsMsg: { Name: "receipts request", MaxCount: MaxReceiptFetch, InPacketsMeter: miscInReceiptPacketsMeter, @@ -105,7 +105,7 @@ var Les3 = map[uint64]RequestType{ ServingTimeMeter: miscServingTimeReceiptTimer, Handle: handleGetReceipts, }, - GetProofsV2Msg: RequestType{ + GetProofsV2Msg: { Name: "les/2 proofs request", MaxCount: MaxProofsFetch, InPacketsMeter: miscInTrieProofPacketsMeter, @@ -115,7 +115,7 @@ var Les3 = map[uint64]RequestType{ ServingTimeMeter: miscServingTimeTrieProofTimer, Handle: handleGetProofs, }, - GetHelperTrieProofsMsg: RequestType{ + GetHelperTrieProofsMsg: { Name: "helper trie proof request", MaxCount: MaxHelperTrieProofsFetch, InPacketsMeter: miscInHelperTriePacketsMeter, @@ -125,7 +125,7 @@ var Les3 = map[uint64]RequestType{ ServingTimeMeter: miscServingTimeHelperTrieTimer, Handle: handleGetHelperTrieProofs, }, - SendTxV2Msg: RequestType{ + SendTxV2Msg: { Name: "new transactions", MaxCount: MaxTxSend, InPacketsMeter: miscInTxsPacketsMeter, @@ -135,7 +135,7 @@ var Les3 = map[uint64]RequestType{ ServingTimeMeter: miscServingTimeTxTimer, Handle: handleSendTx, }, - GetTxStatusMsg: RequestType{ + GetTxStatusMsg: { Name: "transaction status query request", MaxCount: MaxTxStatus, InPacketsMeter: miscInTxStatusPacketsMeter,