Skip to content

feat(config): add network_restrictions to db config #3759

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion internal/link/link.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,19 @@ func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func(
func LinkServices(ctx context.Context, projectRef, anonKey string, fsys afero.Fs) {
// Ignore non-fatal errors linking services
var wg sync.WaitGroup
wg.Add(7)
wg.Add(8)
go func() {
defer wg.Done()
if err := linkDatabaseSettings(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
fmt.Fprintln(os.Stderr, err)
}
}()
go func() {
defer wg.Done()
if err := linkNetworkRestrictions(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
fmt.Fprintln(os.Stderr, err)
}
}()
go func() {
defer wg.Done()
if err := linkPostgrest(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
Expand Down Expand Up @@ -193,6 +199,42 @@ func linkDatabaseSettings(ctx context.Context, projectRef string) error {
return nil
}

func linkNetworkRestrictions(ctx context.Context, projectRef string) error {
resp, err := utils.GetSupabase().V1GetNetworkRestrictionsWithResponse(ctx, projectRef)
if err != nil {
return errors.Errorf("failed to read network restrictions config: %w", err)
} else if resp.JSON200 == nil {
return errors.Errorf("unexpected network restrictions config status %d: %s", resp.StatusCode(), string(resp.Body))
}

// Check if remote has actual restrictions (not just "allow all")
hasRestrictions := false
if resp.JSON200.Config.DbAllowedCidrs != nil && len(*resp.JSON200.Config.DbAllowedCidrs) > 0 {
// Check if it's not just "allow all"
if len(*resp.JSON200.Config.DbAllowedCidrs) != 1 || (*resp.JSON200.Config.DbAllowedCidrs)[0] != "0.0.0.0/0" {
hasRestrictions = true
}
}
if resp.JSON200.Config.DbAllowedCidrsV6 != nil && len(*resp.JSON200.Config.DbAllowedCidrsV6) > 0 {
// Check if it's not just "allow all"
if len(*resp.JSON200.Config.DbAllowedCidrsV6) != 1 || (*resp.JSON200.Config.DbAllowedCidrsV6)[0] != "::/0" {
hasRestrictions = true
}
}

// Only create NetworkRestrictions if there are actual restrictions
if hasRestrictions {
if utils.Config.Db.NetworkRestrictions == nil {
utils.Config.Db.NetworkRestrictions = &cliConfig.NetworkRestrictions{}
}
utils.Config.Db.NetworkRestrictions.FromRemoteNetworkRestrictions(*resp.JSON200)
} else {
// No restrictions, set to nil so the section doesn't appear in TOML
utils.Config.Db.NetworkRestrictions = nil
}
return nil
}

func linkDatabase(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
conn, err := utils.ConnectByConfig(ctx, config, options...)
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions internal/link/link_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ func TestLinkCommand(t *testing.T) {
Get("/v1/projects/" + project + "/config/database/pooler").
Reply(200).
JSON(api.V1PgbouncerConfigResponse{})
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + project + "/network-restrictions").
Reply(200).
JSON(api.NetworkRestrictionsResponse{})
// Link versions
auth := tenant.HealthResponse{Version: "v2.74.2"}
gock.New("https://" + utils.GetSupabaseHost(project)).
Expand Down Expand Up @@ -151,6 +155,10 @@ func TestLinkCommand(t *testing.T) {
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + project + "/config/database/pooler").
ReplyError(errors.New("network error"))
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + project + "/network-restrictions").
Reply(200).
JSON(api.NetworkRestrictionsResponse{})
// Link versions
gock.New("https://" + utils.GetSupabaseHost(project)).
Get("/auth/v1/health").
Expand Down Expand Up @@ -201,6 +209,10 @@ func TestLinkCommand(t *testing.T) {
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + project + "/config/database/pooler").
ReplyError(errors.New("network error"))
gock.New(utils.DefaultApiHost).
Get("/v1/projects/" + project + "/network-restrictions").
Reply(200).
JSON(api.NetworkRestrictionsResponse{})
// Link versions
gock.New("https://" + utils.GetSupabaseHost(project)).
Get("/auth/v1/health").
Expand Down
105 changes: 94 additions & 11 deletions pkg/config/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,25 @@ type (
WorkMem *string `toml:"work_mem"`
}

NetworkRestrictions struct {
Enabled bool `toml:"enabled"`
DbAllowedCidrs []string `toml:"db_allowed_cidrs"`
DbAllowedCidrsV6 []string `toml:"db_allowed_cidrs_v6"`
}

db struct {
Image string `toml:"-"`
Port uint16 `toml:"port"`
ShadowPort uint16 `toml:"shadow_port"`
MajorVersion uint `toml:"major_version"`
Password string `toml:"-"`
RootKey Secret `toml:"root_key"`
Pooler pooler `toml:"pooler"`
Migrations migrations `toml:"migrations"`
Seed seed `toml:"seed"`
Settings settings `toml:"settings"`
Vault map[string]Secret `toml:"vault"`
Image string `toml:"-"`
Port uint16 `toml:"port"`
ShadowPort uint16 `toml:"shadow_port"`
MajorVersion uint `toml:"major_version"`
Password string `toml:"-"`
RootKey Secret `toml:"root_key"`
Pooler pooler `toml:"pooler"`
Migrations migrations `toml:"migrations"`
Seed seed `toml:"seed"`
Settings settings `toml:"settings"`
NetworkRestrictions *NetworkRestrictions `toml:"network_restrictions,omitempty"`
Vault map[string]Secret `toml:"vault"`
}

migrations struct {
Expand Down Expand Up @@ -188,3 +195,79 @@ func (a *settings) DiffWithRemote(remoteConfig v1API.PostgresConfigResponse) ([]
}
return diff.Diff("remote[db.settings]", remoteCompare, "local[db.settings]", currentValue), nil
}

func (n *NetworkRestrictions) ToUpdateNetworkRestrictionsBody() v1API.V1UpdateNetworkRestrictionsJSONRequestBody {
body := v1API.V1UpdateNetworkRestrictionsJSONRequestBody{}

// If network_restrictions explicitely disabled we allow-all
if !n.Enabled {
body.DbAllowedCidrs = &[]string{"0.0.0.0/0"}
body.DbAllowedCidrsV6 = &[]string{"::/0"}
return body
}

// If enabled, send the actual CIDR values (empty arrays will reject all ips)
body.DbAllowedCidrs = &n.DbAllowedCidrs
body.DbAllowedCidrsV6 = &n.DbAllowedCidrsV6
return body
}

func (n *NetworkRestrictions) FromRemoteNetworkRestrictions(remoteConfig v1API.NetworkRestrictionsResponse) {
// Check if remote has restrictions (non-empty arrays that aren't "allow all")
hasRestrictions := false

if len(*remoteConfig.Config.DbAllowedCidrs) > 0 {
// Check if it's not just "allow all"
if len(*remoteConfig.Config.DbAllowedCidrs) != 1 || (*remoteConfig.Config.DbAllowedCidrs)[0] != "0.0.0.0/0" {
hasRestrictions = true
}
}

if len(*remoteConfig.Config.DbAllowedCidrsV6) > 0 {
// Check if it's not just "allow all"
if len(*remoteConfig.Config.DbAllowedCidrsV6) != 1 || (*remoteConfig.Config.DbAllowedCidrsV6)[0] != "::/0" {
hasRestrictions = true
}
}

// Set enabled based on whether there are actual restrictions
n.Enabled = hasRestrictions

// Set the CIDR values
if remoteConfig.Config.DbAllowedCidrs != nil {
n.DbAllowedCidrs = *remoteConfig.Config.DbAllowedCidrs
} else {
n.DbAllowedCidrs = []string{}
}

if remoteConfig.Config.DbAllowedCidrsV6 != nil {
n.DbAllowedCidrsV6 = *remoteConfig.Config.DbAllowedCidrsV6
} else {
n.DbAllowedCidrsV6 = []string{}
}
}

func (n *NetworkRestrictions) DiffWithRemote(remoteConfig v1API.NetworkRestrictionsResponse) ([]byte, error) {
if n == nil {
return nil, nil
}

// If enabled is explicitely false, we set the default allow_all values
if n.Enabled == false {
n.DbAllowedCidrs = []string{"0.0.0.0/0"}
n.DbAllowedCidrsV6 = []string{"::/0"}
}

copy := *n
// Convert the config values into easily comparable remoteConfig values
currentValue, err := ToTomlBytes(copy)
if err != nil {
return nil, err
}
copy.FromRemoteNetworkRestrictions(remoteConfig)
remoteCompare, err := ToTomlBytes(copy)
if err != nil {
return nil, err
}
return diff.Diff("remote[db.network_restrictions]", remoteCompare, "local[db.network_restrictions]", currentValue), nil
}
Loading