diff --git a/command.go b/command.go index 6904bfba3..60b9c4558 100644 --- a/command.go +++ b/command.go @@ -315,6 +315,10 @@ func (c *Command) SetUsageFunc(f func(*Command) error) { // SetUsageTemplate sets usage template. Can be defined by Application. func (c *Command) SetUsageTemplate(s string) { + if s == "" { + c.usageTemplate = nil + return + } c.usageTemplate = tmpl(s) } @@ -351,11 +355,19 @@ func (c *Command) SetCompletionCommandGroupID(groupID string) { // SetHelpTemplate sets help template to be used. Application can use it to set custom template. func (c *Command) SetHelpTemplate(s string) { + if s == "" { + c.helpTemplate = nil + return + } c.helpTemplate = tmpl(s) } // SetVersionTemplate sets version template to be used. Application can use it to set custom template. func (c *Command) SetVersionTemplate(s string) { + if s == "" { + c.versionTemplate = nil + return + } c.versionTemplate = tmpl(s) } diff --git a/command_test.go b/command_test.go index 0b0d6c662..a379460ae 100644 --- a/command_test.go +++ b/command_test.go @@ -1047,6 +1047,18 @@ func TestSetHelpTemplate(t *testing.T) { if got != expected { t.Errorf("Expected %q, got %q", expected, got) } + + // Reset the root command help template and make sure + // it falls back to the default + rootCmd.SetHelpTemplate("") + got, err = executeCommand(rootCmd, "--help") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if !strings.Contains(got, "Usage:") { + t.Errorf("Expected to contain %q, got %q", "Usage:", got) + } } func TestHelpFlagExecuted(t *testing.T) { @@ -1139,6 +1151,18 @@ func TestSetUsageTemplate(t *testing.T) { expected = "WORKS " + childCmd.UseLine() checkStringContains(t, got, expected) + + // Reset the root command usage template and make sure + // it falls back to the default + rootCmd.SetUsageTemplate("") + got, err = executeCommand(rootCmd, "--invalid") + if err == nil { + t.Errorf("Expected error but did not get one") + } + + if !strings.Contains(got, "Usage:") { + t.Errorf("Expected to contain %q, got %q", "Usage:", got) + } } func TestVersionFlagExecuted(t *testing.T) {