Skip to content

Commit

Permalink
fix: check address when GetCallerAddress to prevent panic
Browse files Browse the repository at this point in the history
  • Loading branch information
ppzqh committed Feb 10, 2025
1 parent 047444c commit 6241434
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
5 changes: 2 additions & 3 deletions pkg/utils/kitexutil/kitexutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ func GetMethod(ctx context.Context) (string, bool) {

// GetCallerHandlerMethod is used to get the method name of caller.
// Only the caller is a Kitex server will have this method information by default, or you can set K_METHOD into context.Context then kitex will get it.
// Return false if failed to get the information.
func GetCallerHandlerMethod(ctx context.Context) (string, bool) {
defer func() { recover() }()

Expand Down Expand Up @@ -92,7 +91,7 @@ func GetCallerAddr(ctx context.Context) (net.Addr, bool) {
defer func() { recover() }()

ri := rpcinfo.GetRPCInfo(ctx)
if ri == nil {
if ri == nil || ri.From() == nil || ri.From().Address() == nil {
return nil, false
}
return ri.From().Address(), true
Expand All @@ -104,7 +103,7 @@ func GetCallerIP(ctx context.Context) (string, bool) {
defer func() { recover() }()

ri := rpcinfo.GetRPCInfo(ctx)
if ri == nil {
if ri == nil || ri.From() == nil || ri.From().Address() == nil {
return "", false
}
addrStr := ri.From().Address().String()
Expand Down
45 changes: 24 additions & 21 deletions pkg/utils/kitexutil/kitexutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@ import (
)

var (
testRi rpcinfo.RPCInfo
testCtx context.Context
panicCtx context.Context
caller = "kitexutil.from.service"
callee = "kitexutil.to.service"
idlServiceName = "MockService"
fromAddr = utils.NewNetAddr("test", "127.0.0.1:12345")
fromMethod = "from_method"
method = "method"
tp = transport.TTHeader
testRi rpcinfo.RPCInfo
testCtx context.Context
panicCtx context.Context
testCaller = "kitexutil.from.service"
testCallee = "kitexutil.to.service"
testIdlServiceName = "MockService"
testFromAddr = utils.NewNetAddr("test", "127.0.0.1:12345")
testFromMethod = "from_method"
testMethod = "testMethod"
testTp = transport.TTHeader
)

func TestMain(m *testing.M) {
testRi = buildRPCInfo()
testRi = buildRPCInfo(testCaller, testFromMethod, testCallee, testMethod, testIdlServiceName, testFromAddr, testTp)

testCtx = context.Background()
testCtx = rpcinfo.NewCtxWithRPCInfo(testCtx, testRi)
Expand All @@ -63,7 +63,7 @@ func TestGetCaller(t *testing.T) {
want string
want1 bool
}{
{name: "Success", args: args{testCtx}, want: caller, want1: true},
{name: "Success", args: args{testCtx}, want: testCaller, want1: true},
{name: "Failure", args: args{context.Background()}, want: "", want1: false},
{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
}
Expand All @@ -90,7 +90,7 @@ func TestGetCallee(t *testing.T) {
want string
want1 bool
}{
{name: "Success", args: args{testCtx}, want: callee, want1: true},
{name: "Success", args: args{testCtx}, want: testCallee, want1: true},
{name: "Failure", args: args{context.Background()}, want: "", want1: false},
{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
}
Expand All @@ -108,6 +108,8 @@ func TestGetCallee(t *testing.T) {
}

func TestGetCallerAddr(t *testing.T) {
riWithoutAddress := buildRPCInfo(testCaller, testFromMethod, testCallee, testMethod, testIdlServiceName, nil, testTp)
ctxWithoutAddress := rpcinfo.NewCtxWithRPCInfo(context.Background(), riWithoutAddress)
type args struct {
ctx context.Context
}
Expand All @@ -117,8 +119,9 @@ func TestGetCallerAddr(t *testing.T) {
want net.Addr
want1 bool
}{
{name: "Success", args: args{testCtx}, want: fromAddr, want1: true},
{name: "Failure", args: args{context.Background()}, want: nil, want1: false},
{name: "Success", args: args{testCtx}, want: testFromAddr, want1: true},
{name: "Failure: nil rpcinfo", args: args{context.Background()}, want: nil, want1: false},
{name: "Failure: nil address", args: args{ctxWithoutAddress}, want: nil, want1: false},
{name: "Panic recovered", args: args{panicCtx}, want: nil, want1: false},
}
for _, tt := range tests {
Expand All @@ -139,7 +142,7 @@ func TestGetCallerIP(t *testing.T) {
test.Assert(t, ok)
test.Assert(t, ip == "127.0.0.1", ip)

ri := buildRPCInfo()
ri := buildRPCInfo(testCaller, testFromMethod, testCallee, testMethod, testIdlServiceName, testFromAddr, testTp)
rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(utils.NewNetAddr("test", "127.0.0.1"))
ip, ok = GetCallerIP(rpcinfo.NewCtxWithRPCInfo(context.Background(), ri))
test.Assert(t, ok)
Expand Down Expand Up @@ -169,7 +172,7 @@ func TestGetMethod(t *testing.T) {
want string
want1 bool
}{
{name: "Success", args: args{testCtx}, want: method, want1: true},
{name: "Success", args: args{testCtx}, want: testMethod, want1: true},
{name: "Failure", args: args{context.Background()}, want: "", want1: false},
}
for _, tt := range tests {
Expand All @@ -195,7 +198,7 @@ func TestGetCallerHandlerMethod(t *testing.T) {
want string
want1 bool
}{
{name: "Success", args: args{testCtx}, want: fromMethod, want1: true},
{name: "Success", args: args{testCtx}, want: testFromMethod, want1: true},
{name: "Failure", args: args{context.Background()}, want: "", want1: false},
{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
}
Expand All @@ -222,7 +225,7 @@ func TestGetIDLServiceName(t *testing.T) {
want string
want1 bool
}{
{name: "Success", args: args{testCtx}, want: idlServiceName, want1: true},
{name: "Success", args: args{testCtx}, want: testIdlServiceName, want1: true},
{name: "Failure", args: args{context.Background()}, want: "", want1: false},
{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
}
Expand Down Expand Up @@ -275,7 +278,7 @@ func TestGetCtxTransportProtocol(t *testing.T) {
want string
want1 bool
}{
{name: "Success", args: args{testCtx}, want: tp.String(), want1: true},
{name: "Success", args: args{testCtx}, want: testTp.String(), want1: true},
{name: "Failure", args: args{context.Background()}, want: "", want1: false},
{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
}
Expand Down Expand Up @@ -340,7 +343,7 @@ func TestGetRealResponse(t *testing.T) {
}
}

func buildRPCInfo() rpcinfo.RPCInfo {
func buildRPCInfo(caller, fromMethod, callee, method, idlServiceName string, fromAddr net.Addr, tp transport.Protocol) rpcinfo.RPCInfo {
from := rpcinfo.NewEndpointInfo(caller, fromMethod, fromAddr, nil)
to := rpcinfo.NewEndpointInfo(callee, method, nil, nil)
ink := rpcinfo.NewInvocation(idlServiceName, method)
Expand Down

0 comments on commit 6241434

Please # to comment.