|
| 1 | +package db_client |
| 2 | + |
| 3 | +import ( |
| 4 | + "os" |
| 5 | + "strings" |
| 6 | + "testing" |
| 7 | + |
| 8 | + "github.com/stretchr/testify/assert" |
| 9 | + "github.com/stretchr/testify/require" |
| 10 | +) |
| 11 | + |
| 12 | +// TestTimestamptzTextFormatImplemented verifies that the timestamptz wire protocol fix is in place. |
| 13 | +// Reference: https://github.com/turbot/steampipe/issues/4450 |
| 14 | +// |
| 15 | +// This test verifies that startQuery uses QueryResultFormatsByOID to request text format |
| 16 | +// for timestamptz columns, ensuring PostgreSQL formats values using the session timezone. |
| 17 | +// |
| 18 | +// Without this fix, pgx uses binary protocol which loses session timezone info, causing |
| 19 | +// timestamptz values to display in the local machine timezone instead of the session timezone. |
| 20 | +func TestTimestamptzTextFormatImplemented(t *testing.T) { |
| 21 | + // Read the db_client_execute.go file to verify the fix is present |
| 22 | + content, err := os.ReadFile("db_client_execute.go") |
| 23 | + require.NoError(t, err, "should be able to read db_client_execute.go") |
| 24 | + |
| 25 | + sourceCode := string(content) |
| 26 | + |
| 27 | + // Verify QueryResultFormatsByOID is used |
| 28 | + assert.Contains(t, sourceCode, "pgx.QueryResultFormatsByOID", |
| 29 | + "QueryResultFormatsByOID must be used to specify format for specific column types") |
| 30 | + |
| 31 | + // Verify TimestamptzOID is referenced |
| 32 | + assert.Contains(t, sourceCode, "pgtype.TimestamptzOID", |
| 33 | + "TimestamptzOID must be specified to request text format for timestamptz columns") |
| 34 | + |
| 35 | + // Verify TextFormatCode is used |
| 36 | + assert.Contains(t, sourceCode, "pgx.TextFormatCode", |
| 37 | + "TextFormatCode must be used to request text format") |
| 38 | + |
| 39 | + // Verify the fix is in startQuery function |
| 40 | + funcStart := strings.Index(sourceCode, "func (c *DbClient) startQuery") |
| 41 | + assert.NotEqual(t, -1, funcStart, "startQuery function must exist") |
| 42 | + |
| 43 | + // Extract just the startQuery function for more precise checking |
| 44 | + funcEnd := strings.Index(sourceCode[funcStart:], "\nfunc ") |
| 45 | + if funcEnd == -1 { |
| 46 | + funcEnd = len(sourceCode) |
| 47 | + } else { |
| 48 | + funcEnd += funcStart |
| 49 | + } |
| 50 | + startQueryFunc := sourceCode[funcStart:funcEnd] |
| 51 | + |
| 52 | + // Verify all three components are in startQuery |
| 53 | + assert.Contains(t, startQueryFunc, "QueryResultFormatsByOID", |
| 54 | + "QueryResultFormatsByOID must be in startQuery function") |
| 55 | + assert.Contains(t, startQueryFunc, "TimestamptzOID", |
| 56 | + "TimestamptzOID must be in startQuery function") |
| 57 | + assert.Contains(t, startQueryFunc, "TextFormatCode", |
| 58 | + "TextFormatCode must be in startQuery function") |
| 59 | + |
| 60 | + // Verify there's a comment explaining the fix |
| 61 | + hasComment := strings.Contains(startQueryFunc, "session timezone") || |
| 62 | + strings.Contains(startQueryFunc, "text format for timestamptz") || |
| 63 | + strings.Contains(startQueryFunc, "Request text format") |
| 64 | + assert.True(t, hasComment, |
| 65 | + "Comment should explain why text format is needed for timestamptz") |
| 66 | + |
| 67 | + // Verify queryArgs are constructed and used |
| 68 | + assert.Contains(t, startQueryFunc, "queryArgs", |
| 69 | + "queryArgs variable must be used to prepend format specification") |
| 70 | + assert.Contains(t, startQueryFunc, "conn.Query(ctx, query, queryArgs...)", |
| 71 | + "conn.Query must use queryArgs instead of args directly") |
| 72 | +} |
| 73 | + |
| 74 | +// TestTimestamptzFormatCorrectness verifies the format specification structure |
| 75 | +func TestTimestamptzFormatCorrectness(t *testing.T) { |
| 76 | + content, err := os.ReadFile("db_client_execute.go") |
| 77 | + require.NoError(t, err, "should be able to read db_client_execute.go") |
| 78 | + |
| 79 | + sourceCode := string(content) |
| 80 | + |
| 81 | + // Verify the QueryResultFormatsByOID is constructed as the first element |
| 82 | + // This is critical - it must be the first argument before actual query parameters |
| 83 | + assert.Contains(t, sourceCode, "queryArgs := make([]any, 0, len(args)+1)", |
| 84 | + "queryArgs must be allocated with capacity for format spec + args") |
| 85 | + |
| 86 | + // Verify format spec is appended first |
| 87 | + lines := strings.Split(sourceCode, "\n") |
| 88 | + var foundMake, foundAppendFormat, foundAppendArgs bool |
| 89 | + var makeIdx, appendFormatIdx, appendArgsIdx int |
| 90 | + |
| 91 | + for i, line := range lines { |
| 92 | + if strings.Contains(line, "queryArgs := make([]any, 0, len(args)+1)") { |
| 93 | + foundMake = true |
| 94 | + makeIdx = i |
| 95 | + } |
| 96 | + if strings.Contains(line, "queryArgs = append(queryArgs, pgx.QueryResultFormatsByOID{") { |
| 97 | + foundAppendFormat = true |
| 98 | + appendFormatIdx = i |
| 99 | + } |
| 100 | + if strings.Contains(line, "queryArgs = append(queryArgs, args...)") { |
| 101 | + foundAppendArgs = true |
| 102 | + appendArgsIdx = i |
| 103 | + } |
| 104 | + } |
| 105 | + |
| 106 | + assert.True(t, foundMake, "queryArgs must be allocated") |
| 107 | + assert.True(t, foundAppendFormat, "format spec must be appended to queryArgs") |
| 108 | + assert.True(t, foundAppendArgs, "original args must be appended to queryArgs") |
| 109 | + |
| 110 | + // Verify correct order: make -> append format spec -> append args |
| 111 | + if foundMake && foundAppendFormat && foundAppendArgs { |
| 112 | + assert.Less(t, makeIdx, appendFormatIdx, |
| 113 | + "queryArgs must be allocated before appending format spec") |
| 114 | + assert.Less(t, appendFormatIdx, appendArgsIdx, |
| 115 | + "format spec must be appended before original args") |
| 116 | + } |
| 117 | +} |
| 118 | + |
| 119 | +// TestTimestamptzFormatDoesNotAffectOtherTypes verifies only timestamptz format is changed |
| 120 | +func TestTimestamptzFormatDoesNotAffectOtherTypes(t *testing.T) { |
| 121 | + content, err := os.ReadFile("db_client_execute.go") |
| 122 | + require.NoError(t, err, "should be able to read db_client_execute.go") |
| 123 | + |
| 124 | + sourceCode := string(content) |
| 125 | + |
| 126 | + // Find the QueryResultFormatsByOID map construction |
| 127 | + funcStart := strings.Index(sourceCode, "func (c *DbClient) startQuery") |
| 128 | + require.NotEqual(t, -1, funcStart, "startQuery function must exist") |
| 129 | + |
| 130 | + funcEnd := strings.Index(sourceCode[funcStart:], "\nfunc ") |
| 131 | + if funcEnd == -1 { |
| 132 | + funcEnd = len(sourceCode) |
| 133 | + } else { |
| 134 | + funcEnd += funcStart |
| 135 | + } |
| 136 | + startQueryFunc := sourceCode[funcStart:funcEnd] |
| 137 | + |
| 138 | + // Verify ONLY TimestamptzOID is in the map (no other OIDs) |
| 139 | + // This ensures we don't accidentally change format for other types |
| 140 | + otherOIDs := []string{ |
| 141 | + "DateOID", |
| 142 | + "TimestampOID", |
| 143 | + "TimeOID", |
| 144 | + "IntervalOID", |
| 145 | + "JSONOID", |
| 146 | + "JSONBOID", |
| 147 | + } |
| 148 | + |
| 149 | + for _, oid := range otherOIDs { |
| 150 | + assert.NotContains(t, startQueryFunc, "pgtype."+oid, |
| 151 | + "Should not change format for "+oid+" - only timestamptz needs text format") |
| 152 | + } |
| 153 | + |
| 154 | + // Verify there's only one entry in QueryResultFormatsByOID |
| 155 | + // Count how many times we see "OID:" in the map definition |
| 156 | + oidCount := strings.Count(startQueryFunc, "OID:") |
| 157 | + assert.Equal(t, 1, oidCount, |
| 158 | + "QueryResultFormatsByOID should have exactly one entry (TimestamptzOID)") |
| 159 | +} |
0 commit comments