Skip to content

Commit e194d09

Browse files
committed
test: add general helper functions
Signed-off-by: Marvin Drees <[email protected]>
1 parent 1980d8a commit e194d09

File tree

2 files changed

+1029
-0
lines changed

2 files changed

+1029
-0
lines changed
Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
// Copyright 2024 The Update Framework Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License
14+
//
15+
// SPDX-License-Identifier: Apache-2.0
16+
//
17+
18+
package helpers
19+
20+
import (
21+
"crypto/ed25519"
22+
"crypto/rand"
23+
"encoding/json"
24+
"fmt"
25+
"os"
26+
"path/filepath"
27+
"strings"
28+
"testing"
29+
"time"
30+
)
31+
32+
// TestCase represents a generic test case structure for table-driven tests
33+
type TestCase[T any] struct {
34+
Name string
35+
Setup func(t *testing.T) T
36+
Input T
37+
Want T
38+
WantErr bool
39+
ErrorMsg string
40+
Cleanup func(t *testing.T)
41+
}
42+
43+
// TempDirManager manages temporary directories for tests
44+
type TempDirManager struct {
45+
baseTempDir string
46+
tempDirs []string
47+
}
48+
49+
// NewTempDirManager creates a new temporary directory manager
50+
func NewTempDirManager() *TempDirManager {
51+
return &TempDirManager{
52+
baseTempDir: os.TempDir(),
53+
tempDirs: make([]string, 0),
54+
}
55+
}
56+
57+
// CreateTempDir creates a temporary directory and tracks it for cleanup
58+
func (tdm *TempDirManager) CreateTempDir(t *testing.T, pattern string) string {
59+
t.Helper()
60+
tempDir, err := os.MkdirTemp(tdm.baseTempDir, pattern)
61+
if err != nil {
62+
t.Fatalf("failed to create temp dir: %v", err)
63+
}
64+
tdm.tempDirs = append(tdm.tempDirs, tempDir)
65+
return tempDir
66+
}
67+
68+
// Cleanup removes all tracked temporary directories
69+
func (tdm *TempDirManager) Cleanup(t *testing.T) {
70+
t.Helper()
71+
for _, dir := range tdm.tempDirs {
72+
if err := os.RemoveAll(dir); err != nil {
73+
t.Errorf("failed to remove temp dir %s: %v", dir, err)
74+
}
75+
}
76+
tdm.tempDirs = tdm.tempDirs[:0]
77+
}
78+
79+
// WriteTestFile writes content to a file in the given directory
80+
func WriteTestFile(t *testing.T, dir, filename string, content []byte) string {
81+
t.Helper()
82+
filePath := filepath.Join(dir, filename)
83+
if err := os.WriteFile(filePath, content, 0644); err != nil {
84+
t.Fatalf("failed to write test file %s: %v", filePath, err)
85+
}
86+
return filePath
87+
}
88+
89+
// ReadTestFile reads content from a file
90+
func ReadTestFile(t *testing.T, filePath string) []byte {
91+
t.Helper()
92+
content, err := os.ReadFile(filePath)
93+
if err != nil {
94+
t.Fatalf("failed to read test file %s: %v", filePath, err)
95+
}
96+
return content
97+
}
98+
99+
// StripWhitespaces removes all whitespace characters from a byte slice
100+
func StripWhitespaces(data []byte) []byte {
101+
result := make([]byte, 0, len(data))
102+
for _, b := range data {
103+
if b != ' ' && b != '\t' && b != '\n' && b != '\r' {
104+
result = append(result, b)
105+
}
106+
}
107+
return result
108+
}
109+
110+
// CompareJSON compares two JSON byte slices ignoring whitespace differences
111+
func CompareJSON(t *testing.T, got, want []byte) {
112+
t.Helper()
113+
114+
var gotJSON, wantJSON interface{}
115+
116+
if err := json.Unmarshal(got, &gotJSON); err != nil {
117+
t.Fatalf("failed to unmarshal got JSON: %v", err)
118+
}
119+
120+
if err := json.Unmarshal(want, &wantJSON); err != nil {
121+
t.Fatalf("failed to unmarshal want JSON: %v", err)
122+
}
123+
124+
gotBytes, err := json.Marshal(gotJSON)
125+
if err != nil {
126+
t.Fatalf("failed to marshal got JSON: %v", err)
127+
}
128+
129+
wantBytes, err := json.Marshal(wantJSON)
130+
if err != nil {
131+
t.Fatalf("failed to marshal want JSON: %v", err)
132+
}
133+
134+
if string(gotBytes) != string(wantBytes) {
135+
t.Errorf("JSON mismatch:\ngot: %s\nwant: %s", string(gotBytes), string(wantBytes))
136+
}
137+
}
138+
139+
// GenerateTestKeyPair generates a test Ed25519 key pair
140+
func GenerateTestKeyPair(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) {
141+
t.Helper()
142+
pub, priv, err := ed25519.GenerateKey(rand.Reader)
143+
if err != nil {
144+
t.Fatalf("failed to generate test key pair: %v", err)
145+
}
146+
return pub, priv
147+
}
148+
149+
// ErrorContains checks if error contains expected message
150+
func ErrorContains(t *testing.T, err error, expectedMsg string) {
151+
t.Helper()
152+
if err == nil {
153+
t.Fatalf("expected error containing %q, got nil", expectedMsg)
154+
}
155+
if !strings.Contains(err.Error(), expectedMsg) {
156+
t.Fatalf("expected error containing %q, got %q", expectedMsg, err.Error())
157+
}
158+
}
159+
160+
// NoError fails the test if err is not nil
161+
func NoError(t *testing.T, err error) {
162+
t.Helper()
163+
if err != nil {
164+
t.Fatalf("unexpected error: %v", err)
165+
}
166+
}
167+
168+
// MustMarshal marshals data to JSON or fails the test
169+
func MustMarshal(t *testing.T, v interface{}) []byte {
170+
t.Helper()
171+
data, err := json.Marshal(v)
172+
if err != nil {
173+
t.Fatalf("failed to marshal: %v", err)
174+
}
175+
return data
176+
}
177+
178+
// MustUnmarshal unmarshals JSON data or fails the test
179+
func MustUnmarshal[T any](t *testing.T, data []byte) T {
180+
t.Helper()
181+
var result T
182+
if err := json.Unmarshal(data, &result); err != nil {
183+
t.Fatalf("failed to unmarshal: %v", err)
184+
}
185+
return result
186+
}
187+
188+
// AssertEqual compares two values for equality
189+
func AssertEqual[T comparable](t *testing.T, got, want T, msgAndArgs ...interface{}) {
190+
t.Helper()
191+
if got != want {
192+
msg := fmt.Sprintf("values not equal:\ngot: %v\nwant: %v", got, want)
193+
if len(msgAndArgs) > 0 {
194+
if format, ok := msgAndArgs[0].(string); ok {
195+
msg = fmt.Sprintf(format, msgAndArgs[1:]...) + "\n" + msg
196+
}
197+
}
198+
t.Error(msg)
199+
}
200+
}
201+
202+
// AssertNotEqual compares two values for inequality
203+
func AssertNotEqual[T comparable](t *testing.T, got, want T, msgAndArgs ...interface{}) {
204+
t.Helper()
205+
if got == want {
206+
msg := fmt.Sprintf("values should not be equal: %v", got)
207+
if len(msgAndArgs) > 0 {
208+
if format, ok := msgAndArgs[0].(string); ok {
209+
msg = fmt.Sprintf(format, msgAndArgs[1:]...) + "\n" + msg
210+
}
211+
}
212+
t.Error(msg)
213+
}
214+
}
215+
216+
// RunTableTest runs a table-driven test
217+
func RunTableTest[T any](t *testing.T, tests []TestCase[T], testFunc func(t *testing.T, tc TestCase[T])) {
218+
t.Helper()
219+
for _, tt := range tests {
220+
t.Run(tt.Name, func(t *testing.T) {
221+
defer func() {
222+
if tt.Cleanup != nil {
223+
tt.Cleanup(t)
224+
}
225+
}()
226+
testFunc(t, tt)
227+
})
228+
}
229+
}
230+
231+
// CreateInvalidJSON creates various types of invalid JSON for testing
232+
func CreateInvalidJSON() map[string][]byte {
233+
return map[string][]byte{
234+
"empty": []byte(""),
235+
"invalid_json": []byte("{invalid json}"),
236+
"missing_signed": []byte(`{"signatures": []}`),
237+
"wrong_type": []byte(`{"signed": {"_type": "wrong"}, "signatures": []}`),
238+
"missing_version": []byte(`{"signed": {"_type": "root"}, "signatures": []}`),
239+
"negative_version": []byte(`{"signed": {"_type": "root", "version": -1}, "signatures": []}`),
240+
}
241+
}
242+
243+
// Benchmark helper function
244+
func BenchmarkOperation(b *testing.B, operation func()) {
245+
b.Helper()
246+
b.ResetTimer()
247+
for i := 0; i < b.N; i++ {
248+
operation()
249+
}
250+
}
251+
252+
// CreateTestJSON creates test JSON for different metadata types
253+
func CreateTestRootJSON(t *testing.T) []byte {
254+
t.Helper()
255+
256+
expiry := time.Now().UTC().Add(24 * time.Hour)
257+
258+
root := map[string]interface{}{
259+
"signed": map[string]interface{}{
260+
"_type": "root",
261+
"spec_version": "1.0.31",
262+
"version": 1,
263+
"expires": expiry.Format(time.RFC3339),
264+
"consistent_snapshot": true,
265+
"keys": map[string]interface{}{},
266+
"roles": map[string]interface{}{
267+
"root": map[string]interface{}{
268+
"keyids": []string{},
269+
"threshold": 1,
270+
},
271+
"targets": map[string]interface{}{
272+
"keyids": []string{},
273+
"threshold": 1,
274+
},
275+
"snapshot": map[string]interface{}{
276+
"keyids": []string{},
277+
"threshold": 1,
278+
},
279+
"timestamp": map[string]interface{}{
280+
"keyids": []string{},
281+
"threshold": 1,
282+
},
283+
},
284+
},
285+
"signatures": []interface{}{},
286+
}
287+
288+
data, err := json.Marshal(root)
289+
if err != nil {
290+
t.Fatalf("failed to create test root JSON: %v", err)
291+
}
292+
return data
293+
}
294+
295+
func CreateTestTargetsJSON(t *testing.T) []byte {
296+
t.Helper()
297+
298+
expiry := time.Now().UTC().Add(24 * time.Hour)
299+
300+
targets := map[string]interface{}{
301+
"signed": map[string]interface{}{
302+
"_type": "targets",
303+
"spec_version": "1.0.31",
304+
"version": 1,
305+
"expires": expiry.Format(time.RFC3339),
306+
"targets": map[string]interface{}{},
307+
},
308+
"signatures": []interface{}{},
309+
}
310+
311+
data, err := json.Marshal(targets)
312+
if err != nil {
313+
t.Fatalf("failed to create test targets JSON: %v", err)
314+
}
315+
return data
316+
}
317+
318+
func CreateTestSnapshotJSON(t *testing.T) []byte {
319+
t.Helper()
320+
321+
expiry := time.Now().UTC().Add(24 * time.Hour)
322+
323+
snapshot := map[string]interface{}{
324+
"signed": map[string]interface{}{
325+
"_type": "snapshot",
326+
"spec_version": "1.0.31",
327+
"version": 1,
328+
"expires": expiry.Format(time.RFC3339),
329+
"meta": map[string]interface{}{
330+
"targets.json": map[string]interface{}{
331+
"version": 1,
332+
},
333+
},
334+
},
335+
"signatures": []interface{}{},
336+
}
337+
338+
data, err := json.Marshal(snapshot)
339+
if err != nil {
340+
t.Fatalf("failed to create test snapshot JSON: %v", err)
341+
}
342+
return data
343+
}
344+
345+
func CreateTestTimestampJSON(t *testing.T) []byte {
346+
t.Helper()
347+
348+
expiry := time.Now().UTC().Add(24 * time.Hour)
349+
350+
timestamp := map[string]interface{}{
351+
"signed": map[string]interface{}{
352+
"_type": "timestamp",
353+
"spec_version": "1.0.31",
354+
"version": 1,
355+
"expires": expiry.Format(time.RFC3339),
356+
"meta": map[string]interface{}{
357+
"snapshot.json": map[string]interface{}{
358+
"version": 1,
359+
},
360+
},
361+
},
362+
"signatures": []interface{}{},
363+
}
364+
365+
data, err := json.Marshal(timestamp)
366+
if err != nil {
367+
t.Fatalf("failed to create test timestamp JSON: %v", err)
368+
}
369+
return data
370+
}
371+
372+
// HexBytes is a simple type for testing - avoiding import cycles
373+
type HexBytes []byte
374+
375+
func (h HexBytes) String() string {
376+
return fmt.Sprintf("%x", []byte(h))
377+
}

0 commit comments

Comments
 (0)