diff --git a/alias/alias.go b/alias/alias.go index dc234bf..bb13d97 100644 --- a/alias/alias.go +++ b/alias/alias.go @@ -28,10 +28,13 @@ type Alias struct { SshAgent string `toml:"ssh-agent"` Timeout string `toml:"timeout"` SshConfig string `toml:"config"` + Rpc bool `toml:"rpc"` + RpcAddress string `toml:"rpc-address"` } +// String parses a Alias object to a string representation. func (a Alias) String() string { - return fmt.Sprintf("[verbose: %t, insecure: %t, detach: %t, source: %s, destination: %s, server: %s, key: %s, keep-alive-interval: %s, connection-retries: %d, wait-and-retry: %s, ssh-agent: %s, timeout: %s, config: %s]", + return fmt.Sprintf("[verbose: %t, insecure: %t, detach: %t, source: %s, destination: %s, server: %s, key: %s, keep-alive-interval: %s, connection-retries: %d, wait-and-retry: %s, ssh-agent: %s, timeout: %s, config: %s, rpc: %t, rpc-address: %s]", a.Verbose, a.Insecure, a.Detach, @@ -45,12 +48,14 @@ func (a Alias) String() string { a.SshAgent, a.Timeout, a.SshConfig, + a.Rpc, + a.RpcAddress, ) } // Add persists an tunnel alias to the disk func Add(alias *Alias) error { - mp, err := createDir() + mp, err := fsutils.CreateHomeDir() if err != nil { return err } @@ -174,24 +179,6 @@ func Get(aliasName string) (*Alias, error) { return a, nil } -func createDir() (string, error) { - mp, err := fsutils.Dir() - if err != nil { - return "", err - } - - if _, err := os.Stat(mp); !os.IsNotExist(err) { - return mp, nil - } - - err = os.MkdirAll(mp, os.ModePerm) - if err != nil { - return "", err - } - - return mp, nil -} - //FIXME terrible struct name. Change it. type aliases struct { Aliases map[string]*Alias `toml:"aliases"` diff --git a/alias/alias_test.go b/alias/alias_test.go index a230a91..73c0ce9 100644 --- a/alias/alias_test.go +++ b/alias/alias_test.go @@ -67,6 +67,10 @@ func TestShow(t *testing.T) { expected := string(expectedBytes) output, err := alias.Show(id) + if err != nil { + t.Errorf("error showing alias %s: %v", id, err) + } + if output != expected { t.Errorf("output doesn't match. Failing the test.") } diff --git a/alias/testdata/example.toml b/alias/testdata/example.toml index 7d2f0b9..bba780a 100644 --- a/alias/testdata/example.toml +++ b/alias/testdata/example.toml @@ -12,3 +12,5 @@ wait-and-retry = "3s" ssh-agent = "" timeout = "3s" config = "" +rpc = true +rpc-address = "127.0.0.1:0" diff --git a/alias/testdata/show.alias.fixture b/alias/testdata/show.alias.fixture index a418009..e8cba06 100644 --- a/alias/testdata/show.alias.fixture +++ b/alias/testdata/show.alias.fixture @@ -15,6 +15,8 @@ ssh-agent = "" timeout = "3s" config = "" + rpc = true + rpc-address = "127.0.0.1:0" [aliases.test-env] name = "test-env" type = "local" @@ -31,3 +33,5 @@ ssh-agent = "" timeout = "3s" config = "" + rpc = true + rpc-address = "127.0.0.1:0" diff --git a/alias/testdata/show.alias.test-env.fixture b/alias/testdata/show.alias.test-env.fixture index 5892cc7..6d879b3 100644 --- a/alias/testdata/show.alias.test-env.fixture +++ b/alias/testdata/show.alias.test-env.fixture @@ -13,3 +13,5 @@ wait-and-retry = "3s" ssh-agent = "" timeout = "3s" config = "" +rpc = true +rpc-address = "127.0.0.1:0" diff --git a/alias/testdata/test-env.toml b/alias/testdata/test-env.toml index a449187..1d45152 100644 --- a/alias/testdata/test-env.toml +++ b/alias/testdata/test-env.toml @@ -12,3 +12,5 @@ wait-and-retry = "3s" ssh-agent = "" timeout = "3s" config = "" +rpc = true +rpc-address = "127.0.0.1:0" diff --git a/cmd/root.go b/cmd/root.go index 30c31c2..c6b60ec 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -47,6 +47,10 @@ provide 0 to never give up or a negative number to disable`) cmd.Flags().DurationVarP(&conf.WaitAndRetry, "retry-wait", "w", 3*time.Second, "time to wait before trying to reconnect to ssh server") cmd.Flags().StringVarP(&conf.SshAgent, "ssh-agent", "A", "", "unix socket to communicate with a ssh agent") cmd.Flags().DurationVarP(&conf.Timeout, "timeout", "t", 3*time.Second, "ssh server connection timeout") + cmd.Flags().BoolVarP(&conf.Rpc, "rpc", "", false, "enable the rpc server") + cmd.Flags().StringVarP(&conf.RpcAddress, "rpc-address", "", "127.0.0.1:0", `set the network address of the rpc server. +The default value uses a random free port to listen for requests. +The full address is kept on $HOME/.mole/.`) err := cmd.MarkFlagRequired("server") if err != nil { diff --git a/fsutils/fsutils.go b/fsutils/fsutils.go index 1b14e6a..b5449df 100644 --- a/fsutils/fsutils.go +++ b/fsutils/fsutils.go @@ -17,3 +17,42 @@ func Dir() (string, error) { return mp, nil } + +// CreateHomeDir creates then returns the location where all mole related files +// are persisted, including alias configuration and log files. +func CreateHomeDir() (string, error) { + + home, err := Dir() + if err != nil { + return "", err + } + + if _, err := os.Stat(home); os.IsNotExist(err) { + err := os.MkdirAll(home, 0755) + if err != nil { + return "", err + } + } + + return home, err +} + +// CreateInstanceDir creates and then returns the location where all files +// related to a specific mole instance are persisted. +func CreateInstanceDir(appId string) (string, error) { + home, err := Dir() + if err != nil { + return "", err + } + + d := filepath.Join(home, appId) + + if _, err := os.Stat(d); os.IsNotExist(err) { + err := os.MkdirAll(d, 0755) + if err != nil { + return "", err + } + } + + return d, nil +} diff --git a/go.mod b/go.mod index 87e3d81..d2b0e0c 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/prometheus/common v0.10.0 github.com/sevlyar/go-daemon v0.1.5 github.com/sirupsen/logrus v1.6.0 + github.com/sourcegraph/jsonrpc2 v0.0.0-20200429184054-15c2290dcb37 github.com/spf13/cobra v0.0.5 github.com/spf13/pflag v1.0.5 github.com/spf13/viper v1.3.2 diff --git a/go.sum b/go.sum index 1f0af7d..c3f0318 100644 --- a/go.sum +++ b/go.sum @@ -39,6 +39,7 @@ github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/google/go-dap v0.2.0 h1:whjIGQRumwbR40qRU7CEKuFLmePUUc2s4Nt9DoXXxWk= github.com/google/go-dap v0.2.0/go.mod h1:5q8aYQFnHOAZEMP+6vmq25HKYAEwE+LF5yh7JKrrhSQ= +github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= @@ -98,6 +99,8 @@ github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4 github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/sourcegraph/jsonrpc2 v0.0.0-20200429184054-15c2290dcb37 h1:marA1XQDC7N870zmSFIoHZpIUduK80USeY0Rkuflgp4= +github.com/sourcegraph/jsonrpc2 v0.0.0-20200429184054-15c2290dcb37/go.mod h1:ZafdZgk/axhT1cvZAPOhw+95nz2I/Ra5qMlU4gTRwIo= github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cobra v0.0.0-20170417170307-b6cb39589372/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= diff --git a/mole/app.go b/mole/app.go index e18c338..3a076a4 100644 --- a/mole/app.go +++ b/mole/app.go @@ -8,7 +8,6 @@ import ( "strconv" "github.com/davrodpin/mole/fsutils" - "github.com/gofrs/uuid" "github.com/hpcloud/tail" ) @@ -34,26 +33,13 @@ type DetachedInstance struct { // NewDetachedInstance returns a new instance of DetachedInstance, making sure // the application instance directory is created. func NewDetachedInstance(id string) (*DetachedInstance, error) { - instanceDir, err := fsutils.Dir() - if err != nil { - return nil, err - } - if id == "" { - u, err := uuid.NewV4() - if err != nil { - return nil, fmt.Errorf("could not auto generate app instance id: %v", err) - } - id = u.String()[:8] + return nil, fmt.Errorf("application instance id can't be empty") } - home := filepath.Join(instanceDir, id) - - if _, err := os.Stat(home); os.IsNotExist(err) { - err := os.MkdirAll(home, 0755) - if err != nil { - return nil, err - } + _, err := fsutils.CreateInstanceDir(id) + if err != nil { + return nil, err } pfl, err := GetPidFileLocation(id) diff --git a/mole/app_test.go b/mole/app_test.go index e638a15..16c55c4 100644 --- a/mole/app_test.go +++ b/mole/app_test.go @@ -43,18 +43,6 @@ func TestDetachedInstanceFileLocations(t *testing.T) { } -func TestDetachedInstanceGeneratedId(t *testing.T) { - - di, err := mole.NewDetachedInstance("") - if err != nil { - t.Errorf("error creating a new detached instance: %v", err) - } - - if di.Id == "" { - t.Errorf("detached instance id is empty") - } -} - func TestDetachedInstanceAlreadyRunning(t *testing.T) { id := "TestDetachedInstanceAlreadyRunning" diff --git a/mole/mole.go b/mole/mole.go index d7cea77..db4bb68 100644 --- a/mole/mole.go +++ b/mole/mole.go @@ -2,15 +2,21 @@ package mole import ( "fmt" + "io/ioutil" "os" + "path/filepath" + "strconv" "syscall" "time" "github.com/davrodpin/mole/alias" + "github.com/davrodpin/mole/fsutils" + "github.com/davrodpin/mole/rpc" "github.com/davrodpin/mole/tunnel" - "github.com/sevlyar/go-daemon" "github.com/awnumar/memguard" + "github.com/gofrs/uuid" + "github.com/sevlyar/go-daemon" log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh/terminal" ) @@ -31,6 +37,8 @@ type Configuration struct { SshAgent string Timeout time.Duration SshConfig string + Rpc bool + RpcAddress string } // ParseAlias translates a Configuration object to an Alias object. @@ -51,6 +59,8 @@ func (c Configuration) ParseAlias(name string) *alias.Alias { SshAgent: c.SshAgent, Timeout: c.Timeout.String(), SshConfig: c.SshConfig, + Rpc: c.Rpc, + RpcAddress: c.RpcAddress, } } @@ -74,6 +84,14 @@ func (c *Client) Start() error { if c.Conf.Detach { var err error + if c.Conf.Id == "" { + u, err := uuid.NewV4() + if err != nil { + return fmt.Errorf("could not auto generate app instance id: %v", err) + } + c.Conf.Id = u.String()[:8] + } + ic, err := NewDetachedInstance(c.Conf.Id) if err != nil { log.WithError(err).Errorf("error while creating directory to store mole instance related files") @@ -90,10 +108,44 @@ func (c *Client) Start() error { } } + if c.Conf.Id == "" { + c.Conf.Id = strconv.Itoa(os.Getpid()) + } + if c.Conf.Verbose { log.SetLevel(log.DebugLevel) } + d, err := fsutils.CreateInstanceDir(c.Conf.Id) + if err != nil { + log.WithFields(log.Fields{ + "id": c.Conf.Id, + }).WithError(err).Error("error creating directory for mole instance") + + return err + } + + log.Infof(">>> %t %s", c.Conf.Rpc, c.Conf.RpcAddress) + if c.Conf.Rpc { + addr, err := rpc.Start(c.Conf.RpcAddress) + if err != nil { + return err + } + + rd := filepath.Join(d, "rpc") + + err = ioutil.WriteFile(rd, []byte(addr.String()), 0644) + if err != nil { + log.WithFields(log.Fields{ + "id": c.Conf.Id, + }).WithError(err).Error("error creating file with rpc address") + + return err + } + + log.Infof("rpc server address saved on %s", rd) + } + s, err := tunnel.NewServer(c.Conf.Server.User, c.Conf.Server.Address(), c.Conf.Key, c.Conf.SshAgent, c.Conf.SshConfig) if err != nil { log.Errorf("error processing server options: %v\n", err) @@ -142,7 +194,8 @@ func (c *Client) Start() error { //TODO need to find a way to require the attributes below to be always set // since they are not optional (functionality will break if they are not // set and CLI parsing is the one setting the default values). - // That could be done by make them required in the constructor's signature + // That could be done by make them required in the constructor's signature or + // by creating a configuration struct for a tunnel object. t.ConnectionRetries = c.Conf.ConnectionRetries t.WaitAndRetry = c.Conf.WaitAndRetry t.KeepAliveInterval = c.Conf.KeepAliveInterval @@ -267,6 +320,10 @@ func (c *Configuration) Merge(al *alias.Alias, givenFlags []string) error { c.SshConfig = al.SshConfig + c.Rpc = al.Rpc + + c.RpcAddress = al.RpcAddress + return nil } diff --git a/mole/mole_test.go b/mole/mole_test.go index 5f3f3ec..6072656 100644 --- a/mole/mole_test.go +++ b/mole/mole_test.go @@ -8,8 +8,6 @@ import ( ) func TestAliasMerge(t *testing.T) { - //keepAliveInterval, _ := time.ParseDuration("5s") - tests := []struct { alias *alias.Alias givenFlags []string diff --git a/rpc/doc.go b/rpc/doc.go new file mode 100644 index 0000000..4ab5bc1 --- /dev/null +++ b/rpc/doc.go @@ -0,0 +1,10 @@ +/* +Package rpc implements a JSON-RPC 2.0 server that is used to call registered +procedures. + +For more information about JSON-RPC 2.0, please visit: + +https://www.jsonrpc.org/specification +*/ + +package rpc diff --git a/rpc/rpc.go b/rpc/rpc.go new file mode 100644 index 0000000..a1510e5 --- /dev/null +++ b/rpc/rpc.go @@ -0,0 +1,158 @@ +package rpc + +import ( + "context" + "fmt" + "net" + "sync" + + "encoding/json" + + log "github.com/sirupsen/logrus" + "github.com/sourcegraph/jsonrpc2" +) + +var registeredMethods = sync.Map{} +var listener net.Listener + +const ( + DefaultAddress = "127.0.0.1:0" +) + +// Start initializes the jsonrpc 2.0 server which will be waiting for +// connections on a random port. +func Start(address string) (net.Addr, error) { + var err error + + if address == "" { + address = DefaultAddress + } + + listener, err = net.Listen("tcp", address) + if err != nil { + return nil, err + } + + ctx := context.Background() + h := &Handler{} + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + log.WithError(err).Warnf("error establishing connection with rpc client.") + } + stream := jsonrpc2.NewBufferedStream(conn, jsonrpc2.VarintObjectCodec{}) + jsonrpc2.NewConn(ctx, stream, h) + } + }() + + return listener.Addr(), nil +} + +// Handler handles JSON-RPC requests and notifications. +type Handler struct{} + +// Handle manages JSON-RPC requests and notifications, executing the requested +// method and responding back to the client when needed. +func (h *Handler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + log.WithFields(log.Fields{ + "notification": req.Notif, + "method": req.Method, + "id": req.ID, + }).Info("rpc request received") + + if _, ok := registeredMethods.Load(req.Method); !ok { + log.Errorf("rpc request method %s not supported", req.Method) + + if !req.Notif { + resp := &jsonrpc2.Response{} + resp.SetResult(jsonrpc2.Error{ + Code: jsonrpc2.CodeMethodNotFound, + Message: fmt.Sprintf("method %s not found", req.Method), + }) + + sendResponse(ctx, conn, req, resp) + } + + return + } + + m, _ := registeredMethods.Load(req.Method) + rm, err := m.(Method)() + if err != nil { + log.WithFields(log.Fields{ + "notification": req.Notif, + "method": req.Method, + "id": req.ID, + }).WithError(err).Warn("error executing rpc method.") + + if !req.Notif { + resp := &jsonrpc2.Response{} + resp.SetResult(jsonrpc2.Error{ + Code: jsonrpc2.CodeInternalError, + Message: fmt.Sprintf("error executing rpc method %s", req.Method), + }) + + sendResponse(ctx, conn, req, resp) + + return + } + } + + if !req.Notif { + resp := &jsonrpc2.Response{ID: req.ID, Result: &rm} + + sendResponse(ctx, conn, req, resp) + } +} + +// Register adds a new method that can be called remotely. +func Register(name string, method Method) { + registeredMethods.Store(name, method) +} + +// Method represents a procedure that can be called remotely. +type Method func() (json.RawMessage, error) + +// Call initiates a JSON-RPC call using the specified method and waits for the +// response. +func Call(ctx context.Context, method string) (map[string]interface{}, error) { + tc, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + return nil, err + } + + stream := jsonrpc2.NewBufferedStream(tc, jsonrpc2.VarintObjectCodec{}) + h := &Handler{} + conn := jsonrpc2.NewConn(ctx, stream, h) + + var r map[string]interface{} + err = conn.Call(ctx, method, nil, &r) + if err != nil { + return nil, err + } + + return r, nil +} + +func sendResponse(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request, resp *jsonrpc2.Response) error { + if err := conn.SendResponse(ctx, resp); err != nil { + log.WithFields(log.Fields{ + "notification": req.Notif, + "method": req.Method, + "id": req.ID, + }).WithError(err).Warn("error sending rpc response.") + + return err + } + + log.WithFields(log.Fields{ + "notification": req.Notif, + "method": req.Method, + "id": req.ID, + }).Info("rpc response sent.") + + return nil + +} diff --git a/rpc/rpc_test.go b/rpc/rpc_test.go new file mode 100644 index 0000000..47c9291 --- /dev/null +++ b/rpc/rpc_test.go @@ -0,0 +1,86 @@ +package rpc_test + +import ( + "context" + "encoding/json" + "fmt" + "os" + "testing" + + "github.com/davrodpin/mole/rpc" +) + +func TestHandler(t *testing.T) { + method := "test" + expectedResponse := `{"message":"test"}` + + rpc.Register(method, func() (json.RawMessage, error) { + return json.RawMessage(expectedResponse), nil + }) + + response, err := rpc.Call(context.Background(), method) + if err != nil { + t.Errorf("error while calling remote procedure: %v", err) + } + + json, err := json.Marshal(response) + if err != nil { + t.Errorf("error while parsing response to string: response: %s, err: %v", response, err) + } + + if expectedResponse != string(json) { + t.Errorf("unexpected response for remote procedure call: want: %s, got: %s", expectedResponse, string(json)) + } +} + +func TestMethodNotRegistered(t *testing.T) { + method := "methodnotregistered" + expectedResponse := fmt.Sprintf(`{"code":-32601,"data":null,"message":"method %s not found"}`, method) + + response, err := rpc.Call(context.Background(), method) + if err != nil { + t.Errorf("error while calling remote procedure: %v", err) + } + + json, err := json.Marshal(response) + if err != nil { + t.Errorf("error while parsing response to string: response: %s, err: %v", response, err) + } + + if expectedResponse != string(json) { + t.Errorf("unexpected response for remote procedure call: want: %s, got: %s", expectedResponse, string(json)) + } +} + +func TestMethodWithError(t *testing.T) { + method := "testwitherror" + expectedResponse := fmt.Sprintf(`{"code":-32603,"data":null,"message":"error executing rpc method %s"}`, method) + + rpc.Register(method, func() (json.RawMessage, error) { + return nil, fmt.Errorf("error") + }) + + response, err := rpc.Call(context.Background(), method) + if err != nil { + t.Errorf("error while calling remote procedure: %v", err) + } + + json, err := json.Marshal(response) + if err != nil { + t.Errorf("error while parsing response to string: response: %s, err: %v", response, err) + } + + if expectedResponse != string(json) { + t.Errorf("unexpected response for remote procedure call: want: %s, got: %s", expectedResponse, string(json)) + } +} + +func TestMain(m *testing.M) { + _, err := rpc.Start() + if err != nil { + fmt.Printf("error initializing rpc server: %v", err) + os.Exit(1) + } + + os.Exit(m.Run()) +}