Skip to content

Commit 9873d7c

Browse files
Merge pull request #2 from qnib/user
allow for --user to be overwritten, not yet w/ unix-socket credentials
2 parents 64d28be + aabed67 commit 9873d7c

File tree

4 files changed

+55
-13
lines changed

4 files changed

+55
-13
lines changed

Diff for: main.go

+9-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ var (
5959
Usage: "File holding line-separated regex-patterns to be allowed (comments allowed, use #)",
6060
EnvVar: "DOXY_PATTERN_FILE",
6161
}
62+
pinUserFlag = cli.StringFlag{
63+
Name: "pin-user",
64+
Usage: "Overwrite `--user` with given value",
65+
EnvVar: "DOXY_PIN_USER",
66+
}
6267
deviceFileFlag = cli.StringFlag{
6368
Name: "device-file",
6469
Value: proxy.DEVICE_FILE,
@@ -74,10 +79,12 @@ func EvalOptions(cfg *config.Config) (po []proxy.ProxyOption) {
7479
po = append(po, proxy.WithDockerSocket(dockerSock))
7580
debug, _ := cfg.Bool("debug")
7681
po = append(po, proxy.WithDebugValue(debug))
77-
devMaps, _ := cfg.String("device-mappings")
7882
gpu, _ := cfg.Bool("gpu")
7983
po = append(po, proxy.WithGpuValue(gpu))
84+
devMaps, _ := cfg.String("device-mappings")
8085
po = append(po, proxy.WithDevMappings(strings.Split(devMaps,",")))
86+
pinUser, _ := cfg.String("pin-user")
87+
po = append(po, proxy.WithPinUserValue(pinUser))
8188
return
8289
}
8390

@@ -148,6 +155,7 @@ func main() {
148155
patternFileFlag,
149156
proxyPatternKey,
150157
bindAddFlag,
158+
pinUserFlag,
151159
}
152160
app.Action = RunApp
153161
app.Run(os.Args)

Diff for: proxy/main.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ var (
4848
)
4949

5050
type Proxy struct {
51-
dockerSocket, newSocket string
52-
debug, gpu bool
53-
patterns []string
54-
bindMounts,devMappings []string
51+
dockerSocket, newSocket, pinUser string
52+
debug, gpu bool
53+
patterns []string
54+
bindMounts,devMappings []string
5555
}
5656

5757
func NewProxy(opts ...ProxyOption) Proxy {
@@ -64,6 +64,7 @@ func NewProxy(opts ...ProxyOption) Proxy {
6464
newSocket: options.ProxySocket,
6565
debug: options.Debug,
6666
gpu: options.Gpu,
67+
pinUser: options.PinUser,
6768
patterns: options.Patterns,
6869
bindMounts: options.BindMounts,
6970
devMappings: options.DevMappings,
@@ -81,7 +82,7 @@ func (p *Proxy) GetOptions() map[string]interface{} {
8182
}
8283

8384
func (p *Proxy) Run() {
84-
upstream := NewUpstream(p.dockerSocket, p.patterns, p.bindMounts, p.devMappings, p.gpu)
85+
upstream := NewUpstream(p.dockerSocket, p.patterns, p.bindMounts, p.devMappings, p.gpu, p.pinUser)
8586
sigc := make(chan os.Signal, 1)
8687
signal.Notify(sigc, os.Interrupt, os.Kill, syscall.SIGTERM)
8788
l, err := ListenToNewSock(p.newSocket, sigc)
@@ -94,6 +95,7 @@ func (p *Proxy) Run() {
9495
}
9596
n.UseHandler(upstream)
9697
log.Printf("Serving proxy on '%s'", p.newSocket)
98+
9799
if err = http.Serve(l, n); err != nil {
98100
panic(err)
99101
}

Diff for: proxy/options.go

+7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package proxy
33
type ProxyOptions struct {
44
DockerSocket string
55
ProxySocket string
6+
PinUser string
67
Debug,Gpu bool
78
Patterns []string
89
BindMounts []string
@@ -12,6 +13,7 @@ type ProxyOptions struct {
1213
var defaultProxyOptions = ProxyOptions{
1314
DockerSocket: DOCKER_SOCKET,
1415
ProxySocket: PROXY_SOCKET,
16+
PinUser: "",
1517
Debug: false,
1618
Gpu: false,
1719
Patterns: []string{},
@@ -21,6 +23,11 @@ var defaultProxyOptions = ProxyOptions{
2123

2224
type ProxyOption func(*ProxyOptions)
2325

26+
func WithPinUserValue(pu string) ProxyOption {
27+
return func(o *ProxyOptions) {
28+
o.PinUser = pu
29+
}
30+
}
2431
func WithDockerSocket(s string) ProxyOption {
2532
return func(o *ProxyOptions) {
2633
o.DockerSocket = s

Diff for: proxy/proxy.go

+32-7
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ import (
2121

2222
// UpStream creates upstream handler struct
2323
type UpStream struct {
24-
Name string
25-
proxy http.Handler
24+
Name string
25+
proxy http.Handler
2626
// TODO: Kick out separat config options and use more generic one
27-
allowed []*regexp.Regexp
28-
bindMounts []string
29-
devMappings []string
27+
allowed []*regexp.Regexp
28+
bindMounts []string
29+
devMappings []string
3030
gpu bool
31+
pinUser string
3132
}
3233

3334
// UnixSocket just provides the path, so that I can test it
@@ -64,21 +65,23 @@ func newReverseProxy(dial func(network, addr string) (net.Conn, error)) *httputi
6465
}
6566

6667
// NewUpstream returns a new socket (magic)
67-
func NewUpstream(socket string, regs []string, binds []string, devs []string, gpu bool) *UpStream {
68+
func NewUpstream(socket string, regs []string, binds []string, devs []string, gpu bool, pinUser string) *UpStream {
6869
us := NewUnixSocket(socket)
6970
a := []*regexp.Regexp{}
7071
for _, r := range regs {
7172
p, _ := regexp.Compile(r)
7273
a = append(a, p)
7374
}
74-
return &UpStream{
75+
upstream := &UpStream{
7576
Name: socket,
7677
proxy: newReverseProxy(us.connectSocket),
7778
allowed: a,
7879
bindMounts: binds,
7980
devMappings: devs,
8081
gpu: gpu,
82+
pinUser: pinUser,
8183
}
84+
return upstream
8285
}
8386

8487

@@ -97,11 +100,24 @@ func (u *UpStream) ServeHTTP(w http.ResponseWriter, req *http.Request) {
97100
http.Error(w, fmt.Sprintf("Only GET requests are allowed, req.Method: %s", req.Method), 400)
98101
return
99102
}*/
103+
/*
104+
// Hijack the connection to inspect who called it
105+
hj, ok := w.(http.Hijacker)
106+
if !ok {
107+
http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError)
108+
return
109+
}
110+
conn, _, err := hj.Hijack()
111+
if err != nil {
112+
http.Error(w, err.Error(), http.StatusInternalServerError)
113+
return
114+
}*/
100115
// Read the body
101116
body, err := ioutil.ReadAll(req.Body)
102117
if err != nil {
103118
fmt.Println(err.Error())
104119
}
120+
//syscall.GetsockoptUcred(int(fd), syscall.SOL_SOCKET, syscall.SO_PEERCRED)
105121
//fmt.Printf("%v\n", hostConfig.Mounts)
106122
// And now set a new body, which will simulate the same data we read:
107123
req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
@@ -125,6 +141,15 @@ func (u *UpStream) ServeHTTP(w http.ResponseWriter, req *http.Request) {
125141
config.Env = append(config.Env, "PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin")
126142
config.Env = append(config.Env, "LD_LIBRARY_PATH=/usr/local/nvidia/")
127143
}
144+
if u.pinUser != "" {
145+
// TODO: Should depend on calling user from syscall.GetsockoptUcred()
146+
if config.User != "" {
147+
fmt.Printf("Overwrite User with '%s', was '%s'\n", u.pinUser, config.User)
148+
} else {
149+
fmt.Printf("Overwrite User with '%s'\n", u.pinUser)
150+
}
151+
config.User = u.pinUser
152+
}
128153
for _, bMount := range u.bindMounts {
129154
if bMount == "" {
130155
continue

0 commit comments

Comments
 (0)