Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 62 additions & 22 deletions async_postgres/pg_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,25 @@ type
conn*: PgConnection
err*: ref CatchableError

TransportCloseStage* = enum
## Which transport resource raised during connection teardown.
tcsTlsReader
tcsTlsWriter
tcsBaseReader
tcsBaseWriter
tcsTransport

TraceTransportCloseErrorData* = object
## Data passed to the transport close-error hook. Fired when a chronos
## ``closeWait()`` call raises while ``closeTransport`` is releasing
## connection resources. These errors are otherwise swallowed because
## teardown must release every transport resource regardless of
## individual failures, leaving operators with no signal for half-closed
## TLS sessions, BearSSL ``close_notify`` mismatches, or peer RSTs.
conn*: PgConnection
stage*: TransportCloseStage
err*: ref CatchableError

TraceInsecureAuthData* = object
## Advisory notification that a server-requested auth method is
## considered insecure in the current transport context. Currently fires
Expand Down Expand Up @@ -371,6 +390,12 @@ type
onPoolReleaseEnd*:
proc(ctx: TraceContext, data: TracePoolReleaseEndData) {.gcsafe, raises: [].}
onPoolCloseError*: proc(data: TracePoolCloseErrorData) {.gcsafe, raises: [].}
onTransportCloseError*:
proc(data: TraceTransportCloseErrorData) {.gcsafe, raises: [].}
## Fires when a transport ``closeWait()`` raises during teardown.
## Advisory only — ``closeTransport`` continues releasing the remaining
## resources regardless. Use this to surface half-closed TLS sessions
## or peer RSTs that would otherwise be invisible.
onInsecureAuth*: proc(data: TraceInsecureAuthData) {.gcsafe, raises: [].}
## Fires when an auth method is used over an insecure transport
## (currently: cleartext password without SSL). Advisory only; does
Expand All @@ -381,6 +406,18 @@ type
## Advisory only; does not abort the connection. Use
## `ConnConfig.requireAuth` to enforce.

when hasChronos:
type RowCallback* = proc(row: Row) {.raises: [CatchableError], gcsafe.}
## Callback invoked once per row during `queryEach`. The `Row` is only valid
## inside the callback — its backing buffer is reused for the next row.

else:
type RowCallback* = proc(row: Row) {.gcsafe.}
## Callback invoked once per row during `queryEach`. The `Row` is only valid
## inside the callback — its backing buffer is reused for the next row.

const RecvBufSize = 131072 ## Size of the temporary read buffer for recv operations

# Public API: read-only getters

func pid*(conn: PgConnection): int32 {.inline.} =
Expand Down Expand Up @@ -731,18 +768,6 @@ template makeCopyInCallback*(body: untyped): CopyInCallback =
return fut
r

when hasChronos:
type RowCallback* = proc(row: Row) {.raises: [CatchableError], gcsafe.}
## Callback invoked once per row during `queryEach`. The `Row` is only valid
## inside the callback — its backing buffer is reused for the next row.

else:
type RowCallback* = proc(row: Row) {.gcsafe.}
## Callback invoked once per row during `queryEach`. The `Row` is only valid
## inside the callback — its backing buffer is reused for the next row.

const RecvBufSize = 131072 ## Size of the temporary read buffer for recv operations

proc dispatchNotification*(conn: PgConnection, msg: BackendMessage) {.raises: [].} =
let notif = Notification(
pid: msg.notifPid, channel: msg.notifChannel, payload: msg.notifPayload
Expand Down Expand Up @@ -975,35 +1000,50 @@ proc sendBufMsg*(conn: PgConnection): Future[void] {.async.} =
if conn.sendBuf.len > 0:
await conn.socket.sendRawBytes(conn.sendBuf)

when hasChronos:
proc fireTransportCloseError(
conn: PgConnection, stage: TransportCloseStage, err: ref CatchableError
) =
## Route a swallowed transport close error to the tracer. ``closeTransport``
## must continue releasing the remaining resources, so the error cannot be
## propagated to a caller — tracing is the only signal operators have.
## Reads from ``conn.config.tracer`` so events fire even when teardown
## happens before the runtime tracer alias has been assigned.
let t = conn.config.tracer
if t != nil and t.onTransportCloseError != nil:
t.onTransportCloseError(
TraceTransportCloseErrorData(conn: conn, stage: stage, err: err)
)

proc closeTransport(conn: PgConnection) {.async.} =
## Close transport resources without sending Terminate.
when hasChronos:
if conn.tlsStream != nil:
try:
await conn.tlsStream.reader.closeWait()
except CatchableError:
discard
except CatchableError as e:
conn.fireTransportCloseError(tcsTlsReader, e)
try:
await conn.tlsStream.writer.closeWait()
except CatchableError:
discard
except CatchableError as e:
conn.fireTransportCloseError(tcsTlsWriter, e)
conn.tlsStream = nil
if conn.baseReader != nil:
try:
await conn.baseReader.closeWait()
except CatchableError:
discard
except CatchableError as e:
conn.fireTransportCloseError(tcsBaseReader, e)
try:
await conn.baseWriter.closeWait()
except CatchableError:
discard
except CatchableError as e:
conn.fireTransportCloseError(tcsBaseWriter, e)
conn.baseReader = nil
conn.baseWriter = nil
if conn.transport != nil:
try:
await conn.transport.closeWait()
except CatchableError:
discard
except CatchableError as e:
conn.fireTransportCloseError(tcsTransport, e)
conn.transport = nil
# Drop the cached reader/writer aliases so isConnected() reports false.
conn.reader = nil
Expand Down
112 changes: 112 additions & 0 deletions tests/test_tracing.nim
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ type
hasConn: bool
errMsg: string

TransportCloseErrorRec = object
hasConn: bool
stage: TransportCloseStage
errMsg: string

InsecureAuthRec = object
hasConn: bool
authMethod: AuthMethod
Expand All @@ -139,6 +144,7 @@ type
poolReleaseStarts: seq[PoolReleaseStartRec]
poolReleaseEnds: seq[PoolReleaseEndRec]
poolCloseErrors: seq[PoolCloseErrorRec]
transportCloseErrors: seq[TransportCloseErrorRec]
insecureAuths: seq[InsecureAuthRec]
deprecatedAuths: seq[DeprecatedAuthRec]

Expand Down Expand Up @@ -272,6 +278,17 @@ proc buildTracer(log: TraceLog): PgTracer =
)
)

tracer.onTransportCloseError = proc(
data: TraceTransportCloseErrorData
) {.gcsafe, raises: [].} =
log.transportCloseErrors.add(
TransportCloseErrorRec(
hasConn: data.conn != nil,
stage: data.stage,
errMsg: (if data.err != nil: data.err.msg else: ""),
)
)

tracer.onInsecureAuth = proc(data: TraceInsecureAuthData) {.gcsafe, raises: [].} =
log.insecureAuths.add(
InsecureAuthRec(
Expand Down Expand Up @@ -758,6 +775,101 @@ suite "Tracing: pool close errors":

waitFor t()

suite "Tracing: transport close errors":
when hasChronos:
test "onTransportCloseError reports swallowed close failures":
proc t() {.async.} =
let log = newTraceLog()
let tracer = buildTracer(log)
let conn = await connect(tracedConfig(tracer))

# Drive the hook directly — inducing a real closeWait() failure is
# impractical (the chronos streams swallow most peer-side faults
# internally), so we exercise the chokepoint that closeTransport
# funnels every swallowed close error through.
let err = newException(PgError, "simulated tls close failure")
conn.fireTransportCloseError(tcsTlsReader, err)

doAssert log.transportCloseErrors.len == 1
doAssert log.transportCloseErrors[0].hasConn
doAssert log.transportCloseErrors[0].stage == tcsTlsReader
doAssert log.transportCloseErrors[0].errMsg == "simulated tls close failure"

await conn.close()

waitFor t()

test "every stage value round-trips through the hook":
proc t() {.async.} =
let log = newTraceLog()
let tracer = buildTracer(log)
let conn = await connect(tracedConfig(tracer))

const stages =
[tcsTlsReader, tcsTlsWriter, tcsBaseReader, tcsBaseWriter, tcsTransport]
for stage in stages:
conn.fireTransportCloseError(stage, newException(PgError, "x"))

doAssert log.transportCloseErrors.len == stages.len
for i in 0 ..< stages.len:
doAssert log.transportCloseErrors[i].stage == stages[i]
doAssert log.transportCloseErrors[i].errMsg == "x"

await conn.close()

waitFor t()

test "nil onTransportCloseError hook is a no-op":
proc t() {.async.} =
# Build a config with no tracer at all — fire must be a no-op.
let conn = await connect(plainConfig())
conn.fireTransportCloseError(tcsTransport, newException(PgError, "ignored"))

# Tracer present but the hook itself is nil — also a no-op.
let tracer = PgTracer()
var cfg = plainConfig()
cfg.tracer = tracer
let conn2 = await connect(cfg)
conn2.fireTransportCloseError(tcsBaseReader, newException(PgError, "ignored2"))

await conn.close()
await conn2.close()

waitFor t()

test "healthy close() does not fire onTransportCloseError":
proc t() {.async.} =
let log = newTraceLog()
let tracer = buildTracer(log)
let conn = await connect(tracedConfig(tracer))
await conn.close()

doAssert log.transportCloseErrors.len == 0

waitFor t()

test "closeTransport wires every TransportCloseStage to a fire call":
# Structural guard. Inducing real closeWait() failures is impractical, so
# the behavioural tests above drive fireTransportCloseError directly and
# cannot catch regressions inside closeTransport itself (forgotten except
# clause, new stage without a fire call, new closeWait without wiring).
# This test reads the source and asserts the invariants mechanically.
const src = staticRead("../async_postgres/pg_connection.nim")
let body = src.split("proc closeTransport(")[1].split("\nproc ")[0]

for stage in [
"tcsTlsReader", "tcsTlsWriter", "tcsBaseReader", "tcsBaseWriter", "tcsTransport"
]:
doAssert stage in body,
"closeTransport missing reference to " & stage &
" — fire call wiring is incomplete"

let closeWaits = body.count("closeWait()")
let fires = body.count("fireTransportCloseError(")
doAssert closeWaits == fires,
"closeTransport has " & $closeWaits & " closeWait() calls but " & $fires &
" fireTransportCloseError() calls — each closeWait must be paired with a fire"

suite "Tracing: queryEach":
test "onQueryStart and onQueryEnd are called with rowCount":
proc t() {.async.} =
Expand Down