diff --git a/src/remote/remote.test.ts b/src/remote/remote.test.ts index b37d0bf..8584ec9 100644 --- a/src/remote/remote.test.ts +++ b/src/remote/remote.test.ts @@ -245,6 +245,100 @@ Deno.test({ }, }); +Deno.test({ + name: "infers '10k' stats strategy when row count is below threshold", + sanitizeOps: false, + sanitizeResources: false, + fn: async () => { + // Create source with very few rows (below 5000 threshold) + const [sourceDb, targetDb] = await Promise.all([ + new PostgreSqlContainer("postgres:17") + .withCopyContentToContainer([ + { + content: ` + create extension pg_stat_statements; + create table small_table(id int); + insert into small_table select generate_series(1, 100); + analyze small_table; + `, + target: "/docker-entrypoint-initdb.d/init.sql", + }, + ]) + .withCommand(["-c", "shared_preload_libraries=pg_stat_statements"]) + .start(), + testSpawnTarget(), + ]); + + try { + const target = Connectable.fromString(targetDb.getConnectionUri()); + const source = Connectable.fromString(sourceDb.getConnectionUri()); + + const remote = new Remote( + target, + ConnectionManager.forLocalDatabase(), + ); + + const result = await remote.syncFrom(source); + await remote.optimizer.finish; + + assertEquals( + result.meta.inferredStatsStrategy, + "10k", + "Should infer '10k' strategy for small databases", + ); + } finally { + await Promise.all([sourceDb.stop(), targetDb.stop()]); + } + }, +}); + +Deno.test({ + name: "infers 'fromSource' stats strategy when row count is above threshold", + sanitizeOps: false, + sanitizeResources: false, + fn: async () => { + // Create source with many rows (above 5000 threshold) + const [sourceDb, targetDb] = await Promise.all([ + new PostgreSqlContainer("postgres:17") + .withCopyContentToContainer([ + { + content: ` + create extension pg_stat_statements; + create table large_table(id int); + insert into large_table select generate_series(1, 10000); + analyze large_table; + `, + target: "/docker-entrypoint-initdb.d/init.sql", + }, + ]) + .withCommand(["-c", "shared_preload_libraries=pg_stat_statements"]) + .start(), + testSpawnTarget(), + ]); + + try { + const target = Connectable.fromString(targetDb.getConnectionUri()); + const source = Connectable.fromString(sourceDb.getConnectionUri()); + + const remote = new Remote( + target, + ConnectionManager.forLocalDatabase(), + ); + + const result = await remote.syncFrom(source); + await remote.optimizer.finish; + + assertEquals( + result.meta.inferredStatsStrategy, + "fromSource", + "Should infer 'fromSource' strategy for large databases", + ); + } finally { + await Promise.all([sourceDb.stop(), targetDb.stop()]); + } + }, +}); + Deno.test({ name: "timescaledb with continuous aggregates sync correctly", sanitizeOps: false, diff --git a/src/remote/remote.ts b/src/remote/remote.ts index cdb8724..398486e 100644 --- a/src/remote/remote.ts +++ b/src/remote/remote.ts @@ -34,6 +34,10 @@ export class Remote extends EventEmitter { static readonly optimizingDbName = PgIdentifier.fromString( "optimizing_db", ); + /* Threshold that we determine is "too few rows" for Postgres to start using indexes + * and not defaulting to table scan. + */ + private static readonly STATS_ROWS_THRESHOLD = 5_000; private readonly differ = new SchemaDiffer(); readonly optimizer: QueryOptimizer; @@ -69,14 +73,18 @@ export class Remote extends EventEmitter { source: Connectable, statsStrategy: StatisticsStrategy = { type: "pullFromSource" }, ): Promise< - { meta: { version?: string }; schema: RemoteSyncFullSchemaResponse } + { + meta: { version?: string; inferredStatsStrategy?: InferredStatsStrategy }; + schema: RemoteSyncFullSchemaResponse; + } > { await this.resetDatabase(); + + // First batch: get schema and other info in parallel (needed for stats decision) const [ restoreResult, recentQueries, fullSchema, - pulledStats, databaseInfo, ] = await Promise .allSettled([ @@ -84,7 +92,6 @@ export class Remote extends EventEmitter { this.pipeSchema(this.optimizingDbUDRL, source), this.getRecentQueries(source), this.getFullSchema(source), - this.resolveStatistics(source, statsStrategy), this.getDatabaseInfo(source), ]); @@ -92,6 +99,16 @@ export class Remote extends EventEmitter { this.differ.put(source, fullSchema.value); } + // Second: resolve stats strategy using table list from schema + const tables = fullSchema.status === "fulfilled" + ? fullSchema.value.tables + : []; + const statsResult = await this.resolveStatistics( + source, + statsStrategy, + tables, + ); + const pg = this.manager.getOrCreateConnection( this.optimizingDbUDRL, ); @@ -101,16 +118,11 @@ export class Remote extends EventEmitter { queries = recentQueries.value; } - let stats: StatisticsMode | undefined; - if (pulledStats.status === "fulfilled") { - stats = pulledStats.value; - } - await this.onSuccessfulSync( pg, source, queries, - stats, + statsResult.mode, ); return { @@ -118,6 +130,7 @@ export class Remote extends EventEmitter { version: databaseInfo.status === "fulfilled" ? databaseInfo.value.serverVersion : undefined, + inferredStatsStrategy: statsResult.strategy, }, schema: fullSchema.status === "fulfilled" ? { type: "ok", value: fullSchema.value } @@ -176,16 +189,38 @@ export class Remote extends EventEmitter { } } - private resolveStatistics( + private async resolveStatistics( source: Connectable, strategy: StatisticsStrategy, - ): Promise { - switch (strategy.type) { - case "static": - return Promise.resolve(strategy.stats); - case "pullFromSource": - return this.dumpSourceStats(source); + tables: { schemaName: PgIdentifier; tableName: PgIdentifier }[], + ): Promise { + if (strategy.type === "static") { + // Static strategy doesn't go through inference + return { mode: strategy.stats, strategy: "fromSource" }; + } + return this.decideStatsStrategy(source, tables); + } + + private async decideStatsStrategy( + source: Connectable, + tables: { schemaName: PgIdentifier; tableName: PgIdentifier }[], + ): Promise { + const connector = this.sourceManager.getConnectorFor(source); + const totalRows = await connector.getTotalRowCount(tables); + + if (totalRows < Remote.STATS_ROWS_THRESHOLD) { + log.info( + `Total rows (${totalRows}) below threshold, using default 10k stats`, + "remote", + ); + return { mode: Statistics.defaultStatsMode, strategy: "10k" }; } + + log.info( + `Total rows (${totalRows}) above threshold, pulling source stats`, + "remote", + ); + return { mode: await this.dumpSourceStats(source), strategy: "fromSource" }; } private async dumpSourceStats(source: Connectable): Promise { @@ -245,3 +280,10 @@ export type StatisticsStrategy = { type: "static"; stats: StatisticsMode; }; + +export type InferredStatsStrategy = "10k" | "fromSource"; + +type StatsResult = { + mode: StatisticsMode; + strategy: InferredStatsStrategy; +}; diff --git a/src/sync/pg-connector.ts b/src/sync/pg-connector.ts index 7f894d4..0e33567 100644 --- a/src/sync/pg-connector.ts +++ b/src/sync/pg-connector.ts @@ -12,7 +12,7 @@ import type { import { log } from "../log.ts"; import { shutdownController } from "../shutdown.ts"; import { withSpan } from "../otel.ts"; -import { Postgres } from "@query-doctor/core"; +import { Postgres, PgIdentifier } from "@query-doctor/core"; import { SegmentedQueryCache } from "./seen-cache.ts"; import { FullSchema, FullSchemaColumn } from "./schema_differ.ts"; import { ExtensionNotInstalledError, PostgresError } from "./errors.ts"; @@ -281,8 +281,8 @@ ORDER BY options: DependencyAnalyzerOptions, ): Promise { const schema = await this.getSchema(); - const mkKey = (schema: string, table: string, column: string) => - `${schema.toLowerCase()}:${table.toLowerCase()}:${column}`; + const mkKey = (schema: PgIdentifier, table: PgIdentifier, column: string) => + `${schema.toString().toLowerCase()}:${table.toString().toLowerCase()}:${column}`; const schemaMap = new Map(); if (schema.tables) { for (const table of schema.tables) { @@ -420,6 +420,26 @@ ORDER BY return FullSchema.parse(results.result); } + public async getTotalRowCount( + tables: { schemaName: PgIdentifier; tableName: PgIdentifier }[], + ): Promise { + if (tables.length === 0) return 0; + + const schemaNames = tables.map((t) => t.schemaName.toString()); + const tableNames = tables.map((t) => t.tableName.toString()); + + const results = await this.db.exec<{ total_rows: string }>( + `SELECT COALESCE(SUM(c.reltuples), 0)::bigint as total_rows + FROM pg_class c + JOIN pg_namespace n ON n.oid = c.relnamespace + JOIN unnest($1::text[], $2::text[]) AS t(schema_name, table_name) + ON n.nspname = t.schema_name AND c.relname = t.table_name + WHERE c.relkind IN ('r', 'm')`, + [schemaNames, tableNames], + ); + return Number(results[0]?.total_rows ?? 0); + } + public async getDatabaseInfo() { const results = await this.db.exec<{ serverVersion: string; diff --git a/src/sync/schema_differ.ts b/src/sync/schema_differ.ts index 22f5337..e225c4b 100644 --- a/src/sync/schema_differ.ts +++ b/src/sync/schema_differ.ts @@ -1,3 +1,4 @@ +import { PgIdentifier } from "@query-doctor/core"; import { create } from "jsondiffpatch"; import { format, type Op } from "jsondiffpatch/formatters/jsonpatch"; import { z } from "zod"; @@ -98,8 +99,8 @@ export type FullSchemaColumn = z.infer; export const FullSchemaTable = z.object({ type: z.literal("table"), oid: z.number(), - schemaName: z.string(), - tableName: z.string(), + schemaName: z.string().transform((v) => PgIdentifier.fromString(v)), + tableName: z.string().transform((v) => PgIdentifier.fromString(v)), tablespace: z.string().optional(), partitionKeyDef: z.string().optional(), // tables without columns do exist