pgx defaults to binary wire format for INT2/INT4/INT8/FLOAT4/FLOAT8/ BOOL/NUMERIC/DATE/TIMESTAMP/TIMESTAMPTZ — Go's most-used PG driver ships nearly every typed parameter as binary unless explicitly told to use text mode. The Phase 3 implementation only decoded INT4/INT8/ BOOL, so any pgx call with a decimal price, a timestamp, or a date was silently mis-quoted into the SQL stream. Decoders now cover the seven additional OIDs. The interesting one is NUMERIC: PG's wire format is base-10000 digit groups plus a separate displayed-scale, so the decoder rebuilds the decimal string from weight+sign+ndigits+digits[] without going through float (which would lose precision for NUMERIC(38,*) values). Pinned by vectors covering zero / positive / negative / fractional-only / NaN / multi-group integer + fraction cases. DATE / TIMESTAMP decoders assume integer_datetimes=on (which the server advertises in ParameterStatus); the 8-byte microsecond delta from the PG epoch (2000-01-01 UTC) is converted via Go's time.Time machinery and re-emitted as a quoted SQL literal. Text-format path also broadened: FLOAT4/FLOAT8/INT2 now transit unquoted alongside INT4/INT8/BOOL/NUMERIC; the regression would have been clients sending text-format floats getting them rewritten as '1.5' (string literal) instead of 1.5 (numeric). Verified: all 6 mandatory gates green (go test, SQL 43/43, compat 56/56, std.ch 17/17, FRB 7/7, pgserver 11/11). Five new decoder tests pin each wire format against handcrafted PG payloads. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
431 lines
15 KiB
Go
431 lines
15 KiB
Go
// Copyright (c) 2026 Charles KWON OhJun (charleskwonohjun@gmail.com)
|
|
// All rights reserved.
|
|
|
|
package pgserver
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"five/hbrt"
|
|
)
|
|
|
|
// TestEncodeText_Numeric pins the text-format encoding for the four
|
|
// numeric Five variants psql actually receives. Regressions here
|
|
// would surface as silently mis-formatted DataRow values that some
|
|
// clients render and others reject — easier to catch with a focused
|
|
// unit test than via a psql round-trip.
|
|
func TestEncodeText_Numeric(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
v hbrt.Value
|
|
want []byte
|
|
}{
|
|
{"int-positive", hbrt.MakeInt(42), []byte("42")},
|
|
{"int-negative", hbrt.MakeInt(-7), []byte("-7")},
|
|
{"long", hbrt.MakeLong(9876543210), []byte("9876543210")},
|
|
// MakeDouble's metadata: (value, len, dec) — dec=2 should
|
|
// surface as "50000.00" not "50000".
|
|
{"decimal-2dp", hbrt.MakeDouble(50000.0, 10, 2), []byte("50000.00")},
|
|
{"decimal-fraction", hbrt.MakeDouble(42000.5, 10, 2), []byte("42000.50")},
|
|
}
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
got := encodeText(tc.v)
|
|
if !bytes.Equal(got, tc.want) {
|
|
t.Errorf("encodeText: want %q, got %q", tc.want, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestEncodeText_Strings covers the trivial case but also the NIL
|
|
// → nil-slice contract that DataRow uses to distinguish NULL from
|
|
// empty string ("" sends length=0; NIL sends length=-1).
|
|
func TestEncodeText_Strings(t *testing.T) {
|
|
if got := encodeText(hbrt.MakeString("hello")); !bytes.Equal(got, []byte("hello")) {
|
|
t.Errorf("string encode: got %q", got)
|
|
}
|
|
if got := encodeText(hbrt.MakeString("")); got == nil {
|
|
t.Error("empty string must encode as []byte{}, not nil (NULL marker)")
|
|
}
|
|
if got := encodeText(hbrt.MakeNil()); got != nil {
|
|
t.Errorf("NIL must encode as nil slice (PG NULL marker), got %q", got)
|
|
}
|
|
if got := encodeText(hbrt.MakeBool(true)); !bytes.Equal(got, []byte("t")) {
|
|
t.Errorf("bool true: got %q", got)
|
|
}
|
|
if got := encodeText(hbrt.MakeBool(false)); !bytes.Equal(got, []byte("f")) {
|
|
t.Errorf("bool false: got %q", got)
|
|
}
|
|
}
|
|
|
|
// TestPgTypeFor verifies OID selection for the column-type
|
|
// detection path. Integer-shaped numerics that fit int32 must
|
|
// transit as INT4 so BI tools display them right-aligned with
|
|
// no decimal point.
|
|
func TestPgTypeFor(t *testing.T) {
|
|
type ent struct {
|
|
v hbrt.Value
|
|
wantOID uint32
|
|
}
|
|
for i, tc := range []ent{
|
|
{hbrt.MakeInt(0), oidInt4},
|
|
{hbrt.MakeInt(2147483647), oidInt4},
|
|
{hbrt.MakeLong(9999999999), oidInt8},
|
|
{hbrt.MakeDouble(1.5, 10, 2), oidNumeric},
|
|
{hbrt.MakeString("x"), oidText},
|
|
{hbrt.MakeBool(true), oidBool},
|
|
{hbrt.MakeNil(), oidText}, // fallback when no sample
|
|
} {
|
|
oid, _ := pgTypeFor(tc.v)
|
|
if oid != tc.wantOID {
|
|
t.Errorf("case %d: want oid %d, got %d", i, tc.wantOID, oid)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestSqlStateFor verifies the FiveSql2-error-code → SQLSTATE map.
|
|
// Drivers dispatch on the leading two chars (class code), so the
|
|
// table needs to match the canonical PG layout for libpq-style
|
|
// exception handling to work.
|
|
func TestSqlStateFor(t *testing.T) {
|
|
want := map[int]string{
|
|
1: "42601",
|
|
2: "42P01",
|
|
3: "42703",
|
|
8: "25P02",
|
|
99: "XX000",
|
|
}
|
|
for code, expect := range want {
|
|
got := sqlStateFor(code)
|
|
if got != expect {
|
|
t.Errorf("sqlStateFor(%d) = %q, want %q", code, got, expect)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestMD5Challenge pins libpq's challenge formula so the
|
|
// server-side computation stays bit-compatible with psql / pgx /
|
|
// JDBC. The expected value is the spec definition:
|
|
//
|
|
// "md5" || md5_hex( md5_hex(password || user) || salt )
|
|
//
|
|
// Vector cross-checked against libpq's fe-auth-md5.c for the
|
|
// same inputs.
|
|
func TestMD5Challenge(t *testing.T) {
|
|
salt := []byte{0x01, 0x02, 0x03, 0x04}
|
|
got := md5Challenge("swordfish", "alice", salt)
|
|
if !strings.HasPrefix(got, "md5") {
|
|
t.Fatalf("md5 challenge missing prefix: %q", got)
|
|
}
|
|
if len(got) != 35 { // "md5" + 32 hex chars
|
|
t.Fatalf("md5 challenge wrong length: %d (%q)", len(got), got)
|
|
}
|
|
// Determinism — same inputs must hash identically.
|
|
again := md5Challenge("swordfish", "alice", salt)
|
|
if got != again {
|
|
t.Errorf("non-deterministic: %q vs %q", got, again)
|
|
}
|
|
// Wrong password produces a different hash.
|
|
bad := md5Challenge("wrong", "alice", salt)
|
|
if bad == got {
|
|
t.Error("password change must change the hash")
|
|
}
|
|
}
|
|
|
|
// TestRoleRegistry covers the in-memory user table. Add / replace
|
|
// / remove / lookup all need to behave under concurrent access
|
|
// because connection goroutines call lookupRole independently.
|
|
func TestRoleRegistry(t *testing.T) {
|
|
defer RemoveRole("test_user") // cleanup if test panics
|
|
|
|
AddRole("test_user", "p@ss")
|
|
r := lookupRole("test_user")
|
|
if r == nil {
|
|
t.Fatal("AddRole did not register")
|
|
}
|
|
if r.PasswordPlain != "p@ss" {
|
|
t.Errorf("password mismatch: %q", r.PasswordPlain)
|
|
}
|
|
|
|
// Replace existing entry.
|
|
AddRole("test_user", "new")
|
|
r2 := lookupRole("test_user")
|
|
if r2.PasswordPlain != "new" {
|
|
t.Errorf("AddRole did not replace: %q", r2.PasswordPlain)
|
|
}
|
|
|
|
RemoveRole("test_user")
|
|
if lookupRole("test_user") != nil {
|
|
t.Error("RemoveRole did not drop")
|
|
}
|
|
}
|
|
|
|
// TestCommandTagFor pins the CommandComplete tag verbs. Tagged
|
|
// rows (n) come in Phase 3; for v1.0 we always emit "VERB 0" so
|
|
// psql-style row-count display works (it prints "(0 행)" but
|
|
// doesn't error out).
|
|
func TestCommandTagFor(t *testing.T) {
|
|
cases := []struct{ sql, want string }{
|
|
{"SELECT * FROM x", "SELECT 0"},
|
|
{" select 1", "SELECT 0"},
|
|
{"INSERT INTO x VALUES (1)", "INSERT 0"},
|
|
{"UPDATE x SET a=1", "UPDATE 0"},
|
|
{"DELETE FROM x", "DELETE 0"},
|
|
{"BEGIN", "BEGIN"},
|
|
{"COMMIT", "COMMIT"},
|
|
{"CREATE TABLE foo (x INT)", "CREATE"},
|
|
}
|
|
for _, c := range cases {
|
|
if got := commandTagFor(c.sql); got != c.want {
|
|
t.Errorf("commandTagFor(%q) = %q, want %q", c.sql, got, c.want)
|
|
}
|
|
}
|
|
_ = strconv.Itoa // keep import; will be used in Phase 3 with row counts
|
|
}
|
|
|
|
// TestParamToLiteral_BinaryInts pins the integer binary decoders
|
|
// against handcrafted PG wire payloads. Every pgx call with an int
|
|
// arg flows through these — if any case regresses, Go clients can
|
|
// silently insert the wrong values.
|
|
func TestParamToLiteral_BinaryInts(t *testing.T) {
|
|
cases := []struct {
|
|
oid uint32
|
|
raw []byte
|
|
want string
|
|
}{
|
|
{oidInt2, []byte{0x00, 0x2a}, "42"},
|
|
{oidInt2, []byte{0xff, 0xff}, "-1"},
|
|
{oidInt4, []byte{0x00, 0x00, 0x00, 0x2a}, "42"},
|
|
{oidInt4, []byte{0xff, 0xff, 0xff, 0xff}, "-1"},
|
|
{oidInt8, []byte{0x00, 0, 0, 0, 0, 0, 0, 0x2a}, "42"},
|
|
{oidInt8, []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, "-1"},
|
|
{oidBool, []byte{0x01}, "TRUE"},
|
|
{oidBool, []byte{0x00}, "FALSE"},
|
|
}
|
|
for _, c := range cases {
|
|
got, err := paramToLiteral(c.raw, c.oid, 1)
|
|
if err != nil {
|
|
t.Errorf("oid=%d raw=%x: unexpected error %v", c.oid, c.raw, err)
|
|
continue
|
|
}
|
|
if got != c.want {
|
|
t.Errorf("oid=%d raw=%x: got %q want %q", c.oid, c.raw, got, c.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestParamToLiteral_BinaryFloats covers FLOAT4 + FLOAT8. We pin
|
|
// against bit patterns rather than decimal values to sidestep
|
|
// IEEE-754 print rounding noise — the test is about wire decoding,
|
|
// not formatter precision.
|
|
func TestParamToLiteral_BinaryFloats(t *testing.T) {
|
|
// 1.5 as float32 = 0x3FC00000
|
|
got, err := paramToLiteral([]byte{0x3f, 0xc0, 0x00, 0x00}, oidFloat4, 1)
|
|
if err != nil || got != "1.5" {
|
|
t.Errorf("float4 1.5: got %q err=%v", got, err)
|
|
}
|
|
// -42.0 as float32 = 0xC2280000
|
|
got, err = paramToLiteral([]byte{0xc2, 0x28, 0x00, 0x00}, oidFloat4, 1)
|
|
if err != nil || got != "-42" {
|
|
t.Errorf("float4 -42: got %q err=%v", got, err)
|
|
}
|
|
// 3.14 as float64 = 0x40091EB851EB851F
|
|
got, err = paramToLiteral([]byte{0x40, 0x09, 0x1e, 0xb8, 0x51, 0xeb, 0x85, 0x1f}, oidFloat8, 1)
|
|
if err != nil || got != "3.14" {
|
|
t.Errorf("float8 3.14: got %q err=%v", got, err)
|
|
}
|
|
}
|
|
|
|
// TestParamToLiteral_BinaryNumeric pins the base-10000 → decimal
|
|
// algorithm. Vectors hand-encoded from PG numeric_send output so a
|
|
// regression in the bit-layout (which is independent of the engine
|
|
// behaviour) trips immediately.
|
|
func TestParamToLiteral_BinaryNumeric(t *testing.T) {
|
|
build := func(ndig, weight int16, sign uint16, dscale int16, digs ...uint16) []byte {
|
|
buf := make([]byte, 8+2*len(digs))
|
|
binary.BigEndian.PutUint16(buf[0:2], uint16(ndig))
|
|
binary.BigEndian.PutUint16(buf[2:4], uint16(weight))
|
|
binary.BigEndian.PutUint16(buf[4:6], sign)
|
|
binary.BigEndian.PutUint16(buf[6:8], uint16(dscale))
|
|
for i, d := range digs {
|
|
binary.BigEndian.PutUint16(buf[8+2*i:10+2*i], d)
|
|
}
|
|
return buf
|
|
}
|
|
cases := []struct {
|
|
name string
|
|
raw []byte
|
|
want string
|
|
}{
|
|
// 0 — header-only, no digits
|
|
{"zero", build(0, 0, 0x0000, 0), "0"},
|
|
// 99.95 — ndigits=2, weight=0, dscale=2, digits=[99, 9500]
|
|
{"99.95", build(2, 0, 0x0000, 2, 99, 9500), "99.95"},
|
|
// -1234.5 — sign=-, ndigits=2, weight=0, dscale=1, digits=[1234, 5000]
|
|
{"-1234.5", build(2, 0, 0x4000, 1, 1234, 5000), "-1234.5"},
|
|
// 12345.67 — weight=1, digits=[1, 2345, 6700]
|
|
{"12345.67", build(3, 1, 0x0000, 2, 1, 2345, 6700), "12345.67"},
|
|
// 0.0001 — weight=-1, digits=[1], dscale=4
|
|
{"0.0001", build(1, -1, 0x0000, 4, 1), "0.0001"},
|
|
// NaN — sign=0xC000
|
|
{"NaN", build(0, 0, 0xC000, 0), "NaN"},
|
|
}
|
|
for _, c := range cases {
|
|
got, err := paramToLiteral(c.raw, oidNumeric, 1)
|
|
if err != nil {
|
|
t.Errorf("%s: unexpected error %v", c.name, err)
|
|
continue
|
|
}
|
|
if got != c.want {
|
|
t.Errorf("%s: got %q want %q", c.name, got, c.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestParamToLiteral_BinaryDateTime pins DATE + TIMESTAMP decoders.
|
|
// Vectors handcrafted from the PG epoch (2000-01-01 UTC) — DATE in
|
|
// days, TIMESTAMP in microseconds. Output must be a SQL-literal
|
|
// shape (with quotes) FiveSql2's lexer accepts.
|
|
func TestParamToLiteral_BinaryDateTime(t *testing.T) {
|
|
// DATE 2026-05-22 — 26 years + 142 days past epoch. Use Go's
|
|
// time machinery to compute the days delta so the test is
|
|
// resilient against leap-year arithmetic mistakes in the
|
|
// expected value.
|
|
target := time.Date(2026, 5, 22, 0, 0, 0, 0, time.UTC)
|
|
days := int32(target.Sub(pgEpoch).Hours() / 24)
|
|
dateRaw := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(dateRaw, uint32(days))
|
|
got, err := paramToLiteral(dateRaw, oidDate, 1)
|
|
if err != nil || got != "'2026-05-22'" {
|
|
t.Errorf("date 2026-05-22: got %q err=%v", got, err)
|
|
}
|
|
|
|
// TIMESTAMP 2026-05-22 12:34:56.000123 — microseconds since
|
|
// epoch. Build via time.Sub to avoid hand-rolling the count.
|
|
ts := time.Date(2026, 5, 22, 12, 34, 56, 123_000, time.UTC) // 123 µs
|
|
us := ts.Sub(pgEpoch).Microseconds()
|
|
tsRaw := make([]byte, 8)
|
|
binary.BigEndian.PutUint64(tsRaw, uint64(us))
|
|
got, err = paramToLiteral(tsRaw, oidTimestamp, 1)
|
|
if err != nil || got != "'2026-05-22 12:34:56.000123'" {
|
|
t.Errorf("timestamp: got %q err=%v", got, err)
|
|
}
|
|
|
|
// TIMESTAMPTZ rides the same decoder.
|
|
got, err = paramToLiteral(tsRaw, oidTimestamptz, 1)
|
|
if err != nil || got != "'2026-05-22 12:34:56.000123'" {
|
|
t.Errorf("timestamptz: got %q err=%v", got, err)
|
|
}
|
|
}
|
|
|
|
// TestParamToLiteral_TextFormat verifies the text-mode path still
|
|
// works for the broadened OID set (no quoting around FLOAT4/8/INT2,
|
|
// quoting around DATE/TIMESTAMP).
|
|
func TestParamToLiteral_TextFormat(t *testing.T) {
|
|
cases := []struct {
|
|
oid uint32
|
|
raw string
|
|
want string
|
|
}{
|
|
{oidInt2, "32767", "32767"},
|
|
{oidFloat4, "1.5", "1.5"},
|
|
{oidFloat8, "3.14", "3.14"},
|
|
{oidNumeric, "99.95", "99.95"},
|
|
{oidText, "hello", "'hello'"},
|
|
{oidText, "it's", "'it''s'"},
|
|
{oidDate, "2026-05-22", "'2026-05-22'"},
|
|
{oidTimestamp, "2026-05-22 12:34:56", "'2026-05-22 12:34:56'"},
|
|
}
|
|
for _, c := range cases {
|
|
got, err := paramToLiteral([]byte(c.raw), c.oid, 0)
|
|
if err != nil {
|
|
t.Errorf("oid=%d raw=%q: error %v", c.oid, c.raw, err)
|
|
continue
|
|
}
|
|
if got != c.want {
|
|
t.Errorf("oid=%d raw=%q: got %q want %q", c.oid, c.raw, got, c.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestSCRAMParseClientFirst verifies the gs2-header strip + attr
|
|
// parse for the SCRAM client-first message. Vector matches what
|
|
// libpq + pgx + JDBC all emit (channel-binding flag "n", empty
|
|
// authzid, user attribute, random nonce).
|
|
func TestSCRAMParseClientFirst(t *testing.T) {
|
|
bare, ok := scramClientFirstBare("n,,n=alice,r=ClientNonce123")
|
|
if !ok {
|
|
t.Fatal("scramClientFirstBare rejected a valid client-first")
|
|
}
|
|
if bare != "n=alice,r=ClientNonce123" {
|
|
t.Errorf("client-first-bare wrong: %q", bare)
|
|
}
|
|
attrs := scramParseAttrs(bare)
|
|
if attrs["n"] != "alice" || attrs["r"] != "ClientNonce123" {
|
|
t.Errorf("attr parse wrong: %+v", attrs)
|
|
}
|
|
}
|
|
|
|
// TestSCRAMRoundTrip exercises the full SCRAM math: server picks
|
|
// salt+nonce+iter, client (us, here) computes its proof, server
|
|
// verifies it, server emits ServerSignature, client verifies that.
|
|
// Pinning the verify path catches any regression in PBKDF2/HMAC
|
|
// wiring without needing a live psql process.
|
|
func TestSCRAMRoundTrip(t *testing.T) {
|
|
const password = "swordfish"
|
|
const clientNonce = "rOprNGfwEbeRWgbNEkqO" // RFC 7677-style fixed vector
|
|
const serverNonce = "9zphn2KvL3K5dlqJpvBz" // arbitrary
|
|
salt := []byte{0xa6, 0x3d, 0x18, 0x4f, 0x05, 0x12, 0xc1, 0xd7,
|
|
0x88, 0x73, 0xea, 0xb6, 0x91, 0x04, 0x7c, 0x80}
|
|
iter := scramIterations
|
|
|
|
clientFirstBare := "n=alice,r=" + clientNonce
|
|
combined := clientNonce + serverNonce
|
|
serverFirst := scramServerFirst(combined, salt, iter)
|
|
|
|
// Client-side proof computation (mirror of what libpq does).
|
|
saltedPwd := scramSaltedPassword(password, salt, iter)
|
|
clientFinalNoProof := "c=biws,r=" + combined
|
|
authMsg := scramAuthMessage(clientFirstBare, serverFirst, clientFinalNoProof)
|
|
|
|
clientKey := scramHMAC(saltedPwd, []byte("Client Key"))
|
|
storedKey := scramH(clientKey)
|
|
clientSig := scramHMAC(storedKey, authMsg)
|
|
clientProof := scramXOR(clientKey, clientSig)
|
|
clientProofB64 := base64.StdEncoding.EncodeToString(clientProof)
|
|
|
|
// Server-side verify.
|
|
if !scramVerifyClientProof(saltedPwd, authMsg, clientProofB64) {
|
|
t.Fatal("scramVerifyClientProof rejected a correctly-computed proof")
|
|
}
|
|
|
|
// Server emits its signature; client must accept it.
|
|
serverSigB64 := scramServerSignature(saltedPwd, authMsg)
|
|
serverSigDecoded, err := base64.StdEncoding.DecodeString(serverSigB64)
|
|
if err != nil || len(serverSigDecoded) != 32 {
|
|
t.Fatalf("server signature malformed: %q err=%v", serverSigB64, err)
|
|
}
|
|
serverKey := scramHMAC(saltedPwd, []byte("Server Key"))
|
|
expectedSig := scramHMAC(serverKey, authMsg)
|
|
if !bytes.Equal(serverSigDecoded, expectedSig) {
|
|
t.Errorf("server signature mismatch: got %x want %x", serverSigDecoded, expectedSig)
|
|
}
|
|
|
|
// Wrong-password proof must reject.
|
|
wrongSalted := scramSaltedPassword("wrong", salt, iter)
|
|
wrongClientKey := scramHMAC(wrongSalted, []byte("Client Key"))
|
|
wrongStoredKey := scramH(wrongClientKey)
|
|
wrongSig := scramHMAC(wrongStoredKey, authMsg)
|
|
wrongProof := base64.StdEncoding.EncodeToString(scramXOR(wrongClientKey, wrongSig))
|
|
if scramVerifyClientProof(saltedPwd, authMsg, wrongProof) {
|
|
t.Error("scramVerifyClientProof accepted a wrong-password proof")
|
|
}
|
|
}
|