grpc 服务端源码阅读

 : jank    :   : 430    : 2020-10-29 08:05  rpc

 一 、服务端初始化

   go grpc 服务端代码整体采用建造者模式, 服务初始化时,嵌入所有需要加载的接口服务配置,并加载 ,如 是否加密传输 :

// Create tls based credential.

creds, err := credentials.NewServerTLSFromFile

(testdata.Path("server1.pem"), testdata.Path("server1.key"))

if err != nil {

log.Fatalf("failed to create credentials: %v", err)

}

s := grpc.NewServer(grpc.Creds(creds))


服务初始化 NewServer(...)代码如下:

func NewServer(opt ...ServerOption) *Server {
    opts := defaultServerOptions //default server options
    for _, o := range opt {
        o.apply(&opts) //iterate through opt to implement opts
    }
    s := &Server{
        lis:    make(map[net.Listener]bool),
        opts:   opts,
        conns:  make(map[transport.ServerTransport]bool), //connect cache
        m:      make(map[string]*service),
        quit:   grpcsync.NewEvent(),
        done:   grpcsync.NewEvent(),
        czData: new(channelzData), //rpc_util.go
    }
    chainUnaryServerInterceptors(s)  //only zero(default) / one unary server intercetor, (./intercept.go)
    chainStreamServerInterceptors(s) //the same as above
    s.cv = sync.NewCond(&s.mu)
    if EnableTracing { //global config in ./trace.go use google.org/x/net/trace
        _, file, line, _ := runtime.Caller(1)
        //golang.org/x/net/trace/event.go
        s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
    }
 
    if s.opts.numServerWorkers > 0 { //default 0, multi goroutine( <= cpu kernel) to process
        s.initServerWorkers()
    }
 
    if channelz.IsOn() { //internal/channelz/funcs.go TurnOn()
        s.channelzID = channelz.RegisterServer(&channelzServer{s}, "") //id int64, default 1, ... add 1
    }
    return s
}



二、服务启动

     1.服务外层启动net 服务并传入grpc Serve中进行监听

2.初始化后获得Sever 结构,通过Sever 结构进行服务的配置启动相应的服务

3.在服务监听过程中如果出现Temporary类型的错误,则进行重试,否则返回结束服务接收

4.在循环处理中每次都启用一个新的goroutine 去运行新的客户端连接

关键代码如下:

 
// Serve accepts incoming connections on the listener lis, creating a new
// ServerTransport and service goroutine for each. The service goroutines
// read gRPC requests and then call the registered handlers to reply to them.
// Serve returns when lis.Accept fails with fatal errors.  lis will be closed when
// this method returns.
// Serve will return a non-nil error unless Stop or GracefulStop is called.
func (s *Server) Serve(lis net.Listener) error {

    ...

    //wait to stop server
    s.serveWG.Add(1)
    defer func() {
        s.serveWG.Done()
        if s.quit.HasFired() {
            // Stop or GracefulStop called; block until done and return nil.
            <-s.done.Done()
        }
    }()
 
    ls := &listenSocket{Listener: lis}
    s.lis[ls] = true //add to net.Listener map

    ...
    var tempDelay time.Duration // how long to sleep on accept failure
 
    //cycle accept client connect
    for { 
        rawConn, err := lis.Accept()
        if err != nil {
            if ne, ok := err.(interface {
                Temporary() bool
            }); ok && ne.Temporary() { //failed retry,if err type is Temporary
                ...
                ...
            }
                ...
            return err
        }
        tempDelay = 0
        s.serveWG.Add(1)
        go func() {
            s.handleRawConn(rawConn) 
            s.serveWG.Done()
        }()
    }
}


handleRawConn函数主要做了如下几件事,

1.设置连接超时时间,并校验连接

2.使用当前连接创建一个http2传输结构,并加入当前所有连接缓存中

3.把传输结构传入服务处理函数serveStreams(...) 中

handleRawConn 关键代码如下:

func (s *Server) handleRawConn(rawConn net.Conn) {
	if s.quit.HasFired() {
		rawConn.Close()
		return
	}
	//set connection timeout(if timeout close connection, ), default 120s
	rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout)) 
	//if NewServer() add creds, use authenticator transport
	//interface in credentials/credentials.go, such as tls creds implement in credentials/tls.go
	conn, authInfo, err := s.useTransportAuthenticator(rawConn)  
	if err != nil {
                ...
		rawConn.SetDeadline(time.Time{}) //reset deadline,close timer
		return
	}

	// Finish handshaking (HTTP2), new http2 transport (implement in internal/transport/http2_server.go)
	st := s.newHTTP2Transport(conn, authInfo)
	if st == nil {
		return
	}

	rawConn.SetDeadline(time.Time{}) //reset deadline 
	if !s.addConn(st) {              //add to conns map ( cache connect)
		return
	}
	go func() {
		s.serveStreams(st) 
		s.removeConn(st)  //remove connection
	}()
}

s.serverStreams(st), 通过调用st.HandleStreams嵌入handleStream() 流处理方法,及上下文函数,进行循环处理,嵌入外层函数运行的好处是服务底层不需要关心服务的具体实现,只需要在运行中调用即可。其中handleStream 的并行个数可在服务初始化时进行配置,默认只运行一个。

func (s *Server) serveStreams(st transport.ServerTransport) {
         ...
	//cycle handle stream and trace, it is embedded to process
	st.HandleStreams(func(stream *transport.Stream) {
		wg.Add(1)
		if s.opts.numServerWorkers > 0 { //default 0
			//if has multi server workers, and get a worker by modulus
                        ...
		} else {
			//default server worker handle
			go func() {
				defer wg.Done()
				s.handleStream(st, stream, s.traceInfo(st, stream))
			}()
		}
	}, func(ctx context.Context, method string) context.Context {
                ...
	})
	...
}


 HandleStreams() 是transport 接口的中的一个方法,其中http2的实现,通过循环读取framer中的frame结构,并对其进行分类处理,其中包括头帧、数据帧处理等

关键代码如下:

func (t *http2Server) HandleStreams(handle func(*Stream), traceCtx func(context.Context, string) context.Context) {
	defer close(t.readerDone)
	for {
		t.controlBuf.throttle()
		frame, err := t.framer.fr.ReadFrame()
                ...
		switch frame := frame.(type) {
		//handle header frame
		case *http2.MetaHeadersFrame:
			//new stream register
			if t.operateHeaders(frame, handle, traceCtx) {
				t.Close()
				break
			}
	        //handle data frame
		case *http2.DataFrame:
			t.handleData(frame)
			...
		}
	}
}

 t.operateHeaders() ,  处理头信息。

首先进行decode,并获取全部的头信息, 创建一个新的流处理结构体, 并丰富其字段.

如果客户端设置了超时,则在stream中设置带有相应超时时间的context

注册流信息,并调用handleStream() 方法处理


func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream), traceCtx func(context.Context, string) context.Context) (fatal bool) {
	//get stream id
	streamID := frame.Header().StreamID
        state := &decodeState{
            serverSide: true, //is server side
        }
	//decode header
	//// get state.data.encoding、method、contentSubtype and so on , ./http_util.go
	//if had grpc-timeout header, set state.data.timeoutSet = true
	if err := state.decodeHeader(frame); err != nil { 
		if se, ok := status.FromError(err); ok {
			t.controlBuf.put(&cleanupStream{
				streamID: streamID,
				rst:      true, //close stream
				rstCode:  statusCodeConvTab[se.Code()],
				onWrite:  func() {},
			})
		}
		return false
	}
        
        //create a new Stream struct
	buf := newRecvBuffer()
	s := &Stream{
		id:             streamID,
		st:             t,
		buf:            buf,
	        recvCompress:   state.data.encoding,
		method:         state.data.method,
		contentSubtype: state.data.contentSubtype,
	}
        ...
        
        //if client set timeout, set stream context with timeout  
        if state.data.timeoutSet {
		s.ctx, s.cancel = context.WithTimeout(t.ctx, state.data.timeout)
	} else {
		s.ctx, s.cancel = context.WithCancel(t.ctx)
	}

        
        //peer/peer.go, set get remote addr and auth info 
	pr := &peer.Peer{
		Addr: t.remoteAddr,
	}
	// Attach Auth info if there is any.
	if t.authInfo != nil {
		pr.AuthInfo = t.authInfo
	}
	
	//set get context by peer
	s.ctx = peer.NewContext(s.ctx, pr)
	// Attach the received metadata to the context.
	...
	//flow control, implement by userself at NewServer(...)
	if t.inTapHandle != nil {
                ...
		s.ctx, err = t.inTapHandle(s.ctx, info)
	        ...
	}
        ...
        
        //control limit max active stream 
	if uint32(len(t.activeStreams)) >= t.maxStreams {
                ...
		return false
	}
	//set max streamID now
	t.maxStreamID = streamID 
	
	//register streamID
	t.activeStreams[streamID] = s
        ...
        
        //call traceCtx(),new trace and attach context value
	s.ctx = traceCtx(s.ctx, s.method)
	
	// if config, for statistical information,and attach context value
	if t.stats != nil {
		s.ctx = t.stats.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
		inHeader := &stats.InHeader{
			FullMethod:  s.method,
			RemoteAddr:  t.remoteAddr,
			LocalAddr:   t.localAddr,
			Compression: s.recvCompress,
			WireLength:  int(frame.Header().Length),
			Header:      metadata.MD(state.data.mdata).Copy(),
		}
		t.stats.HandleRPC(s.ctx, inHeader)
	}
	s.ctxDone = s.ctx.Done()
	s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
	
	//set transport reader 
	s.trReader = &transportReader{
	        //implement io.Reader interface to read from s.buf(*recvBuffer)
		reader: &recvBufferReader{
			ctx:        s.ctx,
			ctxDone:    s.ctxDone,
			recv:       s.buf,
			freeBuffer: t.bufferPool.put,
		},
		windowHandler: func(n int) {
			t.updateWindow(s, uint32(n))
		},
	}
	// Register the stream with loopy (implement by list link)
	// async set l.estdStreams *outSteam map
	t.controlBuf.put(&registerStream{
		streamID: s.id,
		wq:       s.wq,
	})
	
	//call s.hanldeStream()
	handle(s)
	return false
}


func (t *http2Server) handleData(f *http2.DataFrame) {
        size := f.Header().Length
        ...
	// Select the right stream to dispatch.
	s, ok := t.getStream(f)
	if !ok {
		return
	}
	if size > 0 {
                ...
		if len(f.Data()) > 0 {
			buffer := t.bufferPool.get()
			buffer.Reset()
			buffer.Write(f.Data())
			s.write(recvMsg{buffer: buffer}) //received the stream data and append the s.buf.backlog
		}
	}
	if f.Header().Flags.Has(http2.FlagDataEndStream) {
		// Received the end of stream from the client.
		//set s.state = streamReadDone
		s.compareAndSwapState(streamActive, streamReadDone)
		s.write(recvMsg{err: io.EOF}) //received over
	}
}

handleStream(), grpc服务业务层处理方法,在这个方法中会根据stream中的信息获取对应的服务和方法,

    1.先从当前已注册的单元处理方法中查询对应的方法, 并通过调用processUnaryRPC 函数来处理并返回。

    2.从当前已注册的流处理方法中查询对应的方法,并通过调用processStreamingRPC 函数俩处理并返回。

    3.没有该方法,则进行相应的错误处理,并返回结束 

//st is http2 transport, stream is from http2
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
	sm := stream.Method()
        ...
	pos := strings.LastIndex(sm, "/")
	...
	//get service name
	service := sm[:pos] 
	//get method name
	method := sm[pos+1:] 
        
        //process by method, process unary rpc or streaming rpc
	srv, knownService := s.m[service]
	if knownService { //if exist this service
		if md, ok := srv.md[method]; ok { //if it is simple rpc
			s.processUnaryRPC(t, stream, srv, md, trInfo)
			return
		}
		if sd, ok := srv.sd[method]; ok { //if it is stream rpc
			s.processStreamingRPC(t, stream, srv, sd, trInfo)
			return
		}
	}
	// Unknown service, or known server unknown method. handle
	...
}

        processUnaryRPC(...)

        单元处理方法,grpc 服务普通调用处理的核心函数。

        1)在该函数中会进行统计信息、日志记录

        2)根据头信息获取对应的压缩解压方法并对当前stream中的包内容进行解压

        3)调用该rpc请求对应的用户自定义的服务方法,并获取函数返回结果

        4)调用rpc 服务返回函数,把调用结果压缩返回客户端。

        5)最后在写完body内容后,需要发送处理成功的状态码

//st is http2 transport, stream is from http2, srv is user defined service, md is user defined method
func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.Stream, srv *service, md *MethodDesc, trInfo *traceInfo) (err error) { //stream is from http2
	sh := s.opts.statsHandler 
	//if add stats handler, implement by userself, 
	//./stats/handler.go
	if sh != nil || trInfo != nil || channelz.IsOn() {
            ...
	}
	//binlog, default init at internal/binarylog/binarylog.go
	binlog := binarylog.GetMethodLogger(stream.Method()) //internal/binarylog/binarylog.go, default get nil, unless user set method at first
	//record and write method log
	if binlog != nil {
            ...
	}

	// comp and cp are used for compression.  decomp and dc are used for
	// decompression.  If comp and decomp are both set, they are the same;
	// however they are kept separate to ensure that at most one of the
	// compressor/decompressor variable pairs are set for use later.
	var comp, decomp encoding.Compressor
	var cp Compressor
	var dc Decompressor
	...

	//rpc_util.go, get decompress data d
	d, err := recvAndDecompress(&parser{r: stream}, stream, dc, s.opts.maxReceiveMessageSize, payInfo, decomp)
        ...
	df := func(v interface{}) error {
		if err := s.getCodec(stream.ContentSubtype()).Unmarshal(d, v); err != nil {
			return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err)
		}
	        ...
	        //record stats and log
		return nil
	}
	ctx := NewContextWithServerTransportStream(stream.Context(), stream)
	reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt) //run user defined source method func, if s.opts.unaryInt(user defined unary interceptor) not nil, run it at first
	...
	opts := &transport.Options{Last: true}

	//response to client
	if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
		if err == io.EOF {
			// The entire stream is done (for unary RPC only).
			return err
		}
		//handle err
		...
		return err
	}
        ...
	err = t.WriteStatus(stream, statusOK) //writer header, after writer data
        ...
	return err
}

            recvAndDecompress(...) 函数主要是读取包内容,并根据包头信息获取当前的包内容是否加密,如果加密则解密返回,否则直接返回包内容。下面主要看看读取包内容的过程,通过实现io.Reader 接口读取包内容,采用tlv读取方式,先读取头5个字节,第一个字节是type类型加密/不加密, 后4个字节是包的长度,后面的全部字节即包的内容,核心代码如下:

func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
	if _, err := p.r.Read(p.header[:]); err != nil { //read header and is 5 byte
		return 0, nil, err
	}

	pf = payloadFormat(p.header[0])                 //0 none compression, 1 compression
	length := binary.BigEndian.Uint32(p.header[1:]) // get data length
        ...
	msg = make([]byte, int(length))
	if _, err := p.r.Read(msg); err != nil { //read all data/msg
		if err == io.EOF {
			err = io.ErrUnexpectedEOF
		}
		return 0, nil, err
	}
	return pf, msg, nil
}


           s.sendResponse(...)  encode and compress response body, 调用底层写入方法 t.Write(...) , 至此grpc服务端接收到发送简单过程结束

func (t *http2Server) Write(s *Stream, hdr []byte, data []byte, opts *Options) error {
	if !s.isHeaderSent() { // Headers haven't been written yet., transport.go
	        //write header
		if err := t.WriteHeader(s, nil); err != nil {
		        ...
			return status.Errorf(codes.Internal, "transport: %v", err)
		}
	} else {
		// Writing headers checks for this condition.
		if s.getState() == streamDone { //when fininsh stream - fininshStrem() , or close stream - closeStream()
	                ...
	                return ...
		}
	}
        ...
        //append header 
	hdr = append(hdr, data[:emptyLen]...)
	data = data[emptyLen:]
	
	//create data frame
	df := &dataFrame{
		streamID:    s.id,
		h:           hdr,
		d:           data,
		onEachWrite: t.setResetPingStrikes,
	}
         ...
        //put data to list tail, and send a chan to wakeup consumer list head     
	return t.controlBuf.put(df) 
}



   

备案编号:赣ICP备15011386号

联系方式:qq:1150662577    邮箱:1150662577@qq.com