package main import ( "context" "database/sql" "fmt" "git.gqnotes.com/guoqiang/grpcexercises/consuldemo/pkg/consul" "git.gqnotes.com/guoqiang/grpcexercises/consuldemo/pkg/db/conn" "git.gqnotes.com/guoqiang/grpcexercises/consuldemo/pkg/db/user" "github.com/hashicorp/consul/api" "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" "google.golang.org/grpc/codes" "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" "log" "net" "net/http" "os" "os/signal" "syscall" "time" "git.gqnotes.com/guoqiang/grpcexercises/consuldemo/pb" "go.uber.org/zap" "golang.org/x/sync/errgroup" "google.golang.org/grpc" ) func main() { // consul地址,此处为本地地址 consulAddress := "localhost:8500" // 服务名称 serviceName := "service-grpcdemo" logger := zap.NewExample() // 实现服务端逻辑 lis, err := net.Listen("tcp", ":5630") if err != nil { panic(err) } s := grpc.NewServer() pb.RegisterGreetServiceServer(s, &Server{logger: logger}) reflection.Register(s) logger.Info("grpc server start") // 启动相关服务 var gg errgroup.Group gg.Go(func() error { return conn.InitMySQLConn(&conn.MySQLConn{ Host: "localhost", Port: 3306, Username: "test1", Password: "test1", DBName: "grpcstudy", }) }) gg.Go(func() error { logger.Debug("grpc server start") go func() { if err := s.Serve(lis); err != nil { logger.Fatal("grpc server start failed", zap.Error(err)) } }() return nil }) gg.Go(func() error { if err := consul.InitConsulClient(consulAddress); err != nil { return err } logger.Debug("consul client start") return nil }) if err = gg.Wait(); err != nil { logger.Fatal("server start failed", zap.Error(err)) return } // 注册服务 // 健康检查 go func() { mux := http.NewServeMux() mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("ok")) }) err = http.ListenAndServe(":8081", mux) if err != nil { logger.Fatal("health check start failed", zap.Error(err)) } }() check := &api.AgentServiceCheck{ Interval: "3s", Timeout: "2s", DeregisterCriticalServiceAfter: "60s", HTTP: "http://127.0.0.1:8081/health", Method: "GET", } err = consul.GetConsulClient().Agent().ServiceRegister(&api.AgentServiceRegistration{ Name: serviceName, Port: 5630, Check: check, ID: "1001", Tags: []string{"grpc.port=5630"}, }) if err != nil { logger.Fatal("register service failed", zap.Error(err)) return } // 优雅关闭 ch := make(chan os.Signal, 1) // 监听信号 signal.Notify(ch, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL, syscall.SIGHUP, syscall.SIGQUIT) // 阻塞等待信号 sig := <-ch logger.Info("receive signal", zap.Any("signal", sig)) // 关闭服务 s.GracefulStop() logger.Debug("server stop") } type Server struct { pb.UnimplementedGreetServiceServer logger *zap.Logger } func (s *Server) Greet(ctx context.Context, req *pb.GreetRequest) (resp *pb.GreetResponse, err error) { t0 := time.Now() defer func() { log.Println("Greet cost:", time.Since(t0)) }() resp = &pb.GreetResponse{ Result: "hello, " + req.Greeting + "!", } return } func (s *Server) GreetManyTimes(request *pb.GreetRequest, stream pb.GreetService_GreetManyTimesServer) (err error) { for i := 0; i < 10; i++ { err = stream.Send(&pb.GreetResponse{ Result: "hello, " + request.Greeting + "!", }) if err != nil { return } time.Sleep(time.Millisecond * 1) } return } // CreateUser 创建用户 func (s *Server) CreateUser(ctx context.Context, req *pb.CreateUserRequest) (resp *pb.CreateUserResponse, err error) { resp = &pb.CreateUserResponse{} // 获取密码hash值 passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 12) if err != nil { err = status.Error(codes.Internal, "password hash failed:"+err.Error()) return } queries := user.New(conn.GetMySQLConn()) result, err := queries.CreateUser(ctx, user.CreateUserParams{ Username: req.Username, Mobile: req.Mobile, Email: req.Email, Password: passwordHash, }) if err != nil { err = status.Error(codes.Internal, "create user failed:"+err.Error()) return } // 获取插入的id id, err := result.LastInsertId() if err != nil { err = status.Error(codes.Internal, "create user failed") return } resp.Id = id return } // GetUser 获取用户 func (s *Server) GetUser(ctx context.Context, req *pb.GetUserRequest) (resp *pb.GetUserResponse, err error) { resp = &pb.GetUserResponse{} defer func() { if err1 := recover(); err1 != nil { s.logger.Fatal("get user failed", zap.Any("err", err1)) err = status.Error(codes.Internal, fmt.Sprintf("panic: %v", err1)) } if err != nil { s.logger.Error("get user failed", zap.Error(err)) } else { s.logger.Info("get user success") } }() queries := user.New(conn.GetMySQLConn()) result, err := queries.GetUserById(ctx, uint32(req.Id)) if err != nil { if !errors.Is(err, sql.ErrNoRows) { err = status.Error(codes.Internal, "get user failed") return } err = status.Error(codes.NotFound, "user not found") return } resp.Id = int64(result.ID) resp.Username = result.Username resp.Mobile = result.Mobile resp.Email = result.Email return } // CheckPassword 校验密码 func (s *Server) CheckPassword(ctx context.Context, req *pb.CheckPasswordRequest) (resp *pb.CheckPasswordResponse, err error) { resp = &pb.CheckPasswordResponse{} defer func() { if err1 := recover(); err1 != nil { s.logger.Fatal("check password failed", zap.Any("err", err1)) err = status.Error(codes.Internal, fmt.Sprintf("panic: %v", err1)) } if err != nil { s.logger.Error("check password failed", zap.Error(err)) } else { s.logger.Info("check password success") } }() queries := user.New(conn.GetMySQLConn()) // 查询用户信息 result, err := queries.GetUserByUsername(ctx, req.Username) if err != nil { if !errors.Is(err, sql.ErrNoRows) { err = status.Error(codes.Internal, "check password failed:"+err.Error()) return } err = status.Error(codes.NotFound, "user not found") return } // 校验密码 if err = bcrypt.CompareHashAndPassword(result.Password, []byte(req.Password)); err != nil { err = status.Error(codes.Unauthenticated, "password incorrect") return } resp.Ok = true return }