@@ -21,13 +21,14 @@ import (
21
21
22
22
// UpStream creates upstream handler struct
23
23
type UpStream struct {
24
- Name string
25
- proxy http.Handler
24
+ Name string
25
+ proxy http.Handler
26
26
// TODO: Kick out separat config options and use more generic one
27
- allowed []* regexp.Regexp
28
- bindMounts []string
29
- devMappings []string
27
+ allowed []* regexp.Regexp
28
+ bindMounts []string
29
+ devMappings []string
30
30
gpu bool
31
+ pinUser string
31
32
}
32
33
33
34
// UnixSocket just provides the path, so that I can test it
@@ -64,21 +65,23 @@ func newReverseProxy(dial func(network, addr string) (net.Conn, error)) *httputi
64
65
}
65
66
66
67
// NewUpstream returns a new socket (magic)
67
- func NewUpstream (socket string , regs []string , binds []string , devs []string , gpu bool ) * UpStream {
68
+ func NewUpstream (socket string , regs []string , binds []string , devs []string , gpu bool , pinUser string ) * UpStream {
68
69
us := NewUnixSocket (socket )
69
70
a := []* regexp.Regexp {}
70
71
for _ , r := range regs {
71
72
p , _ := regexp .Compile (r )
72
73
a = append (a , p )
73
74
}
74
- return & UpStream {
75
+ upstream := & UpStream {
75
76
Name : socket ,
76
77
proxy : newReverseProxy (us .connectSocket ),
77
78
allowed : a ,
78
79
bindMounts : binds ,
79
80
devMappings : devs ,
80
81
gpu : gpu ,
82
+ pinUser : pinUser ,
81
83
}
84
+ return upstream
82
85
}
83
86
84
87
@@ -97,11 +100,24 @@ func (u *UpStream) ServeHTTP(w http.ResponseWriter, req *http.Request) {
97
100
http.Error(w, fmt.Sprintf("Only GET requests are allowed, req.Method: %s", req.Method), 400)
98
101
return
99
102
}*/
103
+ /*
104
+ // Hijack the connection to inspect who called it
105
+ hj, ok := w.(http.Hijacker)
106
+ if !ok {
107
+ http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError)
108
+ return
109
+ }
110
+ conn, _, err := hj.Hijack()
111
+ if err != nil {
112
+ http.Error(w, err.Error(), http.StatusInternalServerError)
113
+ return
114
+ }*/
100
115
// Read the body
101
116
body , err := ioutil .ReadAll (req .Body )
102
117
if err != nil {
103
118
fmt .Println (err .Error ())
104
119
}
120
+ //syscall.GetsockoptUcred(int(fd), syscall.SOL_SOCKET, syscall.SO_PEERCRED)
105
121
//fmt.Printf("%v\n", hostConfig.Mounts)
106
122
// And now set a new body, which will simulate the same data we read:
107
123
req .Body = ioutil .NopCloser (bytes .NewBuffer (body ))
@@ -125,6 +141,15 @@ func (u *UpStream) ServeHTTP(w http.ResponseWriter, req *http.Request) {
125
141
config .Env = append (config .Env , "PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" )
126
142
config .Env = append (config .Env , "LD_LIBRARY_PATH=/usr/local/nvidia/" )
127
143
}
144
+ if u .pinUser != "" {
145
+ // TODO: Should depend on calling user from syscall.GetsockoptUcred()
146
+ if config .User != "" {
147
+ fmt .Printf ("Overwrite User with '%s', was '%s'\n " , u .pinUser , config .User )
148
+ } else {
149
+ fmt .Printf ("Overwrite User with '%s'\n " , u .pinUser )
150
+ }
151
+ config .User = u .pinUser
152
+ }
128
153
for _ , bMount := range u .bindMounts {
129
154
if bMount == "" {
130
155
continue
0 commit comments