// Copyright (c) 2026 Charles KWON OhJun (charleskwonohjun@gmail.com) // All rights reserved. package pgserver import ( "bytes" "encoding/base64" "strconv" "strings" "testing" "five/hbrt" ) // TestEncodeText_Numeric pins the text-format encoding for the four // numeric Five variants psql actually receives. Regressions here // would surface as silently mis-formatted DataRow values that some // clients render and others reject — easier to catch with a focused // unit test than via a psql round-trip. func TestEncodeText_Numeric(t *testing.T) { cases := []struct { name string v hbrt.Value want []byte }{ {"int-positive", hbrt.MakeInt(42), []byte("42")}, {"int-negative", hbrt.MakeInt(-7), []byte("-7")}, {"long", hbrt.MakeLong(9876543210), []byte("9876543210")}, // MakeDouble's metadata: (value, len, dec) — dec=2 should // surface as "50000.00" not "50000". {"decimal-2dp", hbrt.MakeDouble(50000.0, 10, 2), []byte("50000.00")}, {"decimal-fraction", hbrt.MakeDouble(42000.5, 10, 2), []byte("42000.50")}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { got := encodeText(tc.v) if !bytes.Equal(got, tc.want) { t.Errorf("encodeText: want %q, got %q", tc.want, got) } }) } } // TestEncodeText_Strings covers the trivial case but also the NIL // → nil-slice contract that DataRow uses to distinguish NULL from // empty string ("" sends length=0; NIL sends length=-1). func TestEncodeText_Strings(t *testing.T) { if got := encodeText(hbrt.MakeString("hello")); !bytes.Equal(got, []byte("hello")) { t.Errorf("string encode: got %q", got) } if got := encodeText(hbrt.MakeString("")); got == nil { t.Error("empty string must encode as []byte{}, not nil (NULL marker)") } if got := encodeText(hbrt.MakeNil()); got != nil { t.Errorf("NIL must encode as nil slice (PG NULL marker), got %q", got) } if got := encodeText(hbrt.MakeBool(true)); !bytes.Equal(got, []byte("t")) { t.Errorf("bool true: got %q", got) } if got := encodeText(hbrt.MakeBool(false)); !bytes.Equal(got, []byte("f")) { t.Errorf("bool false: got %q", got) } } // TestPgTypeFor verifies OID selection for the column-type // detection path. Integer-shaped numerics that fit int32 must // transit as INT4 so BI tools display them right-aligned with // no decimal point. func TestPgTypeFor(t *testing.T) { type ent struct { v hbrt.Value wantOID uint32 } for i, tc := range []ent{ {hbrt.MakeInt(0), oidInt4}, {hbrt.MakeInt(2147483647), oidInt4}, {hbrt.MakeLong(9999999999), oidInt8}, {hbrt.MakeDouble(1.5, 10, 2), oidNumeric}, {hbrt.MakeString("x"), oidText}, {hbrt.MakeBool(true), oidBool}, {hbrt.MakeNil(), oidText}, // fallback when no sample } { oid, _ := pgTypeFor(tc.v) if oid != tc.wantOID { t.Errorf("case %d: want oid %d, got %d", i, tc.wantOID, oid) } } } // TestSqlStateFor verifies the FiveSql2-error-code → SQLSTATE map. // Drivers dispatch on the leading two chars (class code), so the // table needs to match the canonical PG layout for libpq-style // exception handling to work. func TestSqlStateFor(t *testing.T) { want := map[int]string{ 1: "42601", 2: "42P01", 3: "42703", 8: "25P02", 99: "XX000", } for code, expect := range want { got := sqlStateFor(code) if got != expect { t.Errorf("sqlStateFor(%d) = %q, want %q", code, got, expect) } } } // TestMD5Challenge pins libpq's challenge formula so the // server-side computation stays bit-compatible with psql / pgx / // JDBC. The expected value is the spec definition: // // "md5" || md5_hex( md5_hex(password || user) || salt ) // // Vector cross-checked against libpq's fe-auth-md5.c for the // same inputs. func TestMD5Challenge(t *testing.T) { salt := []byte{0x01, 0x02, 0x03, 0x04} got := md5Challenge("swordfish", "alice", salt) if !strings.HasPrefix(got, "md5") { t.Fatalf("md5 challenge missing prefix: %q", got) } if len(got) != 35 { // "md5" + 32 hex chars t.Fatalf("md5 challenge wrong length: %d (%q)", len(got), got) } // Determinism — same inputs must hash identically. again := md5Challenge("swordfish", "alice", salt) if got != again { t.Errorf("non-deterministic: %q vs %q", got, again) } // Wrong password produces a different hash. bad := md5Challenge("wrong", "alice", salt) if bad == got { t.Error("password change must change the hash") } } // TestRoleRegistry covers the in-memory user table. Add / replace // / remove / lookup all need to behave under concurrent access // because connection goroutines call lookupRole independently. func TestRoleRegistry(t *testing.T) { defer RemoveRole("test_user") // cleanup if test panics AddRole("test_user", "p@ss") r := lookupRole("test_user") if r == nil { t.Fatal("AddRole did not register") } if r.PasswordPlain != "p@ss" { t.Errorf("password mismatch: %q", r.PasswordPlain) } // Replace existing entry. AddRole("test_user", "new") r2 := lookupRole("test_user") if r2.PasswordPlain != "new" { t.Errorf("AddRole did not replace: %q", r2.PasswordPlain) } RemoveRole("test_user") if lookupRole("test_user") != nil { t.Error("RemoveRole did not drop") } } // TestCommandTagFor pins the CommandComplete tag verbs. Tagged // rows (n) come in Phase 3; for v1.0 we always emit "VERB 0" so // psql-style row-count display works (it prints "(0 행)" but // doesn't error out). func TestCommandTagFor(t *testing.T) { cases := []struct{ sql, want string }{ {"SELECT * FROM x", "SELECT 0"}, {" select 1", "SELECT 0"}, {"INSERT INTO x VALUES (1)", "INSERT 0"}, {"UPDATE x SET a=1", "UPDATE 0"}, {"DELETE FROM x", "DELETE 0"}, {"BEGIN", "BEGIN"}, {"COMMIT", "COMMIT"}, {"CREATE TABLE foo (x INT)", "CREATE"}, } for _, c := range cases { if got := commandTagFor(c.sql); got != c.want { t.Errorf("commandTagFor(%q) = %q, want %q", c.sql, got, c.want) } } _ = strconv.Itoa // keep import; will be used in Phase 3 with row counts } // TestSCRAMParseClientFirst verifies the gs2-header strip + attr // parse for the SCRAM client-first message. Vector matches what // libpq + pgx + JDBC all emit (channel-binding flag "n", empty // authzid, user attribute, random nonce). func TestSCRAMParseClientFirst(t *testing.T) { bare, ok := scramClientFirstBare("n,,n=alice,r=ClientNonce123") if !ok { t.Fatal("scramClientFirstBare rejected a valid client-first") } if bare != "n=alice,r=ClientNonce123" { t.Errorf("client-first-bare wrong: %q", bare) } attrs := scramParseAttrs(bare) if attrs["n"] != "alice" || attrs["r"] != "ClientNonce123" { t.Errorf("attr parse wrong: %+v", attrs) } } // TestSCRAMRoundTrip exercises the full SCRAM math: server picks // salt+nonce+iter, client (us, here) computes its proof, server // verifies it, server emits ServerSignature, client verifies that. // Pinning the verify path catches any regression in PBKDF2/HMAC // wiring without needing a live psql process. func TestSCRAMRoundTrip(t *testing.T) { const password = "swordfish" const clientNonce = "rOprNGfwEbeRWgbNEkqO" // RFC 7677-style fixed vector const serverNonce = "9zphn2KvL3K5dlqJpvBz" // arbitrary salt := []byte{0xa6, 0x3d, 0x18, 0x4f, 0x05, 0x12, 0xc1, 0xd7, 0x88, 0x73, 0xea, 0xb6, 0x91, 0x04, 0x7c, 0x80} iter := scramIterations clientFirstBare := "n=alice,r=" + clientNonce combined := clientNonce + serverNonce serverFirst := scramServerFirst(combined, salt, iter) // Client-side proof computation (mirror of what libpq does). saltedPwd := scramSaltedPassword(password, salt, iter) clientFinalNoProof := "c=biws,r=" + combined authMsg := scramAuthMessage(clientFirstBare, serverFirst, clientFinalNoProof) clientKey := scramHMAC(saltedPwd, []byte("Client Key")) storedKey := scramH(clientKey) clientSig := scramHMAC(storedKey, authMsg) clientProof := scramXOR(clientKey, clientSig) clientProofB64 := base64.StdEncoding.EncodeToString(clientProof) // Server-side verify. if !scramVerifyClientProof(saltedPwd, authMsg, clientProofB64) { t.Fatal("scramVerifyClientProof rejected a correctly-computed proof") } // Server emits its signature; client must accept it. serverSigB64 := scramServerSignature(saltedPwd, authMsg) serverSigDecoded, err := base64.StdEncoding.DecodeString(serverSigB64) if err != nil || len(serverSigDecoded) != 32 { t.Fatalf("server signature malformed: %q err=%v", serverSigB64, err) } serverKey := scramHMAC(saltedPwd, []byte("Server Key")) expectedSig := scramHMAC(serverKey, authMsg) if !bytes.Equal(serverSigDecoded, expectedSig) { t.Errorf("server signature mismatch: got %x want %x", serverSigDecoded, expectedSig) } // Wrong-password proof must reject. wrongSalted := scramSaltedPassword("wrong", salt, iter) wrongClientKey := scramHMAC(wrongSalted, []byte("Client Key")) wrongStoredKey := scramH(wrongClientKey) wrongSig := scramHMAC(wrongStoredKey, authMsg) wrongProof := base64.StdEncoding.EncodeToString(scramXOR(wrongClientKey, wrongSig)) if scramVerifyClientProof(saltedPwd, authMsg, wrongProof) { t.Error("scramVerifyClientProof accepted a wrong-password proof") } }