feat: parse templates in collection-type variables (#1526)

* refactor: replacer

* feat: move traverser to deepcopy package

* feat: nested map variable templating

* refactor: ReplaceVar function

* feat: test cases

* fix: TraverseStringsFunc copy value instead of pointer
This commit is contained in:
Pete Davison
2024-03-10 17:11:07 +00:00
committed by GitHub
parent 19a4d8f928
commit 08a888dc8a
13 changed files with 243 additions and 135 deletions

View File

@@ -59,20 +59,9 @@ func (c *Compiler) getVariables(t *ast.Task, call *ast.Call, evaluateShVars bool
getRangeFunc := func(dir string) func(k string, v ast.Var) error {
return func(k string, v ast.Var) error {
tr := templater.Templater{Vars: result}
cache := &templater.Cache{Vars: result}
// Replace values
newVar := ast.Var{}
switch value := v.Value.(type) {
case string:
newVar.Value = tr.Replace(value)
default:
newVar.Value = value
}
newVar.Sh = tr.Replace(v.Sh)
newVar.Ref = v.Ref
newVar.Json = tr.Replace(v.Json)
newVar.Yaml = tr.Replace(v.Yaml)
newVar.Dir = v.Dir
newVar := templater.ReplaceVar(v, cache)
// If the variable is a reference, we can resolve it
if newVar.Ref != "" {
newVar.Value = result.Get(newVar.Ref).Value
@@ -89,7 +78,7 @@ func (c *Compiler) getVariables(t *ast.Task, call *ast.Call, evaluateShVars bool
return nil
}
// Now we can check for errors since we've handled all the cases when we don't want to evaluate
if err := tr.Err(); err != nil {
if err := cache.Err(); err != nil {
return err
}
// Evaluate JSON
@@ -124,9 +113,9 @@ func (c *Compiler) getVariables(t *ast.Task, call *ast.Call, evaluateShVars bool
if t != nil {
// NOTE(@andreynering): We're manually joining these paths here because
// this is the raw task, not the compiled one.
tr := templater.Templater{Vars: result}
dir := tr.Replace(t.Dir)
if err := tr.Err(); err != nil {
cache := &templater.Cache{Vars: result}
dir := templater.Replace(t.Dir, cache)
if err := cache.Err(); err != nil {
return nil, err
}
dir = filepathext.SmartJoin(c.Dir, dir)

View File

@@ -1,5 +1,9 @@
package deepcopy
import (
"reflect"
)
type Copier[T any] interface {
DeepCopy() T
}
@@ -33,3 +37,105 @@ func Map[K comparable, V any](orig map[K]V) map[K]V {
}
return c
}
// TraverseStringsFunc runs the given function on every string in the given
// value by traversing it recursively. If the given value is a string, the
// function will run on a copy of the string and return it. If the value is a
// struct, map or a slice, the function will recursively call itself for each
// field or element of the struct, map or slice until all strings inside the
// struct or slice are replaced.
func TraverseStringsFunc[T any](v T, fn func(v string) (string, error)) (T, error) {
original := reflect.ValueOf(v)
if original.Kind() == reflect.Invalid || !original.IsValid() {
return v, nil
}
copy := reflect.New(original.Type()).Elem()
var traverseFunc func(copy, v reflect.Value) error
traverseFunc = func(copy, v reflect.Value) error {
switch v.Kind() {
case reflect.Ptr:
// Unwrap the pointer
originalValue := v.Elem()
// If the pointer is nil, do nothing
if !originalValue.IsValid() {
return nil
}
// Create an empty copy from the original value's type
copy.Set(reflect.New(originalValue.Type()))
// Unwrap the newly created pointer and call traverseFunc recursively
if err := traverseFunc(copy.Elem(), originalValue); err != nil {
return err
}
case reflect.Interface:
// Unwrap the interface
originalValue := v.Elem()
if !originalValue.IsValid() {
return nil
}
// Create an empty copy from the original value's type
copyValue := reflect.New(originalValue.Type()).Elem()
// Unwrap the newly created pointer and call traverseFunc recursively
if err := traverseFunc(copyValue, originalValue); err != nil {
return err
}
copy.Set(copyValue)
case reflect.Struct:
// Loop over each field and call traverseFunc recursively
for i := 0; i < v.NumField(); i += 1 {
if err := traverseFunc(copy.Field(i), v.Field(i)); err != nil {
return err
}
}
case reflect.Slice:
// Create an empty copy from the original value's type
copy.Set(reflect.MakeSlice(v.Type(), v.Len(), v.Cap()))
// Loop over each element and call traverseFunc recursively
for i := 0; i < v.Len(); i += 1 {
if err := traverseFunc(copy.Index(i), v.Index(i)); err != nil {
return err
}
}
case reflect.Map:
// Create an empty copy from the original value's type
copy.Set(reflect.MakeMap(v.Type()))
// Loop over each key
for _, key := range v.MapKeys() {
// Create a copy of each map index
originalValue := v.MapIndex(key)
if originalValue.IsNil() {
continue
}
copyValue := reflect.New(originalValue.Type()).Elem()
// Call traverseFunc recursively
if err := traverseFunc(copyValue, originalValue); err != nil {
return err
}
copy.SetMapIndex(key, copyValue)
}
case reflect.String:
rv, err := fn(v.String())
if err != nil {
return err
}
copy.Set(reflect.ValueOf(rv))
default:
copy.Set(v)
}
return nil
}
if err := traverseFunc(copy, original); err != nil {
return v, err
}
return copy.Interface().(T), nil
}

View File

@@ -3,6 +3,8 @@ package output
import (
"bytes"
"io"
"github.com/go-task/task/v3/internal/templater"
)
type Group struct {
@@ -10,13 +12,13 @@ type Group struct {
ErrorOnly bool
}
func (g Group) WrapWriter(stdOut, _ io.Writer, _ string, tmpl Templater) (io.Writer, io.Writer, CloseFunc) {
func (g Group) WrapWriter(stdOut, _ io.Writer, _ string, cache *templater.Cache) (io.Writer, io.Writer, CloseFunc) {
gw := &groupWriter{writer: stdOut}
if g.Begin != "" {
gw.begin = tmpl.Replace(g.Begin) + "\n"
gw.begin = templater.Replace(g.Begin, cache) + "\n"
}
if g.End != "" {
gw.end = tmpl.Replace(g.End) + "\n"
gw.end = templater.Replace(g.End, cache) + "\n"
}
return gw, gw, func(err error) error {
if g.ErrorOnly && err == nil {

View File

@@ -2,10 +2,12 @@ package output
import (
"io"
"github.com/go-task/task/v3/internal/templater"
)
type Interleaved struct{}
func (Interleaved) WrapWriter(stdOut, stdErr io.Writer, _ string, _ Templater) (io.Writer, io.Writer, CloseFunc) {
func (Interleaved) WrapWriter(stdOut, stdErr io.Writer, _ string, _ *templater.Cache) (io.Writer, io.Writer, CloseFunc) {
return stdOut, stdErr, func(error) error { return nil }
}

View File

@@ -4,18 +4,12 @@ import (
"fmt"
"io"
"github.com/go-task/task/v3/internal/templater"
"github.com/go-task/task/v3/taskfile/ast"
)
// Templater executes a template engine.
// It is provided by the templater.Templater package.
type Templater interface {
// Replace replaces the provided template string with a rendered string.
Replace(tmpl string) string
}
type Output interface {
WrapWriter(stdOut, stdErr io.Writer, prefix string, tmpl Templater) (io.Writer, io.Writer, CloseFunc)
WrapWriter(stdOut, stdErr io.Writer, prefix string, cache *templater.Cache) (io.Writer, io.Writer, CloseFunc)
}
type CloseFunc func(err error) error

View File

@@ -46,7 +46,7 @@ func TestGroup(t *testing.T) {
}
func TestGroupWithBeginEnd(t *testing.T) {
tmpl := templater.Templater{
tmpl := templater.Cache{
Vars: &ast.Vars{
OrderedMap: omap.FromMap(map[string]ast.Var{
"VAR1": {Value: "example-value"},

View File

@@ -5,11 +5,13 @@ import (
"fmt"
"io"
"strings"
"github.com/go-task/task/v3/internal/templater"
)
type Prefixed struct{}
func (Prefixed) WrapWriter(stdOut, _ io.Writer, prefix string, _ Templater) (io.Writer, io.Writer, CloseFunc) {
func (Prefixed) WrapWriter(stdOut, _ io.Writer, prefix string, _ *templater.Cache) (io.Writer, io.Writer, CloseFunc) {
pw := &prefixWriter{writer: stdOut, prefix: prefix}
return pw, pw, func(error) error { return pw.close() }
}

View File

@@ -6,122 +6,116 @@ import (
"strings"
"text/template"
"github.com/go-task/task/v3/internal/deepcopy"
"github.com/go-task/task/v3/taskfile/ast"
)
// Templater is a help struct that allow us to call "replaceX" funcs multiple
// Cache is a help struct that allow us to call "replaceX" funcs multiple
// times, without having to check for error each time. The first error that
// happen will be assigned to r.err, and consecutive calls to funcs will just
// return the zero value.
type Templater struct {
type Cache struct {
Vars *ast.Vars
cacheMap map[string]any
err error
}
func (r *Templater) ResetCache() {
func (r *Cache) ResetCache() {
r.cacheMap = r.Vars.ToCacheMap()
}
func (r *Templater) Replace(str string) string {
return r.replace(str, nil)
func (r *Cache) Err() error {
return r.err
}
func (r *Templater) ReplaceWithExtra(str string, extra map[string]any) string {
return r.replace(str, extra)
func Replace[T any](v T, cache *Cache) T {
return ReplaceWithExtra(v, cache, nil)
}
func (r *Templater) replace(str string, extra map[string]any) string {
if r.err != nil || str == "" {
return ""
func ReplaceWithExtra[T any](v T, cache *Cache, extra map[string]any) T {
// If there is already an error, do nothing
if cache.err != nil {
return v
}
templ, err := template.New("").Funcs(templateFuncs).Parse(str)
// Initialize the cache map if it's not already initialized
if cache.cacheMap == nil {
cache.cacheMap = cache.Vars.ToCacheMap()
}
// Create a copy of the cache map to avoid editing the original
// If there is extra data, merge it with the cache map
data := maps.Clone(cache.cacheMap)
if extra != nil {
maps.Copy(data, extra)
}
// Traverse the value and parse any template variables
copy, err := deepcopy.TraverseStringsFunc(v, func(v string) (string, error) {
tpl, err := template.New("").Funcs(templateFuncs).Parse(v)
if err != nil {
return v, err
}
var b bytes.Buffer
if err := tpl.Execute(&b, data); err != nil {
return v, err
}
return strings.ReplaceAll(b.String(), "<no value>", ""), nil
})
if err != nil {
r.err = err
return ""
cache.err = err
return v
}
if r.cacheMap == nil {
r.cacheMap = r.Vars.ToCacheMap()
}
var b bytes.Buffer
if extra == nil {
err = templ.Execute(&b, r.cacheMap)
} else {
// Copy the map to avoid modifying the cached map
m := maps.Clone(r.cacheMap)
maps.Copy(m, extra)
err = templ.Execute(&b, m)
}
if err != nil {
r.err = err
return ""
}
return strings.ReplaceAll(b.String(), "<no value>", "")
return copy
}
func (r *Templater) ReplaceSlice(strs []string) []string {
if r.err != nil || len(strs) == 0 {
return nil
}
new := make([]string, len(strs))
for i, str := range strs {
new[i] = r.Replace(str)
}
return new
}
func (r *Templater) ReplaceGlobs(globs []*ast.Glob) []*ast.Glob {
if r.err != nil || len(globs) == 0 {
func ReplaceGlobs(globs []*ast.Glob, cache *Cache) []*ast.Glob {
if cache.err != nil || len(globs) == 0 {
return nil
}
new := make([]*ast.Glob, len(globs))
for i, g := range globs {
new[i] = &ast.Glob{
Glob: r.Replace(g.Glob),
Glob: Replace(g.Glob, cache),
Negate: g.Negate,
}
}
return new
}
func (r *Templater) ReplaceVars(vars *ast.Vars) *ast.Vars {
return r.replaceVars(vars, nil)
func ReplaceVar(v ast.Var, cache *Cache) ast.Var {
return ReplaceVarWithExtra(v, cache, nil)
}
func (r *Templater) ReplaceVarsWithExtra(vars *ast.Vars, extra map[string]any) *ast.Vars {
return r.replaceVars(vars, extra)
func ReplaceVarWithExtra(v ast.Var, cache *Cache, extra map[string]any) ast.Var {
return ast.Var{
Value: ReplaceWithExtra(v.Value, cache, extra),
Sh: ReplaceWithExtra(v.Sh, cache, extra),
Live: v.Live,
Ref: v.Ref,
Dir: v.Dir,
Json: ReplaceWithExtra(v.Json, cache, extra),
Yaml: ReplaceWithExtra(v.Yaml, cache, extra),
}
}
func (r *Templater) replaceVars(vars *ast.Vars, extra map[string]any) *ast.Vars {
if r.err != nil || vars.Len() == 0 {
func ReplaceVars(vars *ast.Vars, cache *Cache) *ast.Vars {
return ReplaceVarsWithExtra(vars, cache, nil)
}
func ReplaceVarsWithExtra(vars *ast.Vars, cache *Cache, extra map[string]any) *ast.Vars {
if cache.err != nil || vars.Len() == 0 {
return nil
}
var newVars ast.Vars
_ = vars.Range(func(k string, v ast.Var) error {
var newVar ast.Var
switch value := v.Value.(type) {
case string:
newVar.Value = r.ReplaceWithExtra(value, extra)
}
newVar.Live = v.Live
newVar.Sh = r.ReplaceWithExtra(v.Sh, extra)
newVar.Ref = v.Ref
newVar.Json = r.ReplaceWithExtra(v.Json, extra)
newVar.Yaml = r.ReplaceWithExtra(v.Yaml, extra)
newVars.Set(k, newVar)
newVars.Set(k, ReplaceVarWithExtra(v, cache, extra))
return nil
})
return &newVars
}
func (r *Templater) Err() error {
return r.err
}