Files
five/hbrtl/pgserver/pgserver_test.go
CharlesKWON e83787750a feat(pgserver): SCRAM-SHA-256 authentication (Phase 5.1)
PG14+ clients (libpq, pgx, JDBC) prefer SCRAM over MD5 when offered;
this lands the five-message exchange (SASL / SASLInitialResponse /
SASLContinue / SASLResponse / SASLFinal) so they get their preferred
path. MD5 stays as the universal fallback.

Storage stays plaintext in the in-memory role registry — per-auth we
generate a fresh salt + iter, derive SaltedPassword on the fly. Same
net security as the existing MD5 path, while matching wire output to
RFC 5802 byte for byte.

Critical detail: pgproto3's Backend multiplexes PasswordMessage,
SASLInitialResponse, and SASLResponse onto the same 'p' byte tag.
Without SetAuthType() the decoder picks PasswordMessage and the
handshake fails immediately. Switch state to AuthTypeSASL before
the client-first receive and AuthTypeSASLContinue before the
client-final receive.

Verified:
  * SCRAM math (PBKDF2 / HMAC / proof verify / server signature)
    via pinned unit test
  * Live psql round-trip — correct password accepted, wrong password
    rejected with proper SQLSTATE 28P01
  * All 6 mandatory gates green (go test, SQL 43/43, compat 56/56,
    std.ch 17/17, FRB 7/7, pgserver 11/11)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-22 09:24:34 +09:00

264 lines
8.9 KiB
Go

// Copyright (c) 2026 Charles KWON OhJun (charleskwonohjun@gmail.com)
// All rights reserved.
package pgserver
import (
"bytes"
"encoding/base64"
"strconv"
"strings"
"testing"
"five/hbrt"
)
// TestEncodeText_Numeric pins the text-format encoding for the four
// numeric Five variants psql actually receives. Regressions here
// would surface as silently mis-formatted DataRow values that some
// clients render and others reject — easier to catch with a focused
// unit test than via a psql round-trip.
func TestEncodeText_Numeric(t *testing.T) {
cases := []struct {
name string
v hbrt.Value
want []byte
}{
{"int-positive", hbrt.MakeInt(42), []byte("42")},
{"int-negative", hbrt.MakeInt(-7), []byte("-7")},
{"long", hbrt.MakeLong(9876543210), []byte("9876543210")},
// MakeDouble's metadata: (value, len, dec) — dec=2 should
// surface as "50000.00" not "50000".
{"decimal-2dp", hbrt.MakeDouble(50000.0, 10, 2), []byte("50000.00")},
{"decimal-fraction", hbrt.MakeDouble(42000.5, 10, 2), []byte("42000.50")},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := encodeText(tc.v)
if !bytes.Equal(got, tc.want) {
t.Errorf("encodeText: want %q, got %q", tc.want, got)
}
})
}
}
// TestEncodeText_Strings covers the trivial case but also the NIL
// → nil-slice contract that DataRow uses to distinguish NULL from
// empty string ("" sends length=0; NIL sends length=-1).
func TestEncodeText_Strings(t *testing.T) {
if got := encodeText(hbrt.MakeString("hello")); !bytes.Equal(got, []byte("hello")) {
t.Errorf("string encode: got %q", got)
}
if got := encodeText(hbrt.MakeString("")); got == nil {
t.Error("empty string must encode as []byte{}, not nil (NULL marker)")
}
if got := encodeText(hbrt.MakeNil()); got != nil {
t.Errorf("NIL must encode as nil slice (PG NULL marker), got %q", got)
}
if got := encodeText(hbrt.MakeBool(true)); !bytes.Equal(got, []byte("t")) {
t.Errorf("bool true: got %q", got)
}
if got := encodeText(hbrt.MakeBool(false)); !bytes.Equal(got, []byte("f")) {
t.Errorf("bool false: got %q", got)
}
}
// TestPgTypeFor verifies OID selection for the column-type
// detection path. Integer-shaped numerics that fit int32 must
// transit as INT4 so BI tools display them right-aligned with
// no decimal point.
func TestPgTypeFor(t *testing.T) {
type ent struct {
v hbrt.Value
wantOID uint32
}
for i, tc := range []ent{
{hbrt.MakeInt(0), oidInt4},
{hbrt.MakeInt(2147483647), oidInt4},
{hbrt.MakeLong(9999999999), oidInt8},
{hbrt.MakeDouble(1.5, 10, 2), oidNumeric},
{hbrt.MakeString("x"), oidText},
{hbrt.MakeBool(true), oidBool},
{hbrt.MakeNil(), oidText}, // fallback when no sample
} {
oid, _ := pgTypeFor(tc.v)
if oid != tc.wantOID {
t.Errorf("case %d: want oid %d, got %d", i, tc.wantOID, oid)
}
}
}
// TestSqlStateFor verifies the FiveSql2-error-code → SQLSTATE map.
// Drivers dispatch on the leading two chars (class code), so the
// table needs to match the canonical PG layout for libpq-style
// exception handling to work.
func TestSqlStateFor(t *testing.T) {
want := map[int]string{
1: "42601",
2: "42P01",
3: "42703",
8: "25P02",
99: "XX000",
}
for code, expect := range want {
got := sqlStateFor(code)
if got != expect {
t.Errorf("sqlStateFor(%d) = %q, want %q", code, got, expect)
}
}
}
// TestMD5Challenge pins libpq's challenge formula so the
// server-side computation stays bit-compatible with psql / pgx /
// JDBC. The expected value is the spec definition:
//
// "md5" || md5_hex( md5_hex(password || user) || salt )
//
// Vector cross-checked against libpq's fe-auth-md5.c for the
// same inputs.
func TestMD5Challenge(t *testing.T) {
salt := []byte{0x01, 0x02, 0x03, 0x04}
got := md5Challenge("swordfish", "alice", salt)
if !strings.HasPrefix(got, "md5") {
t.Fatalf("md5 challenge missing prefix: %q", got)
}
if len(got) != 35 { // "md5" + 32 hex chars
t.Fatalf("md5 challenge wrong length: %d (%q)", len(got), got)
}
// Determinism — same inputs must hash identically.
again := md5Challenge("swordfish", "alice", salt)
if got != again {
t.Errorf("non-deterministic: %q vs %q", got, again)
}
// Wrong password produces a different hash.
bad := md5Challenge("wrong", "alice", salt)
if bad == got {
t.Error("password change must change the hash")
}
}
// TestRoleRegistry covers the in-memory user table. Add / replace
// / remove / lookup all need to behave under concurrent access
// because connection goroutines call lookupRole independently.
func TestRoleRegistry(t *testing.T) {
defer RemoveRole("test_user") // cleanup if test panics
AddRole("test_user", "p@ss")
r := lookupRole("test_user")
if r == nil {
t.Fatal("AddRole did not register")
}
if r.PasswordPlain != "p@ss" {
t.Errorf("password mismatch: %q", r.PasswordPlain)
}
// Replace existing entry.
AddRole("test_user", "new")
r2 := lookupRole("test_user")
if r2.PasswordPlain != "new" {
t.Errorf("AddRole did not replace: %q", r2.PasswordPlain)
}
RemoveRole("test_user")
if lookupRole("test_user") != nil {
t.Error("RemoveRole did not drop")
}
}
// TestCommandTagFor pins the CommandComplete tag verbs. Tagged
// rows (n) come in Phase 3; for v1.0 we always emit "VERB 0" so
// psql-style row-count display works (it prints "(0 행)" but
// doesn't error out).
func TestCommandTagFor(t *testing.T) {
cases := []struct{ sql, want string }{
{"SELECT * FROM x", "SELECT 0"},
{" select 1", "SELECT 0"},
{"INSERT INTO x VALUES (1)", "INSERT 0"},
{"UPDATE x SET a=1", "UPDATE 0"},
{"DELETE FROM x", "DELETE 0"},
{"BEGIN", "BEGIN"},
{"COMMIT", "COMMIT"},
{"CREATE TABLE foo (x INT)", "CREATE"},
}
for _, c := range cases {
if got := commandTagFor(c.sql); got != c.want {
t.Errorf("commandTagFor(%q) = %q, want %q", c.sql, got, c.want)
}
}
_ = strconv.Itoa // keep import; will be used in Phase 3 with row counts
}
// TestSCRAMParseClientFirst verifies the gs2-header strip + attr
// parse for the SCRAM client-first message. Vector matches what
// libpq + pgx + JDBC all emit (channel-binding flag "n", empty
// authzid, user attribute, random nonce).
func TestSCRAMParseClientFirst(t *testing.T) {
bare, ok := scramClientFirstBare("n,,n=alice,r=ClientNonce123")
if !ok {
t.Fatal("scramClientFirstBare rejected a valid client-first")
}
if bare != "n=alice,r=ClientNonce123" {
t.Errorf("client-first-bare wrong: %q", bare)
}
attrs := scramParseAttrs(bare)
if attrs["n"] != "alice" || attrs["r"] != "ClientNonce123" {
t.Errorf("attr parse wrong: %+v", attrs)
}
}
// TestSCRAMRoundTrip exercises the full SCRAM math: server picks
// salt+nonce+iter, client (us, here) computes its proof, server
// verifies it, server emits ServerSignature, client verifies that.
// Pinning the verify path catches any regression in PBKDF2/HMAC
// wiring without needing a live psql process.
func TestSCRAMRoundTrip(t *testing.T) {
const password = "swordfish"
const clientNonce = "rOprNGfwEbeRWgbNEkqO" // RFC 7677-style fixed vector
const serverNonce = "9zphn2KvL3K5dlqJpvBz" // arbitrary
salt := []byte{0xa6, 0x3d, 0x18, 0x4f, 0x05, 0x12, 0xc1, 0xd7,
0x88, 0x73, 0xea, 0xb6, 0x91, 0x04, 0x7c, 0x80}
iter := scramIterations
clientFirstBare := "n=alice,r=" + clientNonce
combined := clientNonce + serverNonce
serverFirst := scramServerFirst(combined, salt, iter)
// Client-side proof computation (mirror of what libpq does).
saltedPwd := scramSaltedPassword(password, salt, iter)
clientFinalNoProof := "c=biws,r=" + combined
authMsg := scramAuthMessage(clientFirstBare, serverFirst, clientFinalNoProof)
clientKey := scramHMAC(saltedPwd, []byte("Client Key"))
storedKey := scramH(clientKey)
clientSig := scramHMAC(storedKey, authMsg)
clientProof := scramXOR(clientKey, clientSig)
clientProofB64 := base64.StdEncoding.EncodeToString(clientProof)
// Server-side verify.
if !scramVerifyClientProof(saltedPwd, authMsg, clientProofB64) {
t.Fatal("scramVerifyClientProof rejected a correctly-computed proof")
}
// Server emits its signature; client must accept it.
serverSigB64 := scramServerSignature(saltedPwd, authMsg)
serverSigDecoded, err := base64.StdEncoding.DecodeString(serverSigB64)
if err != nil || len(serverSigDecoded) != 32 {
t.Fatalf("server signature malformed: %q err=%v", serverSigB64, err)
}
serverKey := scramHMAC(saltedPwd, []byte("Server Key"))
expectedSig := scramHMAC(serverKey, authMsg)
if !bytes.Equal(serverSigDecoded, expectedSig) {
t.Errorf("server signature mismatch: got %x want %x", serverSigDecoded, expectedSig)
}
// Wrong-password proof must reject.
wrongSalted := scramSaltedPassword("wrong", salt, iter)
wrongClientKey := scramHMAC(wrongSalted, []byte("Client Key"))
wrongStoredKey := scramH(wrongClientKey)
wrongSig := scramHMAC(wrongStoredKey, authMsg)
wrongProof := base64.StdEncoding.EncodeToString(scramXOR(wrongClientKey, wrongSig))
if scramVerifyClientProof(saltedPwd, authMsg, wrongProof) {
t.Error("scramVerifyClientProof accepted a wrong-password proof")
}
}