Skip to content
This repository has been archived by the owner on Jul 11, 2022. It is now read-only.

Commit

Permalink
refactor to io.Reader for better memory use
Browse files Browse the repository at this point in the history
Signed-off-by: Jason McCallister <jason@craftcms.com>
  • Loading branch information
jasonmccallister committed Oct 2, 2020
1 parent 071e978 commit 6430746
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 53 deletions.
71 changes: 39 additions & 32 deletions internal/cmd/db_import.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

"github.com/craftcms/nitro/internal/client"
"github.com/craftcms/nitro/internal/config"
"github.com/craftcms/nitro/internal/database"
"github.com/craftcms/nitro/internal/helpers"
"github.com/craftcms/nitro/internal/nitro"
"github.com/craftcms/nitro/internal/nitrod"
Expand All @@ -33,6 +34,10 @@ var dbImportCommand = &cobra.Command{
RunE: func(cmd *cobra.Command, args []string) error {
machine := flagMachineName
runner := nitro.NewMultipassRunner("multipass")
var configFile config.Config
if err := viper.Unmarshal(&configFile); err != nil {
return err
}
ip := nitro.IP(machine, runner)
c, err := client.NewClient(ip, "50051")
if err != nil {
Expand Down Expand Up @@ -64,20 +69,28 @@ var dbImportCommand = &cobra.Command{
return nil
}

engines, err := getDatabaseEngines()
// create a new prompt
p := prompt.NewPrompt()

// open the file
file, err := os.Open(filename)
if err != nil {
fmt.Println("Unable to get a list of the database engines, error:", err.Error())
return nil
return err
}
reader := bufio.NewReader(file)

// open the file so we can stream it to the server
f, err := os.Open(filename)
// try to determine the database engine
detected, err := database.DetermineEngine(reader)
if err != nil {
return err
fmt.Println("Unable to determine the database engine from the file", filename)
}

// create a new prompt
p := prompt.NewPrompt()
// get the databases as a list, limiting to the detected engine
engines := configFile.DatabaseEnginesAsList(detected)
if len(engines) == 0 {
fmt.Println("Unable to get a list of the database engines")
return nil
}

// if there is only on database engine
var container string
Expand All @@ -94,12 +107,25 @@ var dbImportCommand = &cobra.Command{
}
req.Container = container

// prompt for the database name to create
database, err := p.Ask("Enter the database name to create for the import:", &prompt.InputOptions{Validator: nil})
if err != nil {
return err
// if the detect engine is mysql
showCreatePrompt := true
if detected == "mysql" {
// check if there is a create database statement
showCreatePrompt, err = database.HasCreateStatement(reader)
if err != nil {
fmt.Println(err.Error())
}
fmt.Printf("The file %q will create a database during import...\n", filename)
}

// prompt for the database name to create if needed
if showCreatePrompt {
db, err := p.Ask("Enter the database name to create for the import:", &prompt.InputOptions{Validator: nil})
if err != nil {
return err
}
req.Database = db
}
req.Database = database

// create the stream
stream, err := c.ImportDatabase(cmd.Context())
Expand All @@ -110,7 +136,6 @@ var dbImportCommand = &cobra.Command{

fmt.Printf("Uploading %q into %q (large files may take a while)...\n", filename, machine)

reader := bufio.NewReader(f)
start := time.Now()
buffer := make([]byte, 1024*20)

Expand Down Expand Up @@ -160,21 +185,3 @@ func isCompressed(file string, req *nitrod.ImportDatabaseRequest) error {

return nil
}

func getDatabaseEngines() ([]string, error) {
var dbs []string
var databases []config.Database
if err := viper.UnmarshalKey("databases", &databases); err != nil {
return dbs, err
}

for _, db := range databases {
dbs = append(dbs, db.Name())
}

if len(dbs) == 0 {
return dbs, errors.New("there are no databases that we can import the file into")
}

return dbs, nil
}
24 changes: 7 additions & 17 deletions internal/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package database
import (
"bufio"
"errors"
"os"
"io"
"strings"
)

Expand All @@ -12,16 +12,12 @@ import (
// imports. It will return the engine "mysql" or
// "postgres" if it can determine the engine.
// If it cannot, it will return an error.
func DetermineEngine(file string) (string, error) {
f, err := os.Open(file)
if err != nil {
return "", err
}
defer f.Close()

func DetermineEngine(file io.Reader) (string, error) {
engine := ""
s := bufio.NewScanner(f)
line := 1

s := bufio.NewScanner(file)

for s.Scan() {
// check if its postgres
if strings.Contains(s.Text(), "PostgreSQL") || strings.Contains(s.Text(), "pg_dump") {
Expand Down Expand Up @@ -54,16 +50,10 @@ func DetermineEngine(file string) (string, error) {
// if the file will create a database during import.
// If it creates a database, it will return true
// otherwise it will return false.
func HasCreateStatement(file string) (bool, error) {
f, err := os.Open(file)
if err != nil {
return false, err
}
defer f.Close()

func HasCreateStatement(file io.Reader) (bool, error) {
creates := false

s := bufio.NewScanner(f)
s := bufio.NewScanner(file)
line := 1
for s.Scan() {

Expand Down
25 changes: 21 additions & 4 deletions internal/database/database_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package database

import "testing"
import (
"bufio"
"os"
"testing"
)

func TestDetermineEngine(t *testing.T) {
type args struct {
Expand Down Expand Up @@ -33,7 +37,14 @@ func TestDetermineEngine(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := DetermineEngine(tt.args.file)
f, err := os.Open(tt.args.file)
if err != nil {
t.Fatal(err)
}

rdr := bufio.NewReader(f)

got, err := DetermineEngine(rdr)
if (err != nil) != tt.wantErr {
t.Errorf("DetermineEngine() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down Expand Up @@ -64,7 +75,13 @@ func TestHasCreateStatement(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := HasCreateStatement(tt.args.file)
f, err := os.Open(tt.args.file)
if err != nil {
t.Fatal(err)
}
rdr := bufio.NewReader(f)

got, err := HasCreateStatement(rdr)
if (err != nil) != tt.wantErr {
t.Errorf("HasCreateStatement() error = %v, wantErr %v", err, tt.wantErr)
return
Expand All @@ -74,4 +91,4 @@ func TestHasCreateStatement(t *testing.T) {
}
})
}
}
}

0 comments on commit 6430746

Please # to comment.