feat: support self-signed certificates for remote taskfiles (#2537)

This commit is contained in:
Valentin Maerten
2026-01-25 18:51:30 +01:00
committed by GitHub
parent f6720760b4
commit 026c899d90
16 changed files with 520 additions and 8 deletions

View File

@@ -34,13 +34,14 @@ func NewRootNode(
dir string,
insecure bool,
timeout time.Duration,
opts ...NodeOption,
) (Node, error) {
dir = fsext.DefaultDir(entrypoint, dir)
// If the entrypoint is "-", we read from stdin
if entrypoint == "-" {
return NewStdinNode(dir)
}
return NewNode(entrypoint, dir, insecure)
return NewNode(entrypoint, dir, insecure, opts...)
}
func NewNode(

View File

@@ -10,6 +10,9 @@ type (
parent Node
dir string
checksum string
caCert string
cert string
certKey string
}
)
@@ -54,3 +57,21 @@ func (node *baseNode) Checksum() string {
func (node *baseNode) Verify(checksum string) bool {
return node.checksum == "" || node.checksum == checksum
}
func WithCACert(caCert string) NodeOption {
return func(node *baseNode) {
node.caCert = caCert
}
}
func WithCert(cert string) NodeOption {
return func(node *baseNode) {
node.cert = cert
}
}
func WithCertKey(certKey string) NodeOption {
return func(node *baseNode) {
node.certKey = certKey
}
}

View File

@@ -2,10 +2,13 @@ package taskfile
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
@@ -17,7 +20,54 @@ import (
// An HTTPNode is a node that reads a Taskfile from a remote location via HTTP.
type HTTPNode struct {
*baseNode
url *url.URL // stores url pointing actual remote file. (e.g. with Taskfile.yml)
url *url.URL // stores url pointing actual remote file. (e.g. with Taskfile.yml)
client *http.Client // HTTP client with optional TLS configuration
}
// buildHTTPClient creates an HTTP client with optional TLS configuration.
// If no certificate options are provided, it returns http.DefaultClient.
func buildHTTPClient(insecure bool, caCert, cert, certKey string) (*http.Client, error) {
// Validate that cert and certKey are provided together
if (cert != "" && certKey == "") || (cert == "" && certKey != "") {
return nil, fmt.Errorf("both --cert and --cert-key must be provided together")
}
// If no TLS customization is needed, return the default client
if !insecure && caCert == "" && cert == "" {
return http.DefaultClient, nil
}
tlsConfig := &tls.Config{
InsecureSkipVerify: insecure,
}
// Load custom CA certificate if provided
if caCert != "" {
caCertData, err := os.ReadFile(caCert)
if err != nil {
return nil, fmt.Errorf("failed to read CA certificate: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCertData) {
return nil, fmt.Errorf("failed to parse CA certificate")
}
tlsConfig.RootCAs = caCertPool
}
// Load client certificate and key if provided
if cert != "" && certKey != "" {
clientCert, err := tls.LoadX509KeyPair(cert, certKey)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{clientCert}
}
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
},
}, nil
}
func NewHTTPNode(
@@ -34,9 +84,16 @@ func NewHTTPNode(
if url.Scheme == "http" && !insecure {
return nil, &errors.TaskfileNotSecureError{URI: url.Redacted()}
}
client, err := buildHTTPClient(insecure, base.caCert, base.cert, base.certKey)
if err != nil {
return nil, err
}
return &HTTPNode{
baseNode: base,
url: url,
client: client,
}, nil
}
@@ -49,7 +106,7 @@ func (node *HTTPNode) Read() ([]byte, error) {
}
func (node *HTTPNode) ReadContext(ctx context.Context) ([]byte, error) {
url, err := RemoteExists(ctx, *node.url)
url, err := RemoteExists(ctx, *node.url, node.client)
if err != nil {
return nil, err
}
@@ -58,7 +115,7 @@ func (node *HTTPNode) ReadContext(ctx context.Context) ([]byte, error) {
return nil, errors.TaskfileFetchFailedError{URI: node.Location()}
}
resp, err := http.DefaultClient.Do(req.WithContext(ctx))
resp, err := node.client.Do(req.WithContext(ctx))
if err != nil {
if ctx.Err() != nil {
return nil, err

View File

@@ -1,7 +1,18 @@
package taskfile
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"net/http"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -47,3 +58,227 @@ func TestHTTPNode_CacheKey(t *testing.T) {
assert.Equal(t, tt.expectedKey, key)
}
}
func TestBuildHTTPClient_Default(t *testing.T) {
t.Parallel()
// When no TLS customization is needed, should return http.DefaultClient
client, err := buildHTTPClient(false, "", "", "")
require.NoError(t, err)
assert.Equal(t, http.DefaultClient, client)
}
func TestBuildHTTPClient_Insecure(t *testing.T) {
t.Parallel()
client, err := buildHTTPClient(true, "", "", "")
require.NoError(t, err)
require.NotNil(t, client)
assert.NotEqual(t, http.DefaultClient, client)
// Check that InsecureSkipVerify is set
transport, ok := client.Transport.(*http.Transport)
require.True(t, ok)
require.NotNil(t, transport.TLSClientConfig)
assert.True(t, transport.TLSClientConfig.InsecureSkipVerify)
}
func TestBuildHTTPClient_CACert(t *testing.T) {
t.Parallel()
// Create a temporary CA cert file
tempDir := t.TempDir()
caCertPath := filepath.Join(tempDir, "ca.crt")
// Generate a valid CA certificate
caCertPEM := generateTestCACert(t)
err := os.WriteFile(caCertPath, caCertPEM, 0o600)
require.NoError(t, err)
client, err := buildHTTPClient(false, caCertPath, "", "")
require.NoError(t, err)
require.NotNil(t, client)
assert.NotEqual(t, http.DefaultClient, client)
// Check that custom RootCAs is set
transport, ok := client.Transport.(*http.Transport)
require.True(t, ok)
require.NotNil(t, transport.TLSClientConfig)
assert.NotNil(t, transport.TLSClientConfig.RootCAs)
}
func TestBuildHTTPClient_CACertNotFound(t *testing.T) {
t.Parallel()
client, err := buildHTTPClient(false, "/nonexistent/ca.crt", "", "")
assert.Error(t, err)
assert.Nil(t, client)
assert.Contains(t, err.Error(), "failed to read CA certificate")
}
func TestBuildHTTPClient_CACertInvalid(t *testing.T) {
t.Parallel()
// Create a temporary file with invalid content
tempDir := t.TempDir()
caCertPath := filepath.Join(tempDir, "invalid.crt")
err := os.WriteFile(caCertPath, []byte("not a valid certificate"), 0o600)
require.NoError(t, err)
client, err := buildHTTPClient(false, caCertPath, "", "")
assert.Error(t, err)
assert.Nil(t, client)
assert.Contains(t, err.Error(), "failed to parse CA certificate")
}
func TestBuildHTTPClient_CertWithoutKey(t *testing.T) {
t.Parallel()
client, err := buildHTTPClient(false, "", "/path/to/cert.crt", "")
assert.Error(t, err)
assert.Nil(t, client)
assert.Contains(t, err.Error(), "both --cert and --cert-key must be provided together")
}
func TestBuildHTTPClient_KeyWithoutCert(t *testing.T) {
t.Parallel()
client, err := buildHTTPClient(false, "", "", "/path/to/key.pem")
assert.Error(t, err)
assert.Nil(t, client)
assert.Contains(t, err.Error(), "both --cert and --cert-key must be provided together")
}
func TestBuildHTTPClient_CertAndKey(t *testing.T) {
t.Parallel()
// Create temporary cert and key files
tempDir := t.TempDir()
certPath := filepath.Join(tempDir, "client.crt")
keyPath := filepath.Join(tempDir, "client.key")
// Generate a self-signed certificate and key for testing
cert, key := generateTestCertAndKey(t)
err := os.WriteFile(certPath, cert, 0o600)
require.NoError(t, err)
err = os.WriteFile(keyPath, key, 0o600)
require.NoError(t, err)
client, err := buildHTTPClient(false, "", certPath, keyPath)
require.NoError(t, err)
require.NotNil(t, client)
assert.NotEqual(t, http.DefaultClient, client)
// Check that client certificate is set
transport, ok := client.Transport.(*http.Transport)
require.True(t, ok)
require.NotNil(t, transport.TLSClientConfig)
assert.Len(t, transport.TLSClientConfig.Certificates, 1)
}
func TestBuildHTTPClient_CertNotFound(t *testing.T) {
t.Parallel()
client, err := buildHTTPClient(false, "", "/nonexistent/cert.crt", "/nonexistent/key.pem")
assert.Error(t, err)
assert.Nil(t, client)
assert.Contains(t, err.Error(), "failed to load client certificate")
}
func TestBuildHTTPClient_InsecureWithCACert(t *testing.T) {
t.Parallel()
// Create a temporary CA cert file
tempDir := t.TempDir()
caCertPath := filepath.Join(tempDir, "ca.crt")
// Generate a valid CA certificate
caCertPEM := generateTestCACert(t)
err := os.WriteFile(caCertPath, caCertPEM, 0o600)
require.NoError(t, err)
// Both insecure and CA cert can be set together
client, err := buildHTTPClient(true, caCertPath, "", "")
require.NoError(t, err)
require.NotNil(t, client)
transport, ok := client.Transport.(*http.Transport)
require.True(t, ok)
require.NotNil(t, transport.TLSClientConfig)
assert.True(t, transport.TLSClientConfig.InsecureSkipVerify)
assert.NotNil(t, transport.TLSClientConfig.RootCAs)
}
// generateTestCertAndKey generates a self-signed certificate and key for testing
func generateTestCertAndKey(t *testing.T) (certPEM, keyPEM []byte) {
t.Helper()
// Generate a new ECDSA private key
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
// Create a certificate template
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Task Org"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
}
// Create the certificate
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
// Encode certificate to PEM
certPEM = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
// Encode private key to PEM
keyDER, err := x509.MarshalECPrivateKey(privateKey)
require.NoError(t, err)
keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "EC PRIVATE KEY",
Bytes: keyDER,
})
return certPEM, keyPEM
}
// generateTestCACert generates a self-signed CA certificate for testing
func generateTestCACert(t *testing.T) []byte {
t.Helper()
// Generate a new ECDSA private key
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
// Create a CA certificate template
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test CA"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
IsCA: true,
BasicConstraintsValid: true,
}
// Create the certificate
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
// Encode certificate to PEM
return pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
}

View File

@@ -47,6 +47,9 @@ type (
trustedHosts []string
tempDir string
cacheExpiryDuration time.Duration
caCert string
cert string
certKey string
debugFunc DebugFunc
promptFunc PromptFunc
promptMutex sync.Mutex
@@ -199,6 +202,45 @@ func (o *promptFuncOption) ApplyToReader(r *Reader) {
r.promptFunc = o.promptFunc
}
// WithReaderCACert sets the path to a custom CA certificate for TLS connections.
func WithReaderCACert(caCert string) ReaderOption {
return &readerCACertOption{caCert: caCert}
}
type readerCACertOption struct {
caCert string
}
func (o *readerCACertOption) ApplyToReader(r *Reader) {
r.caCert = o.caCert
}
// WithReaderCert sets the path to a client certificate for TLS connections.
func WithReaderCert(cert string) ReaderOption {
return &readerCertOption{cert: cert}
}
type readerCertOption struct {
cert string
}
func (o *readerCertOption) ApplyToReader(r *Reader) {
r.cert = o.cert
}
// WithReaderCertKey sets the path to a client certificate key for TLS connections.
func WithReaderCertKey(certKey string) ReaderOption {
return &readerCertKeyOption{certKey: certKey}
}
type readerCertKeyOption struct {
certKey string
}
func (o *readerCertKeyOption) ApplyToReader(r *Reader) {
r.certKey = o.certKey
}
// Read will read the Taskfile defined by the [Reader]'s [Node] and recurse
// through any [ast.Includes] it finds, reading each included Taskfile and
// building an [ast.TaskfileGraph] as it goes. If any errors occur, they will be
@@ -314,6 +356,9 @@ func (r *Reader) include(ctx context.Context, node Node) error {
includeNode, err := NewNode(entrypoint, include.Dir, r.insecure,
WithParent(node),
WithChecksum(include.Checksum),
WithCACert(r.caCert),
WithCert(r.cert),
WithCertKey(r.certKey),
)
if err != nil {
if include.Optional {

View File

@@ -38,7 +38,7 @@ var (
// at the given URL with any of the default Taskfile files names. If any of
// these match a file, the first matching path will be returned. If no files are
// found, an error will be returned.
func RemoteExists(ctx context.Context, u url.URL) (*url.URL, error) {
func RemoteExists(ctx context.Context, u url.URL, client *http.Client) (*url.URL, error) {
// Create a new HEAD request for the given URL to check if the resource exists
req, err := http.NewRequestWithContext(ctx, "HEAD", u.String(), nil)
if err != nil {
@@ -46,7 +46,7 @@ func RemoteExists(ctx context.Context, u url.URL) (*url.URL, error) {
}
// Request the given URL
resp, err := http.DefaultClient.Do(req)
resp, err := client.Do(req)
if err != nil {
if ctx.Err() != nil {
return nil, fmt.Errorf("checking remote file: %w", ctx.Err())
@@ -78,7 +78,7 @@ func RemoteExists(ctx context.Context, u url.URL) (*url.URL, error) {
req.URL = alt
// Try the alternative URL
resp, err = http.DefaultClient.Do(req)
resp, err = client.Do(req)
if err != nil {
return nil, errors.TaskfileFetchFailedError{URI: u.Redacted()}
}