GitRoot

craft your forge, build your project, grow your community freely
  1// SPDX-FileCopyrightText: 2025 Romain Maneschi <romain@gitroot.dev>
  2//
  3// SPDX-License-Identifier: EUPL-1.2
  4
  5package lib
  6
  7import (
  8	"bufio"
  9	"bytes"
 10	"fmt"
 11	"io"
 12	"os"
 13	"os/exec"
 14	"path/filepath"
 15	"runtime"
 16	"strconv"
 17	"strings"
 18	"sync"
 19	"time"
 20
 21	"github.com/shirou/gopsutil/v4/process"
 22	pluginLib "gitroot.dev/libs/golang/plugin/model"
 23)
 24
 25const (
 26	defaultLogLimitBytes  = 10 * 1024 * 1024 // 10Mo
 27	defaultLogDir         = "logs"
 28	defaultArtifactsDir   = "artifacts"
 29	defaultJobContextPath = "../jobContext"
 30)
 31
 32func Start(dir string, input pluginLib.Exec, interpretEnvVar bool) (*pluginLib.ExecStatus, error) {
 33	status := &pluginLib.ExecStatus{
 34		CmdsExec:   make([]string, 0),
 35		CmdsStatus: make([]int, 0),
 36		CmdsLogs:   make([]string, 0),
 37		CmdsStats:  make([]pluginLib.CmdStats, 0),
 38		Artifacts:  make([]string, 0),
 39	}
 40
 41	jobStartMs := time.Now().UnixMilli()
 42	if envStart := os.Getenv("START_TIME_MS"); envStart != "" {
 43		if val, err := strconv.ParseInt(envStart, 10, 64); err == nil {
 44			jobStartMs = val
 45		}
 46	}
 47
 48	limit := int64(defaultLogLimitBytes)
 49	if envLimit := os.Getenv("MAX_LOG_SIZE_MB"); envLimit != "" {
 50		if val, err := strconv.ParseInt(envLimit, 10, 64); err == nil {
 51			limit = val * 1024 * 1024
 52		}
 53	}
 54
 55	logPath := filepath.Join(dir, defaultJobContextPath, defaultLogDir)
 56	os.MkdirAll(logPath, os.ModePerm)
 57	artifactsPath := filepath.Join(dir, defaultJobContextPath, defaultArtifactsDir)
 58	os.MkdirAll(artifactsPath, os.ModePerm)
 59
 60	files, _ := os.ReadDir(logPath)
 61	idx := len(files)
 62	for _, cmdDef := range input.Cmds {
 63		idx++
 64
 65		exitCode, realCmd, logFileName, stats := runStep(dir, idx, cmdDef, input.Env, input.ReportStats, logPath, jobStartMs, limit, interpretEnvVar)
 66
 67		status.CmdsExec = append(status.CmdsExec, realCmd)
 68		status.CmdsStatus = append(status.CmdsStatus, exitCode)
 69		status.CmdsLogs = append(status.CmdsLogs, filepath.Join(defaultLogDir, logFileName))
 70		status.CmdsStats = append(status.CmdsStats, stats)
 71	}
 72
 73	for _, artifact := range input.Artifacts {
 74		if err := checkPath(dir, artifact); err != nil {
 75			continue
 76		}
 77
 78		src, _ := os.Open(filepath.Join(dir, artifact))
 79		defer src.Close()
 80
 81		artifactFilePath := filepath.Join(artifactsPath, artifact)
 82		dir, _ := filepath.Split(artifactFilePath)
 83		os.MkdirAll(dir, os.ModePerm)
 84		dst, _ := os.Create(artifactFilePath)
 85		defer dst.Close()
 86
 87		io.Copy(dst, src)
 88		status.Artifacts = append(status.Artifacts, filepath.Join(defaultArtifactsDir, artifact))
 89	}
 90
 91	return status, nil
 92}
 93
 94func runStep(dir string, idx int, step pluginLib.Cmd, globalEnv []string, reportStats bool, logDir string, jobStartMs, limit int64, interpretEnvVar bool) (int, string, string, pluginLib.CmdStats) {
 95	logFileName := fmt.Sprintf("cmd_%d.log", idx)
 96	logPath := filepath.Join(logDir, logFileName)
 97
 98	stats := pluginLib.CmdStats{}
 99
100	f, err := os.Create(logPath)
101	if err != nil {
102		return -1, "", logFileName, stats
103	}
104
105	logger := &LogWriter{
106		file:       f,
107		limit:      limit,
108		jobStartMs: jobStartMs,
109	}
110
111	cmdStr := step.Cmd
112	args := step.Args
113
114	fullCmdLine := cmdStr + " " + strings.Join(args, " ")
115	useShell := strings.ContainsAny(fullCmdLine, "$>|&")
116
117	var cmd *exec.Cmd
118	if useShell && runtime.GOOS == "linux" && interpretEnvVar {
119		shellCmd := fmt.Sprintf("%s %s", cmdStr, strings.Join(args, " "))
120		cmd = exec.Command("sh", "-c", shellCmd)
121	} else {
122		cmd = exec.Command(cmdStr, args...)
123	}
124
125	cmd.Dir = dir
126
127	cmd.Env = append(cmd.Env, globalEnv...)
128
129	stdoutPipe, _ := cmd.StdoutPipe()
130	stderrPipe, _ := cmd.StderrPipe()
131
132	var wg sync.WaitGroup
133	wg.Add(2)
134
135	go func() {
136		defer wg.Done()
137		reader := bufio.NewReader(stdoutPipe)
138		for {
139			line, err := reader.ReadBytes('\n')
140			if len(line) > 0 {
141				cleanLine := bytes.TrimSuffix(line, []byte("\n"))
142				logger.WriteLine("[o]", cleanLine)
143			}
144			if err != nil {
145				break
146			}
147		}
148		logger.Flush()
149	}()
150
151	go func() {
152		defer wg.Done()
153		reader := bufio.NewReader(stderrPipe)
154		for {
155			line, err := reader.ReadBytes('\n')
156			if len(line) > 0 {
157				cleanLine := bytes.TrimSuffix(line, []byte("\n"))
158				logger.WriteLine("[e]", cleanLine)
159			}
160			if err != nil {
161				break
162			}
163		}
164		logger.Flush()
165	}()
166
167	if err := cmd.Start(); err != nil {
168		fmt.Fprintf(f, "[e][+0] Failed to start: %v\n", err)
169		return -1, cmd.String(), logFileName, stats
170	}
171
172	proc, _ := process.NewProcess(int32(cmd.Process.Pid))
173
174	done := make(chan error, 1)
175	go func() {
176		done <- cmd.Wait()
177	}()
178
179	if reportStats && proc != nil {
180		ticker := time.NewTicker(10 * time.Millisecond)
181		defer ticker.Stop()
182
183	Loop:
184		for {
185			select {
186			case err = <-done:
187				captureStats(proc, &stats)
188				break Loop
189			case <-ticker.C:
190				captureStats(proc, &stats)
191			}
192		}
193	} else {
194		err = <-done
195	}
196
197	wg.Wait()
198
199	logger.Flush()
200	f.Sync()
201	f.Close()
202
203	exitCode := 0
204	if err != nil {
205		if exitErr, ok := err.(*exec.ExitError); ok {
206			exitCode = exitErr.ExitCode()
207		} else {
208			exitCode = 1
209		}
210	}
211
212	return exitCode, cmd.String(), logFileName, stats
213}
214
215func captureStats(p *process.Process, s *pluginLib.CmdStats) {
216	if mem, err := p.MemoryInfo(); err == nil && mem.RSS > s.MaxMemoryBytes {
217		s.MaxMemoryBytes = mem.RSS
218	}
219
220	if numThread, err := p.NumThreads(); err != nil && numThread > s.MaxThreads {
221		s.MaxThreads = numThread
222	}
223
224	if io, err := p.IOCounters(); err == nil {
225		s.ReadBytesTotal = io.ReadBytes
226		s.WriteBytesTotal = io.WriteBytes
227	}
228
229	if cpu, err := p.Times(); err == nil {
230		totalSec := cpu.User + cpu.System
231		s.TotalCPUTimeMs = uint64(totalSec * 1000)
232	}
233}
234
235func checkPath(baseDir, userPath string) error {
236	absBase, err := filepath.Abs(baseDir)
237	if err != nil {
238		return err
239	}
240
241	targetPath := filepath.Join(absBase, userPath)
242
243	rel, err := filepath.Rel(absBase, targetPath)
244	if err != nil {
245		return err
246	}
247
248	if strings.HasPrefix(rel, "..") || strings.HasPrefix(rel, "/") {
249		return fmt.Errorf("not authorised to access %s", rel)
250	}
251
252	return nil
253}