Skip to content

Commit a0d2cc9

Browse files
committed
Add wush cp for single file transfers without rsync
1 parent 32f2838 commit a0d2cc9

File tree

9 files changed

+393
-13
lines changed

9 files changed

+393
-13
lines changed

cmd/wush/cp.go

+194
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
package main
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"io"
7+
"log/slog"
8+
"net/http"
9+
"net/http/httputil"
10+
"net/netip"
11+
"os"
12+
"path/filepath"
13+
14+
"github.com/charmbracelet/huh"
15+
"github.com/coder/serpent"
16+
"github.com/coder/wush/cliui"
17+
"github.com/coder/wush/overlay"
18+
"github.com/coder/wush/tsserver"
19+
"github.com/schollz/progressbar/v3"
20+
"tailscale.com/net/netns"
21+
)
22+
23+
func cpCmd() *serpent.Command {
24+
var (
25+
authID string
26+
waitP2P bool
27+
stunAddrOverride string
28+
stunAddrOverrideIP netip.Addr
29+
)
30+
return &serpent.Command{
31+
Use: "cp <file>",
32+
Short: "Transfer files.",
33+
Long: "Transfer files to a " + cliui.Code("wush") + " peer. ",
34+
Middleware: serpent.Chain(
35+
serpent.RequireNArgs(1),
36+
),
37+
Handler: func(inv *serpent.Invocation) error {
38+
ctx := inv.Context()
39+
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
40+
logF := func(str string, args ...any) {
41+
fmt.Fprintf(inv.Stderr, str+"\n", args...)
42+
}
43+
44+
if authID == "" {
45+
err := huh.NewInput().
46+
Title("Enter your Auth ID:").
47+
Value(&authID).
48+
Run()
49+
if err != nil {
50+
return fmt.Errorf("get auth id: %w", err)
51+
}
52+
}
53+
54+
dm, err := tsserver.DERPMapTailscale(ctx)
55+
if err != nil {
56+
return err
57+
}
58+
59+
if stunAddrOverride != "" {
60+
stunAddrOverrideIP, err = netip.ParseAddr(stunAddrOverride)
61+
if err != nil {
62+
return fmt.Errorf("parse stun addr override: %w", err)
63+
}
64+
}
65+
66+
send := overlay.NewSendOverlay(logger, dm)
67+
send.STUNIPOverride = stunAddrOverrideIP
68+
69+
err = send.Auth.Parse(authID)
70+
if err != nil {
71+
return fmt.Errorf("parse auth key: %w", err)
72+
}
73+
74+
logF("Auth information:")
75+
stunStr := send.Auth.ReceiverStunAddr.String()
76+
if !send.Auth.ReceiverStunAddr.IsValid() {
77+
stunStr = "Disabled"
78+
}
79+
logF("\t> Server overlay STUN address: %s", cliui.Code(stunStr))
80+
derpStr := "Disabled"
81+
if send.Auth.ReceiverDERPRegionID > 0 {
82+
derpStr = dm.Regions[int(send.Auth.ReceiverDERPRegionID)].RegionName
83+
}
84+
logF("\t> Server overlay DERP home: %s", cliui.Code(derpStr))
85+
logF("\t> Server overlay public key: %s", cliui.Code(send.Auth.ReceiverPublicKey.ShortString()))
86+
logF("\t> Server overlay auth key: %s", cliui.Code(send.Auth.OverlayPrivateKey.Public().ShortString()))
87+
88+
s, err := tsserver.NewServer(ctx, logger, send)
89+
if err != nil {
90+
return err
91+
}
92+
93+
if send.Auth.ReceiverDERPRegionID != 0 {
94+
go send.ListenOverlayDERP(ctx)
95+
} else if send.Auth.ReceiverStunAddr.IsValid() {
96+
go send.ListenOverlaySTUN(ctx)
97+
} else {
98+
return errors.New("auth key provided neither DERP nor STUN")
99+
}
100+
101+
go s.ListenAndServe(ctx)
102+
netns.SetDialerOverride(s.Dialer())
103+
ts, err := newTSNet("send")
104+
if err != nil {
105+
return err
106+
}
107+
ts.Logf = func(string, ...any) {}
108+
ts.UserLogf = func(string, ...any) {}
109+
110+
logF("Bringing Wireguard up..")
111+
ts.Up(ctx)
112+
logF("Wireguard is ready!")
113+
114+
lc, err := ts.LocalClient()
115+
if err != nil {
116+
return err
117+
}
118+
119+
ip, err := waitUntilHasPeerHasIP(ctx, logF, lc)
120+
if err != nil {
121+
return err
122+
}
123+
124+
if waitP2P {
125+
err := waitUntilHasP2P(ctx, logF, lc)
126+
if err != nil {
127+
return err
128+
}
129+
}
130+
131+
fiPath := inv.Args[0]
132+
fiName := filepath.Base(inv.Args[0])
133+
134+
fi, err := os.Open(fiPath)
135+
if err != nil {
136+
return err
137+
}
138+
defer fi.Close()
139+
140+
fiStat, err := fi.Stat()
141+
if err != nil {
142+
return err
143+
}
144+
145+
bar := progressbar.DefaultBytes(
146+
fiStat.Size(),
147+
fmt.Sprintf("Uploading %q", fiPath),
148+
)
149+
barReader := progressbar.NewReader(fi, bar)
150+
151+
hc := ts.HTTPClient()
152+
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://%s:4444/%s", ip.String(), fiName), &barReader)
153+
if err != nil {
154+
return err
155+
}
156+
req.ContentLength = fiStat.Size()
157+
158+
res, err := hc.Do(req)
159+
if err != nil {
160+
return err
161+
}
162+
defer res.Body.Close()
163+
164+
out, err := httputil.DumpResponse(res, true)
165+
if err != nil {
166+
return err
167+
}
168+
bar.Close()
169+
fmt.Println(string(out))
170+
171+
return nil
172+
},
173+
Options: []serpent.Option{
174+
{
175+
Flag: "auth-id",
176+
Env: "WUSH_AUTH_ID",
177+
Description: "The auth id returned by " + cliui.Code("wush receive") + ". If not provided, it will be asked for on startup.",
178+
Default: "",
179+
Value: serpent.StringOf(&authID),
180+
},
181+
{
182+
Flag: "stun-ip-override",
183+
Default: "",
184+
Value: serpent.StringOf(&stunAddrOverride),
185+
},
186+
{
187+
Flag: "wait-p2p",
188+
Description: "Waits for the connection to be p2p.",
189+
Default: "false",
190+
Value: serpent.BoolOf(&waitP2P),
191+
},
192+
},
193+
}
194+
}

cmd/wush/main.go

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ func main() {
4646
sshCmd(),
4747
receiveCmd(),
4848
rsyncCmd(),
49+
cpCmd(),
4950
},
5051
Options: []serpent.Option{
5152
{

cmd/wush/receive.go

+44-2
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ import (
44
"fmt"
55
"io"
66
"log/slog"
7+
"net/http"
78
"os"
9+
"strings"
810
"time"
911

1012
"github.com/prometheus/client_golang/prometheus"
13+
"github.com/schollz/progressbar/v3"
1114
"github.com/spf13/afero"
1215
"golang.org/x/xerrors"
1316
"tailscale.com/ipn/store"
@@ -31,6 +34,7 @@ func receiveCmd() *serpent.Command {
3134
return &serpent.Command{
3235
Use: "receive",
3336
Aliases: []string{"host"},
37+
Short: "Run the wush server.",
3438
Long: "Runs the wush server. Allows other wush CLIs to connect to this computer.",
3539
Handler: func(inv *serpent.Invocation) error {
3640
ctx := inv.Context()
@@ -95,19 +99,57 @@ func receiveCmd() *serpent.Command {
9599
return err
96100
}
97101

98-
ls, err := ts.Listen("tcp", ":3")
102+
sshListener, err := ts.Listen("tcp", ":3")
99103
if err != nil {
100104
return err
101105
}
102106

103107
go func() {
104108
fmt.Println(cliui.Timestamp(time.Now()), "SSH server listening")
105-
err := sshSrv.Serve(ls)
109+
err := sshSrv.Serve(sshListener)
106110
if err != nil {
107111
logger.Info("ssh server exited", "err", err)
108112
}
109113
}()
110114

115+
cpListener, err := ts.Listen("tcp", ":4444")
116+
if err != nil {
117+
return err
118+
}
119+
120+
go http.Serve(cpListener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
121+
if r.Method != "POST" {
122+
w.WriteHeader(http.StatusOK)
123+
w.Write([]byte("OK"))
124+
return
125+
}
126+
127+
fiName := strings.TrimPrefix(r.URL.Path, "/")
128+
defer r.Body.Close()
129+
130+
fi, err := os.OpenFile(fiName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644)
131+
if err != nil {
132+
http.Error(w, err.Error(), http.StatusInternalServerError)
133+
return
134+
}
135+
136+
bar := progressbar.DefaultBytes(
137+
r.ContentLength,
138+
fmt.Sprintf("Downloading %q", fiName),
139+
)
140+
_, err = io.Copy(io.MultiWriter(fi, bar), r.Body)
141+
if err != nil {
142+
http.Error(w, err.Error(), http.StatusInternalServerError)
143+
return
144+
}
145+
fi.Close()
146+
bar.Close()
147+
148+
w.WriteHeader(http.StatusOK)
149+
w.Write([]byte(fmt.Sprintf("File %q written", fiName)))
150+
fmt.Printf("Received file %s from %s\n", fiName, r.RemoteAddr)
151+
}))
152+
111153
ctx, ctxCancel := inv.SignalNotifyContext(ctx, os.Interrupt)
112154
defer ctxCancel()
113155

cmd/wush/rsync.go

+10-7
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ func rsyncCmd() *serpent.Command {
2626
sshStdio bool
2727
)
2828
return &serpent.Command{
29-
Use: "rsync [flags] -- [rsync args]",
29+
Use: "rsync [flags] -- [rsync args]",
30+
Short: "Transfer files over rsync.",
3031
Long: "Runs rsync to transfer files to a " + cliui.Code("wush") + " peer. " +
3132
"Use " + cliui.Code("wush receive") + " on the computer you would like to connect to." +
3233
"\n\n" +
@@ -81,11 +82,13 @@ func rsyncCmd() *serpent.Command {
8182

8283
progPath := os.Args[0]
8384
args := []string{
84-
"-e", fmt.Sprintf("%s --auth-id %s --stdio --", progPath, send.Auth.AuthKey()),
85+
"-c",
86+
fmt.Sprintf(`rsync -e "%s ssh --auth-key %s --stdio --" %s`,
87+
progPath, send.Auth.AuthKey(), strings.Join(inv.Args, " "),
88+
),
8589
}
86-
args = append(args, inv.Args...)
8790
fmt.Println("Running: rsync", strings.Join(args, " "))
88-
cmd := exec.CommandContext(ctx, "rsync", args...)
91+
cmd := exec.CommandContext(ctx, "sh", args...)
8992
cmd.Stdin = inv.Stdin
9093
cmd.Stdout = inv.Stdout
9194
cmd.Stderr = inv.Stderr
@@ -94,9 +97,9 @@ func rsyncCmd() *serpent.Command {
9497
},
9598
Options: []serpent.Option{
9699
{
97-
Flag: "auth-id",
98-
Env: "WUSH_AUTH_ID",
99-
Description: "The auth id returned by " + cliui.Code("wush receive") + ". If not provided, it will be asked for on startup.",
100+
Flag: "auth-key",
101+
Env: "WUSH_AUTH_KEY",
102+
Description: "The auth key returned by " + cliui.Code("wush receive") + ". If not provided, it will be asked for on startup.",
100103
Default: "",
101104
Value: serpent.StringOf(&authID),
102105
},

cmd/wush/ssh.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ func sshCmd() *serpent.Command {
3232
return &serpent.Command{
3333
Use: "ssh",
3434
Aliases: []string{},
35+
Short: "Open a shell.",
3536
Long: "Opens an SSH connection to a " + cliui.Code("wush") + " peer. " +
3637
"Use " + cliui.Code("wush receive") + " on the computer you would like to connect to.",
3738
Handler: func(inv *serpent.Invocation) error {
38-
3939
ctx := inv.Context()
4040
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
4141
logF := func(str string, args ...any) {
@@ -135,8 +135,8 @@ func sshCmd() *serpent.Command {
135135
},
136136
Options: []serpent.Option{
137137
{
138-
Flag: "auth",
139-
Env: "WUSH_AUTH",
138+
Flag: "auth-key",
139+
Env: "WUSH_AUTH_KEY",
140140
Description: "The auth key returned by " + cliui.Code("wush receive") + ". If not provided, it will be asked for on startup.",
141141
Default: "",
142142
Value: serpent.StringOf(&authKey),

cmd/wush/version.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
func versionCmd() *serpent.Command {
1111
cmd := &serpent.Command{
1212
Use: "version",
13-
Short: "Show wush version",
13+
Short: "Show wush version.",
1414
Handler: func(inv *serpent.Invocation) error {
1515
bi := getBuildInfo()
1616
fmt.Printf("Wush %s-%s %s\n", bi.version, bi.commitHash[:7], bi.commitTime.Format(time.RFC1123))

go.mod

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ require (
2222
github.com/pion/stun/v3 v3.0.0
2323
github.com/prometheus/client_golang v1.19.1
2424
github.com/puzpuzpuz/xsync/v3 v3.2.0
25+
github.com/schollz/progressbar/v3 v3.14.6
2526
github.com/spf13/afero v1.11.0
2627
github.com/valyala/fasthttp v1.55.0
2728
go4.org/mem v0.0.0-20220726221520-4f986261bf13
@@ -139,6 +140,7 @@ require (
139140
github.com/mdlayher/sdnotify v1.0.0 // indirect
140141
github.com/mdlayher/socket v0.5.0 // indirect
141142
github.com/miekg/dns v1.1.58 // indirect
143+
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
142144
github.com/mitchellh/copystructure v1.2.0 // indirect
143145
github.com/mitchellh/go-ps v1.0.0 // indirect
144146
github.com/mitchellh/go-testing-interface v1.14.1 // indirect

0 commit comments

Comments
 (0)