grpcexercises/consuldemo/server/main.go

324 lines
6.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"context"
"database/sql"
"fmt"
"git.gqnotes.com/guoqiang/grpcexercises/consuldemo/pkg/consul"
"git.gqnotes.com/guoqiang/grpcexercises/consuldemo/pkg/db/conn"
"git.gqnotes.com/guoqiang/grpcexercises/consuldemo/pkg/db/user"
"github.com/hashicorp/consul/api"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/reflection"
"google.golang.org/grpc/status"
"log"
"net"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"syscall"
"time"
"git.gqnotes.com/guoqiang/grpcexercises/consuldemo/pb"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
)
func allocDebug() {
// 分配1MB内存
_ = make([]byte, 1<<20)
}
func main() {
// consul地址此处为本地地址
consulAddress := "localhost:8500"
// 服务名称
serviceName := "service-grpcdemo"
logger := zap.NewExample()
// 实现服务端逻辑
lis, err := net.Listen("tcp", ":5630")
if err != nil {
panic(err)
}
s := grpc.NewServer()
pb.RegisterGreetServiceServer(s, &Server{logger: logger})
reflection.Register(s)
logger.Info("grpc server start")
go func() {
err = http.ListenAndServe(":6060", nil)
if err != nil {
logger.Fatal("pprof start failed", zap.Error(err))
}
}()
// 测试代码
go func() {
for {
allocDebug()
time.Sleep(time.Second)
}
}()
// 启动相关服务
var gg errgroup.Group
gg.Go(func() error {
return conn.InitMySQLConn(&conn.MySQLConn{
Host: "localhost",
Port: 3306,
Username: "test1",
Password: "test1",
DBName: "grpcstudy",
})
})
gg.Go(func() error {
logger.Debug("grpc server start")
go func() {
if err := s.Serve(lis); err != nil {
logger.Fatal("grpc server start failed", zap.Error(err))
}
}()
return nil
})
gg.Go(func() error {
if err := consul.InitConsulClient(consulAddress); err != nil {
return err
}
logger.Debug("consul client start")
return nil
})
if err = gg.Wait(); err != nil {
logger.Fatal("server start failed", zap.Error(err))
return
}
// 注册服务
// 健康检查
go func() {
mux := http.NewServeMux()
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("ok"))
})
err = http.ListenAndServe(":8081", mux)
if err != nil {
logger.Fatal("health check start failed", zap.Error(err))
}
}()
check := &api.AgentServiceCheck{
Interval: "3s",
Timeout: "2s",
DeregisterCriticalServiceAfter: "60s",
HTTP: "http://127.0.0.1:8081/health",
Method: "GET",
}
err = consul.GetConsulClient().Agent().ServiceRegister(&api.AgentServiceRegistration{
Name: serviceName,
Port: 5630,
Check: check,
ID: "1001",
Tags: []string{"grpc.port=5630"},
})
if err != nil {
logger.Fatal("register service failed", zap.Error(err))
return
}
// 优雅关闭
ch := make(chan os.Signal, 1)
// 监听信号
signal.Notify(ch, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL, syscall.SIGHUP, syscall.SIGQUIT)
// 阻塞等待信号
sig := <-ch
logger.Info("receive signal", zap.Any("signal", sig))
// 关闭服务
s.GracefulStop()
logger.Debug("server stop")
}
type Server struct {
pb.UnimplementedGreetServiceServer
logger *zap.Logger
}
func (s *Server) Greet(ctx context.Context, req *pb.GreetRequest) (resp *pb.GreetResponse, err error) {
t0 := time.Now()
defer func() {
log.Println("Greet cost:", time.Since(t0))
}()
resp = &pb.GreetResponse{
Result: "hello, " + req.Greeting + "!",
}
return
}
func (s *Server) GreetManyTimes(request *pb.GreetRequest, stream pb.GreetService_GreetManyTimesServer) (err error) {
for i := 0; i < 10; i++ {
err = stream.Send(&pb.GreetResponse{
Result: "hello, " + request.Greeting + "!",
})
if err != nil {
return
}
time.Sleep(time.Millisecond * 1)
}
return
}
// CreateUser 创建用户
func (s *Server) CreateUser(ctx context.Context, req *pb.CreateUserRequest) (resp *pb.CreateUserResponse, err error) {
resp = &pb.CreateUserResponse{}
// 获取密码hash值
passwordHash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 12)
if err != nil {
err = status.Error(codes.Internal, "password hash failed:"+err.Error())
return
}
queries := user.New(conn.GetMySQLConn())
result, err := queries.CreateUser(ctx, user.CreateUserParams{
Username: req.Username,
Mobile: req.Mobile,
Email: req.Email,
Password: passwordHash,
})
if err != nil {
err = status.Error(codes.Internal, "create user failed:"+err.Error())
return
}
// 获取插入的id
id, err := result.LastInsertId()
if err != nil {
err = status.Error(codes.Internal, "create user failed")
return
}
resp.Id = id
return
}
// GetUser 获取用户
func (s *Server) GetUser(ctx context.Context, req *pb.GetUserRequest) (resp *pb.GetUserResponse, err error) {
resp = &pb.GetUserResponse{}
defer func() {
if err1 := recover(); err1 != nil {
s.logger.Fatal("get user failed", zap.Any("err", err1))
err = status.Error(codes.Internal, fmt.Sprintf("panic: %v", err1))
}
if err != nil {
s.logger.Error("get user failed", zap.Error(err))
} else {
s.logger.Info("get user success")
}
}()
queries := user.New(conn.GetMySQLConn())
result, err := queries.GetUserById(ctx, uint32(req.Id))
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
err = status.Error(codes.Internal, "get user failed")
return
}
err = status.Error(codes.NotFound, "user not found")
return
}
resp.Id = int64(result.ID)
resp.Username = result.Username
resp.Mobile = result.Mobile
resp.Email = result.Email
return
}
// CheckPassword 校验密码
func (s *Server) CheckPassword(ctx context.Context, req *pb.CheckPasswordRequest) (resp *pb.CheckPasswordResponse, err error) {
resp = &pb.CheckPasswordResponse{}
defer func() {
if err1 := recover(); err1 != nil {
s.logger.Fatal("check password failed", zap.Any("err", err1))
err = status.Error(codes.Internal, fmt.Sprintf("panic: %v", err1))
}
if err != nil {
s.logger.Error("check password failed", zap.Error(err))
} else {
s.logger.Info("check password success")
}
}()
queries := user.New(conn.GetMySQLConn())
// 查询用户信息
result, err := queries.GetUserByUsername(ctx, req.Username)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
err = status.Error(codes.Internal, "check password failed:"+err.Error())
return
}
err = status.Error(codes.NotFound, "user not found")
return
}
// 校验密码
if err = bcrypt.CompareHashAndPassword(result.Password, []byte(req.Password)); err != nil {
err = status.Error(codes.Unauthenticated, "password incorrect")
return
}
resp.Ok = true
return
}