diff --git a/setup.go b/setup.go index 9f27b3a9..6611b976 100644 --- a/setup.go +++ b/setup.go @@ -24,13 +24,14 @@ import ( func (e *Executor) Setup() error { e.setupLogger() - if err := e.setCurrentDir(); err != nil { + node, err := e.getRootNode() + if err != nil { return err } if err := e.setupTempDir(); err != nil { return err } - if err := e.readTaskfile(); err != nil { + if err := e.readTaskfile(node); err != nil { return err } e.setupFuzzyModel() @@ -44,45 +45,25 @@ func (e *Executor) Setup() error { if err := e.readDotEnvFiles(); err != nil { return err } - if err := e.doVersionChecks(); err != nil { return err } e.setupDefaults() e.setupConcurrencyState() - return nil } -func (e *Executor) setCurrentDir() error { - // If the entrypoint is already set, we don't need to do anything - if e.Entrypoint != "" { - return nil - } - - // Default the directory to the current working directory - if e.Dir == "" { - wd, err := os.Getwd() - if err != nil { - return err - } - e.Dir = wd - } else { - var err error - e.Dir, err = filepath.Abs(e.Dir) - if err != nil { - return err - } - } - - return nil -} - -func (e *Executor) readTaskfile() error { +func (e *Executor) getRootNode() (taskfile.Node, error) { node, err := taskfile.NewRootNode(e.Dir, e.Entrypoint, e.Insecure) if err != nil { - return err + return nil, err } + e.Dir = node.BaseDir() + return node, err +} + +func (e *Executor) readTaskfile(node taskfile.Node) error { + var err error e.Taskfile, err = taskfile.Read( node, e.Insecure, diff --git a/task_test.go b/task_test.go index 8cfa832b..4c02565c 100644 --- a/task_test.go +++ b/task_test.go @@ -109,6 +109,7 @@ func TestVars(t *testing.T) { func TestSpecialVars(t *testing.T) { const dir = "testdata/special_vars" + const subdir = "testdata/special_vars/subdir" toAbs := func(rel string) string { abs, err := filepath.Abs(rel) assert.NoError(t, err) @@ -122,28 +123,32 @@ func TestSpecialVars(t *testing.T) { // Root {target: "print-task", expected: "print-task"}, {target: "print-root-dir", expected: toAbs(dir)}, + {target: "print-taskfile", expected: toAbs(dir) + "/Taskfile.yml"}, {target: "print-taskfile-dir", expected: toAbs(dir)}, {target: "print-task-version", expected: "unknown"}, // Included {target: "included:print-task", expected: "included:print-task"}, {target: "included:print-root-dir", expected: toAbs(dir)}, + {target: "included:print-taskfile", expected: toAbs(dir) + "/included/Taskfile.yml"}, {target: "included:print-taskfile-dir", expected: toAbs(dir) + "/included"}, {target: "included:print-task-version", expected: "unknown"}, } - for _, test := range tests { - t.Run(test.target, func(t *testing.T) { - var buff bytes.Buffer - e := &task.Executor{ - Dir: dir, - Stdout: &buff, - Stderr: &buff, - Silent: true, - } - require.NoError(t, e.Setup()) - require.NoError(t, e.Run(context.Background(), &ast.Call{Task: test.target})) - assert.Equal(t, test.expected+"\n", buff.String()) - }) + for _, dir := range []string{dir, subdir} { + for _, test := range tests { + t.Run(test.target, func(t *testing.T) { + var buff bytes.Buffer + e := &task.Executor{ + Dir: dir, + Stdout: &buff, + Stderr: &buff, + Silent: true, + } + require.NoError(t, e.Setup()) + require.NoError(t, e.Run(context.Background(), &ast.Call{Task: test.target})) + assert.Equal(t, test.expected+"\n", buff.String()) + }) + } } } diff --git a/taskfile/node.go b/taskfile/node.go index 8b9dc287..27f9ab58 100644 --- a/taskfile/node.go +++ b/taskfile/node.go @@ -24,6 +24,7 @@ func NewRootNode( entrypoint string, insecure bool, ) (Node, error) { + dir = getDefaultDir(entrypoint, dir) // Check if there is something to read on STDIN stat, _ := os.Stdin.Stat() if (stat.Mode()&os.ModeCharDevice) == 0 && stat.Size() > 0 { @@ -68,3 +69,26 @@ func getScheme(uri string) string { } return "" } + +func getDefaultDir(entrypoint, dir string) string { + // If the entrypoint and dir are empty, we default the directory to the current working directory + if dir == "" { + if entrypoint == "" { + wd, err := os.Getwd() + if err != nil { + return "" + } + dir = wd + } + return dir + } + + // If the directory is set, ensure it is an absolute path + var err error + dir, err = filepath.Abs(dir) + if err != nil { + return "" + } + + return dir +} diff --git a/testdata/special_vars/Taskfile.yml b/testdata/special_vars/Taskfile.yml index 271356cf..23415818 100644 --- a/testdata/special_vars/Taskfile.yml +++ b/testdata/special_vars/Taskfile.yml @@ -8,5 +8,6 @@ includes: tasks: print-task: echo {{.TASK}} print-root-dir: echo {{.ROOT_DIR}} + print-taskfile: echo {{.TASKFILE}} print-taskfile-dir: echo {{.TASKFILE_DIR}} print-task-version: echo {{.TASK_VERSION}} diff --git a/testdata/special_vars/included/Taskfile.yml b/testdata/special_vars/included/Taskfile.yml index 5eb75b15..63562f97 100644 --- a/testdata/special_vars/included/Taskfile.yml +++ b/testdata/special_vars/included/Taskfile.yml @@ -3,5 +3,6 @@ version: '3' tasks: print-task: echo {{.TASK}} print-root-dir: echo {{.ROOT_DIR}} + print-taskfile: echo {{.TASKFILE}} print-taskfile-dir: echo {{.TASKFILE_DIR}} print-task-version: echo {{.TASK_VERSION}} diff --git a/testdata/special_vars/subdir/.gitkeep b/testdata/special_vars/subdir/.gitkeep new file mode 100644 index 00000000..e69de29b