diff --git a/src/app/discover/route.ts b/src/app/discover/route.ts index 64af7fb..d513cbe 100644 --- a/src/app/discover/route.ts +++ b/src/app/discover/route.ts @@ -1,6 +1,7 @@ import db from "@/lib/db"; import { serverRegistry } from "@/lib/db/schema"; import { decryptPayload, fingerprintKey } from "@/lib/federation/keytools"; +import { upsertServer } from "@/lib/federation/registry"; import { assertSafeUrl, UrlGuardError } from "@/lib/federation/url-guard"; import createDebug from "debug"; import { desc, eq } from "drizzle-orm"; @@ -86,19 +87,6 @@ export async function GET() { }); } -async function upsertServer(url: string, publicKey: string, encryptionPublicKey: string) { - return await db.insert(serverRegistry).values({ - id: crypto.randomUUID(), - url, - publicKey, - encryptionPublicKey, - lastSeen: new Date(), - createdAt: new Date(), - updatedAt: new Date(), - isHealthy: true, - }).onConflictDoNothing(); -} - async function discoverServer(validated: z.infer) { debug("DISCOVER – looking up server by public key"); const server = await db.select().from(serverRegistry).where(eq(serverRegistry.publicKey, validated.publicKey)); @@ -108,9 +96,7 @@ async function discoverServer(validated: z.infer) { } try { - if (process.env.NODE_ENV !== "development") { - assertSafeUrl(server[0].url); - } + assertSafeUrl(server[0].url); } catch (err) { debug("DISCOVER – stored URL failed SSRF check: %s", server[0].url); if (err instanceof UrlGuardError) { @@ -140,9 +126,7 @@ async function discoverServer(validated: z.infer) { async function registerServer(validated: z.infer) { try { - if (process.env.NODE_ENV !== "development") { - assertSafeUrl(validated.url); - } + await assertSafeUrl(validated.url); } catch (err) { debug("REGISTER – URL failed SSRF check: %s", validated.url); if (err instanceof UrlGuardError) { diff --git a/src/instrumentation.ts b/src/instrumentation.ts new file mode 100644 index 0000000..05a09e2 --- /dev/null +++ b/src/instrumentation.ts @@ -0,0 +1,6 @@ +export async function register() { + if (process.env.NEXT_RUNTIME === 'nodejs') { + const { startFederationWorker } = await import('./lib/bull'); + startFederationWorker(); + } +} diff --git a/src/lib/auth.ts b/src/lib/auth.ts index 33fcf34..2fbe8ed 100644 --- a/src/lib/auth.ts +++ b/src/lib/auth.ts @@ -8,6 +8,7 @@ import db from "./db"; import * as schema from "./db/schema"; import EmailService from "./mail"; import minioClient from "./plugins/server/storage/minio.client"; +import getRedisClient from "./redis"; const isTest = process.env.NODE_ENV === "test"; const emailService: EmailService | undefined = isTest ? undefined : new EmailService(); @@ -52,6 +53,18 @@ const bAuth = betterAuth({ provider: "pg", schema }), + secondaryStorage: { + get: async (key) => { + const value = await getRedisClient().get(key); + return value ? JSON.parse(value) : null; + }, + set: async (key, value, ttl) => { + await getRedisClient().setex(key, ttl ?? 3600 * 24 * 7, JSON.stringify(value)); + }, + delete: async (key) => { + await getRedisClient().del(key); + } + }, hooks: { after: createAuthMiddleware(async (context) => { if (!context.path) return; @@ -74,7 +87,8 @@ const bAuth = betterAuth({ sipherSocial(), federation(), openAPI(), - testUtils() // TODO: Add a conditional plugin for test utils in development + testUtils(), // TODO: Add a conditional plugin for test utils in development + bearer() ], // This is disabled by default, but I'll keep this here for ease of mind. // You never know when companies will change their minds and decide to start tracking you. diff --git a/src/lib/bull/index.ts b/src/lib/bull/index.ts new file mode 100644 index 0000000..8fe6caa --- /dev/null +++ b/src/lib/bull/index.ts @@ -0,0 +1,174 @@ +import db from '@/lib/db'; +import { blacklistedServers, deliveryJobs, follows, serverRegistry } from '@/lib/db/schema'; +import { encryptPayload, getOwnSigningSecretKey, signMessage } from '@/lib/federation/keytools'; +import { discoverAndRegister, DiscoveryError } from '@/lib/federation/registry'; +import { Queue, UnrecoverableError, Worker, type Job } from 'bullmq'; +import createDebug from 'debug'; +import { eq } from 'drizzle-orm'; +import Redis from 'ioredis'; + +const debug = createDebug('app:federation:worker'); + +export interface FederationDeliveryJob { + deliveryJobId: string; + targetUrl: string; + serverUrl: string; + payload: string; +} + +const QUEUE_NAME = 'federation-delivery'; + +function createRedisConnection() { + return new Redis(process.env.REDIS_URL!, { maxRetriesPerRequest: null }); +} + +let _queue: Queue | null = null; + +export function getFederationQueue(): Queue { + if (!_queue) { + _queue = new Queue(QUEUE_NAME, { + connection: createRedisConnection() as never, + defaultJobOptions: { + attempts: 5, + backoff: { + type: 'exponential', + delay: 5_000, + }, + removeOnComplete: { age: 60 * 60 * 24 }, + removeOnFail: { age: 60 * 60 * 24 * 7 }, + }, + }); + } + return _queue; +} + +async function processFederationDelivery(job: Job) { + const { deliveryJobId, targetUrl, serverUrl, payload } = job.data; + debug('processing job %s (%s) → %s (attempt %d)', job.id, job.name, targetUrl, job.attemptsMade + 1); + + const [blacklisted] = await db + .select({ id: blacklistedServers.id }) + .from(blacklistedServers) + .where(eq(blacklistedServers.serverUrl, serverUrl)) + .limit(1); + + if (blacklisted) { + debug('server %s is blacklisted, dropping job %s', serverUrl, job.id); + await db.delete(deliveryJobs).where(eq(deliveryJobs.id, deliveryJobId)); + throw new UnrecoverableError(`Server ${serverUrl} is blacklisted, skipping delivery`); + } + + let encryptionPublicKey: string; + + const [server] = await db + .select({ encryptionPublicKey: serverRegistry.encryptionPublicKey }) + .from(serverRegistry) + .where(eq(serverRegistry.url, serverUrl)) + .limit(1); + + if (server) { + encryptionPublicKey = server.encryptionPublicKey; + } else { + debug('server %s not in registry, attempting auto-discovery', serverUrl); + try { + encryptionPublicKey = await discoverAndRegister(serverUrl); + } catch (err) { + if (err instanceof DiscoveryError) { + debug('auto-discovery of %s failed: %s', serverUrl, err.message); + throw new Error(`Auto-discovery of ${serverUrl} failed: ${err.message}`); + } + throw err; + } + } + + debug('encrypting payload for %s (key: %s…)', serverUrl, encryptionPublicKey.slice(0, 8)); + const recipientKey = new Uint8Array(Buffer.from(encryptionPublicKey, 'base64')); + const encrypted = encryptPayload(payload, recipientKey); + + await db.update(deliveryJobs).set({ + lastAttemptedAt: new Date(), + attempts: job.attemptsMade + 1, + }).where(eq(deliveryJobs.id, deliveryJobId)); + + debug('sending encrypted payload to %s', targetUrl); + + const method = JSON.parse(payload).method; + if (!method || !["FEDERATE", "INSERT", "UNFOLLOW"].includes(method)) { + debug('invalid method: %s, dropping job %s', method, job.id); + await db.delete(deliveryJobs).where(eq(deliveryJobs.id, deliveryJobId)); + debug('job %s dropped because of invalid method', job.id); + throw new UnrecoverableError(`Invalid method: ${method}, dropping job ${job.id}`); + } + + const signature = signMessage(payload, getOwnSigningSecretKey()); + + const response = await fetch(targetUrl, { + method: 'POST', + headers: { 'Content-Type': 'application/json', 'Origin': process.env.BETTER_AUTH_URL! }, + body: JSON.stringify({ method, payload: encrypted, signature }), + signal: AbortSignal.timeout(15_000), + }); + + if (!response.ok) { + debug('delivery to %s failed with status %d', targetUrl, response.status); + throw new Error(`Federation delivery to ${targetUrl} failed: ${response.status}`); + } + + const responseBody = await response.json(); + + if (responseBody.status !== "acknowledged") { + debug('delivery to %s not acknowledged', targetUrl); + throw new UnrecoverableError(`Federation delivery to ${targetUrl} failed: ${response.status} - ${JSON.stringify(responseBody)}`); + } + + if (job.name === 'deliver-follow') { + const followId = JSON.parse(payload).following?.id; + if (followId && typeof responseBody.accepted === "boolean") { + await db.update(follows).set({ accepted: responseBody.accepted }) + .where(eq(follows.id, followId)); + debug('updated follow %s accepted=%s', followId, responseBody.accepted); + } + } + + debug('job %s delivered successfully to %s', job.id, targetUrl); +} + +export function startFederationWorker() { + createDebug.enable(process.env.DEBUG || ''); + console.log('[federation] Starting worker...'); + + const worker = new Worker( + QUEUE_NAME, + processFederationDelivery, + { + connection: createRedisConnection() as never, + concurrency: 10, + }, + ); + + worker.on('ready', () => { + console.log('[federation] Worker connected to Redis and ready'); + }); + + worker.on('failed', (job, err) => { + const retriesLeft = (job?.opts.attempts ?? 0) - (job?.attemptsMade ?? 0); + debug('job %s (%s) to %s failed (attempt %d, %d retries left): %s', job?.id, job?.name, job?.data.targetUrl, job?.attemptsMade, retriesLeft, err.message); + if (err.cause) debug('cause: %O', err.cause); + }); + + worker.on('completed', async (job) => { + debug('job %s (%s) completed, cleaning up delivery record %s', job.id, job.name, job.data.deliveryJobId); + try { + await db.delete(deliveryJobs).where(eq(deliveryJobs.id, job.data.deliveryJobId)); + } catch (err) { + debug('failed to clean up delivery job %s: %O', job.data.deliveryJobId, err); + } + }); + + worker.on('error', (err) => { + console.error('[federation] Worker error:', err); + }); + + debug('worker started'); + return worker; +} diff --git a/src/lib/db/schema/index.ts b/src/lib/db/schema/index.ts index 19de971..14e4084 100644 --- a/src/lib/db/schema/index.ts +++ b/src/lib/db/schema/index.ts @@ -1,289 +1,274 @@ import { relations } from "drizzle-orm"; import { - boolean, - index, - integer, - jsonb, - pgTable, - text, - timestamp, - uniqueIndex, + pgTable, + text, + timestamp, + boolean, + integer, + jsonb, + index, + uniqueIndex, } from "drizzle-orm/pg-core"; export const user = pgTable("user", { - id: text("id").primaryKey(), - name: text("name").notNull(), - email: text("email").notNull().unique(), - emailVerified: boolean("email_verified").default(false).notNull(), - image: text("image"), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") - .defaultNow() - .$onUpdate(() => /* @__PURE__ */ new Date()) - .notNull(), - username: text("username").unique(), - displayUsername: text("display_username"), - twoFactorEnabled: boolean("two_factor_enabled").default(false), - isPrivate: boolean("is_private").default(false), + id: text("id").primaryKey(), + name: text("name").notNull(), + email: text("email").notNull().unique(), + emailVerified: boolean("email_verified").default(false).notNull(), + image: text("image"), + createdAt: timestamp("created_at").defaultNow().notNull(), + updatedAt: timestamp("updated_at") + .defaultNow() + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), + username: text("username").unique(), + displayUsername: text("display_username"), + twoFactorEnabled: boolean("two_factor_enabled").default(false), + isPrivate: boolean("is_private").default(false), }); -export const session = pgTable( - "session", - { - id: text("id").primaryKey(), - expiresAt: timestamp("expires_at").notNull(), - token: text("token").notNull().unique(), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") - .$onUpdate(() => /* @__PURE__ */ new Date()) - .notNull(), - ipAddress: text("ip_address"), - userAgent: text("user_agent"), - userId: text("user_id") - .notNull() - .references(() => user.id, { onDelete: "cascade" }), - }, - (table) => [index("session_userId_idx").on(table.userId)], -); - export const account = pgTable( - "account", - { - id: text("id").primaryKey(), - accountId: text("account_id").notNull(), - providerId: text("provider_id").notNull(), - userId: text("user_id") - .notNull() - .references(() => user.id, { onDelete: "cascade" }), - accessToken: text("access_token"), - refreshToken: text("refresh_token"), - idToken: text("id_token"), - accessTokenExpiresAt: timestamp("access_token_expires_at"), - refreshTokenExpiresAt: timestamp("refresh_token_expires_at"), - scope: text("scope"), - password: text("password"), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") - .$onUpdate(() => /* @__PURE__ */ new Date()) - .notNull(), - }, - (table) => [index("account_userId_idx").on(table.userId)], -); - -export const verification = pgTable( - "verification", - { - id: text("id").primaryKey(), - identifier: text("identifier").notNull(), - value: text("value").notNull(), - expiresAt: timestamp("expires_at").notNull(), - createdAt: timestamp("created_at").defaultNow().notNull(), - updatedAt: timestamp("updated_at") - .defaultNow() - .$onUpdate(() => /* @__PURE__ */ new Date()) - .notNull(), - }, - (table) => [index("verification_identifier_idx").on(table.identifier)], + "account", + { + id: text("id").primaryKey(), + accountId: text("account_id").notNull(), + providerId: text("provider_id").notNull(), + userId: text("user_id") + .notNull() + .references(() => user.id, { onDelete: "cascade" }), + accessToken: text("access_token"), + refreshToken: text("refresh_token"), + idToken: text("id_token"), + accessTokenExpiresAt: timestamp("access_token_expires_at"), + refreshTokenExpiresAt: timestamp("refresh_token_expires_at"), + scope: text("scope"), + password: text("password"), + createdAt: timestamp("created_at").defaultNow().notNull(), + updatedAt: timestamp("updated_at") + .$onUpdate(() => /* @__PURE__ */ new Date()) + .notNull(), + }, + (table) => [index("account_userId_idx").on(table.userId)], ); export const twoFactor = pgTable( - "two_factor", - { - id: text("id").primaryKey(), - secret: text("secret").notNull(), - backupCodes: text("backup_codes").notNull(), - userId: text("user_id") - .notNull() - .references(() => user.id, { onDelete: "cascade" }), - }, - (table) => [ - index("twoFactor_secret_idx").on(table.secret), - index("twoFactor_userId_idx").on(table.userId), - ], + "two_factor", + { + id: text("id").primaryKey(), + secret: text("secret").notNull(), + backupCodes: text("backup_codes").notNull(), + userId: text("user_id") + .notNull() + .references(() => user.id, { onDelete: "cascade" }), + }, + (table) => [ + index("twoFactor_secret_idx").on(table.secret), + index("twoFactor_userId_idx").on(table.userId), + ], ); -export const posts = pgTable("posts", { - id: text("id").primaryKey(), - content: jsonb("content").notNull(), - authorId: text("author_id") - .notNull() - .references(() => user.id, { onDelete: "cascade" }), - published: timestamp("published").notNull(), - isLocal: boolean("is_local").default(false).notNull(), - isPrivate: boolean("is_private").default(false), - createdAt: timestamp("created_at").notNull(), -}); +export const posts = pgTable( + "posts", + { + id: text("id").primaryKey(), + content: jsonb("content").notNull(), + authorId: text("author_id") + .notNull() + .references(() => user.id, { onDelete: "cascade" }), + published: timestamp("published").notNull(), + isLocal: boolean("is_local").default(false).notNull(), + isPrivate: boolean("is_private").default(false), + createdAt: timestamp("created_at").notNull(), + federationUrl: text("federation_url"), + }, + (table) => [index("posts_federationUrl_idx").on(table.federationUrl)], +); -export const follows = pgTable("follows", { - id: text("id").primaryKey(), - followerId: text("follower_id") - .notNull() - .references(() => user.id, { onDelete: "cascade" }), - followingId: text("following_id") - .notNull() - .references(() => user.id, { onDelete: "cascade" }), - accepted: boolean("accepted").default(false).notNull(), - createdAt: timestamp("created_at").notNull(), -}); +export const follows = pgTable( + "follows", + { + id: text("id").primaryKey(), + followerId: text("follower_id").notNull(), + followingId: text("following_id").notNull(), + accepted: boolean("accepted").default(false).notNull(), + createdAt: timestamp("created_at").notNull(), + followerServerUrl: text("follower_server_url").references( + () => serverRegistry.url, + { onDelete: "cascade" }, + ), + followingServerUrl: text("following_server_url").references( + () => serverRegistry.url, + { onDelete: "cascade" }, + ), + }, + (table) => [ + index("follows_followerServerUrl_idx").on(table.followerServerUrl), + index("follows_followingServerUrl_idx").on(table.followingServerUrl), + ], +); export const deliveryJobs = pgTable("delivery_jobs", { - id: text("id").primaryKey(), - targetUrl: text("target_url").notNull(), - payload: text("payload").notNull(), - attempts: integer("attempts").default(0).notNull(), - lastAttemptedAt: timestamp("last_attempted_at"), - nextAttemptAt: timestamp("next_attempt_at"), - createdAt: timestamp("created_at").notNull(), + id: text("id").primaryKey(), + targetUrl: text("target_url").notNull(), + payload: text("payload").notNull(), + attempts: integer("attempts").default(0).notNull(), + lastAttemptedAt: timestamp("last_attempted_at"), + nextAttemptAt: timestamp("next_attempt_at"), + createdAt: timestamp("created_at").notNull(), }); export const mutes = pgTable("mutes", { - id: text("id").primaryKey(), - userId: text("user_id") - .notNull() - .references(() => user.id, { onDelete: "cascade" }), - mutedUserId: text("muted_user_id") - .notNull() - .references(() => user.id, { onDelete: "cascade" }), - createdAt: timestamp("created_at").notNull(), + id: text("id").primaryKey(), + userId: text("user_id") + .notNull() + .references(() => user.id, { onDelete: "cascade" }), + mutedUserId: text("muted_user_id") + .notNull() + .references(() => user.id, { onDelete: "cascade" }), + createdAt: timestamp("created_at").notNull(), }); export const blocks = pgTable("blocks", { - id: text("id").primaryKey(), - blockerId: text("blocker_id") - .notNull() - .references(() => user.id, { onDelete: "cascade" }), - blockedUserId: text("blocked_user_id") - .notNull() - .references(() => user.id, { onDelete: "cascade" }), - createdAt: timestamp("created_at").notNull(), + id: text("id").primaryKey(), + blockerId: text("blocker_id") + .notNull() + .references(() => user.id, { onDelete: "cascade" }), + blockedUserId: text("blocked_user_id") + .notNull() + .references(() => user.id, { onDelete: "cascade" }), + createdAt: timestamp("created_at").notNull(), }); export const serverRegistry = pgTable( - "server_registry", - { - id: text("id").primaryKey(), - url: text("url").notNull().unique(), - publicKey: text("public_key").notNull().unique(), - encryptionPublicKey: text("encryption_public_key").notNull().unique(), - lastSeen: timestamp("last_seen").notNull(), - createdAt: timestamp("created_at").notNull(), - updatedAt: timestamp("updated_at").notNull(), - isHealthy: boolean("is_healthy").notNull(), - }, - (table) => [ - uniqueIndex("serverRegistry_publicKey_uidx").on(table.publicKey), - uniqueIndex("serverRegistry_encryptionPublicKey_uidx").on( - table.encryptionPublicKey, - ), - index("serverRegistry_lastSeen_idx").on(table.lastSeen), - ], + "server_registry", + { + id: text("id").primaryKey(), + url: text("url").notNull().unique(), + publicKey: text("public_key").notNull().unique(), + encryptionPublicKey: text("encryption_public_key").notNull().unique(), + lastSeen: timestamp("last_seen").notNull(), + createdAt: timestamp("created_at").notNull(), + updatedAt: timestamp("updated_at").notNull(), + isHealthy: boolean("is_healthy").notNull(), + }, + (table) => [ + uniqueIndex("serverRegistry_publicKey_uidx").on(table.publicKey), + uniqueIndex("serverRegistry_encryptionPublicKey_uidx").on( + table.encryptionPublicKey, + ), + index("serverRegistry_lastSeen_idx").on(table.lastSeen), + ], ); export const rotateChallengeTokens = pgTable( - "rotate_challenge_tokens", - { - id: text("id").primaryKey(), - signingOldToken: text("signing_old_token").notNull(), - signingNewToken: text("signing_new_token").notNull(), - encryptionOldToken: text("encryption_old_token").notNull(), - encryptionNewToken: text("encryption_new_token").notNull(), - newSigningPublicKey: text("new_signing_public_key").notNull(), - newEncryptionPublicKey: text("new_encryption_public_key").notNull(), - serverUrl: text("server_url").notNull(), - createdAt: timestamp("created_at").notNull(), - attemptsLeft: integer("attempts_left").default(3).notNull(), - expiresAt: timestamp("expires_at").notNull(), - }, - (table) => [index("rotateChallengeTokens_serverUrl_idx").on(table.serverUrl)], + "rotate_challenge_tokens", + { + id: text("id").primaryKey(), + signingOldToken: text("signing_old_token").notNull(), + signingNewToken: text("signing_new_token").notNull(), + encryptionOldToken: text("encryption_old_token").notNull(), + encryptionNewToken: text("encryption_new_token").notNull(), + newSigningPublicKey: text("new_signing_public_key").notNull(), + newEncryptionPublicKey: text("new_encryption_public_key").notNull(), + serverUrl: text("server_url").notNull(), + createdAt: timestamp("created_at").notNull(), + attemptsLeft: integer("attempts_left").default(3).notNull(), + expiresAt: timestamp("expires_at").notNull(), + }, + (table) => [index("rotateChallengeTokens_serverUrl_idx").on(table.serverUrl)], ); export const blacklistedServers = pgTable( - "blacklisted_servers", - { - id: text("id").primaryKey(), - serverUrl: text("server_url").notNull(), - createdAt: timestamp("created_at").notNull(), - reason: text("reason").notNull(), - }, - (table) => [index("blacklistedServers_serverUrl_idx").on(table.serverUrl)], + "blacklisted_servers", + { + id: text("id").primaryKey(), + serverUrl: text("server_url").notNull(), + createdAt: timestamp("created_at").notNull(), + reason: text("reason").notNull(), + }, + (table) => [index("blacklistedServers_serverUrl_idx").on(table.serverUrl)], ); export const userRelations = relations(user, ({ many }) => ({ - sessions: many(session), - accounts: many(account), - twoFactors: many(twoFactor), - postss: many(posts), - followss: many(follows), - mutess: many(mutes), - blockss: many(blocks), -})); - -export const sessionRelations = relations(session, ({ one }) => ({ - user: one(user, { - fields: [session.userId], - references: [user.id], - }), + accounts: many(account), + twoFactors: many(twoFactor), + postss: many(posts), + mutess: many(mutes), + blockss: many(blocks), })); export const accountRelations = relations(account, ({ one }) => ({ - user: one(user, { - fields: [account.userId], - references: [user.id], - }), + user: one(user, { + fields: [account.userId], + references: [user.id], + }), })); export const twoFactorRelations = relations(twoFactor, ({ one }) => ({ - user: one(user, { - fields: [twoFactor.userId], - references: [user.id], - }), + user: one(user, { + fields: [twoFactor.userId], + references: [user.id], + }), })); export const postsRelations = relations(posts, ({ one }) => ({ - user: one(user, { - fields: [posts.authorId], - references: [user.id], - }), + user: one(user, { + fields: [posts.authorId], + references: [user.id], + }), })); -export const followsFollowerIdRelations = relations(follows, ({ one }) => ({ - user: one(user, { - fields: [follows.followerId], - references: [user.id], - }), -})); +export const followsFollowerServerUrlRelations = relations( + follows, + ({ one }) => ({ + serverRegistry: one(serverRegistry, { + fields: [follows.followerServerUrl], + references: [serverRegistry.url], + }), + }), +); -export const followsFollowingIdRelations = relations(follows, ({ one }) => ({ - user: one(user, { - fields: [follows.followingId], - references: [user.id], - }), -})); +export const followsFollowingServerUrlRelations = relations( + follows, + ({ one }) => ({ + serverRegistry: one(serverRegistry, { + fields: [follows.followingServerUrl], + references: [serverRegistry.url], + }), + }), +); export const mutesUserIdRelations = relations(mutes, ({ one }) => ({ - user: one(user, { - fields: [mutes.userId], - references: [user.id], - }), + user: one(user, { + fields: [mutes.userId], + references: [user.id], + }), })); export const mutesMutedUserIdRelations = relations(mutes, ({ one }) => ({ - user: one(user, { - fields: [mutes.mutedUserId], - references: [user.id], - }), + user: one(user, { + fields: [mutes.mutedUserId], + references: [user.id], + }), })); export const blocksBlockerIdRelations = relations(blocks, ({ one }) => ({ - user: one(user, { - fields: [blocks.blockerId], - references: [user.id], - }), + user: one(user, { + fields: [blocks.blockerId], + references: [user.id], + }), })); export const blocksBlockedUserIdRelations = relations(blocks, ({ one }) => ({ - user: one(user, { - fields: [blocks.blockedUserId], - references: [user.id], - }), + user: one(user, { + fields: [blocks.blockedUserId], + references: [user.id], + }), })); + +export const serverRegistryRelations = relations( + serverRegistry, + ({ many }) => ({ + followss: many(follows), + }), +); diff --git a/src/lib/federation/keytools.ts b/src/lib/federation/keytools.ts index dfc7c68..d95df52 100644 --- a/src/lib/federation/keytools.ts +++ b/src/lib/federation/keytools.ts @@ -84,3 +84,19 @@ export function fingerprintKey(keyBase64: string): string { const hash = createHash("sha256").update(fromBase64(keyBase64)).digest("hex"); return hash; } + +export function getOwnEncryptionPublicKey(): Uint8Array { + return new Uint8Array(Buffer.from(process.env.FEDERATION_ENCRYPTION_PUBLIC_KEY!, "base64")) +} + +export function getOwnSigningPublicKey(): Uint8Array { + return new Uint8Array(Buffer.from(process.env.FEDERATION_PUBLIC_KEY!, "base64")) +} + +export function getOwnSigningSecretKey(): Uint8Array { + return new Uint8Array(Buffer.from(process.env.FEDERATION_PRIVATE_KEY!, "base64")) +} + +export function getOwnEncryptionSecretKey(): Uint8Array { + return new Uint8Array(Buffer.from(process.env.FEDERATION_ENCRYPTION_PRIVATE_KEY!, "base64")) +} \ No newline at end of file diff --git a/src/lib/federation/registry.ts b/src/lib/federation/registry.ts new file mode 100644 index 0000000..d8b7d97 --- /dev/null +++ b/src/lib/federation/registry.ts @@ -0,0 +1,92 @@ +import db from '@/lib/db'; +import { serverRegistry } from '@/lib/db/schema'; +import { assertSafeUrl } from '@/lib/federation/url-guard'; +import createDebug from 'debug'; +import { eq } from 'drizzle-orm'; + +const debug = createDebug('app:federation:registry'); + +export async function upsertServer(url: string, publicKey: string, encryptionPublicKey: string) { + return await db.insert(serverRegistry).values({ + id: crypto.randomUUID(), + url, + publicKey, + encryptionPublicKey, + lastSeen: new Date(), + createdAt: new Date(), + updatedAt: new Date(), + isHealthy: true, + }).onConflictDoNothing(); +} + +export class DiscoveryError extends Error { + constructor(message: string) { + super(message); + this.name = 'DiscoveryError'; + } +} + +/** + * Fetches a remote server's /discover endpoint, registers it locally, + * and POSTs our own info so the remote registers us back (mutual registration). + * Returns the remote server's encryptionPublicKey on success. + */ +export async function discoverAndRegister(serverUrl: string): Promise { + debug('auto-discovering server %s', serverUrl); + + assertSafeUrl(serverUrl); + + let remote: { url?: string; publicKey?: string; encryptionPublicKey?: string }; + try { + const res = await fetch(serverUrl + '/discover', { + signal: AbortSignal.timeout(10_000), + }); + if (!res.ok) { + throw new DiscoveryError(`GET /discover returned ${res.status}`); + } + remote = await res.json(); + } catch (err) { + if (err instanceof DiscoveryError) throw err; + throw new DiscoveryError(`Failed to reach ${serverUrl}/discover: ${err instanceof Error ? err.message : err}`); + } + + if (!remote.publicKey || !remote.encryptionPublicKey) { + throw new DiscoveryError(`Server ${serverUrl} returned incomplete keys`); + } + + const existing = await db + .select({ publicKey: serverRegistry.publicKey }) + .from(serverRegistry) + .where(eq(serverRegistry.url, serverUrl)) + .limit(1); + + if (existing.length > 0 && existing[0].publicKey !== remote.publicKey) { + throw new DiscoveryError( + `Server ${serverUrl} presented a different public key than what we have on record. ` + + `This may indicate a key rotation issue or a compromised server.`, + ); + } + + debug('registering remote server %s locally', serverUrl); + await upsertServer(serverUrl, remote.publicKey, remote.encryptionPublicKey); + + debug('sending mutual REGISTER to %s', serverUrl); + try { + await fetch(serverUrl + '/discover', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + method: 'REGISTER', + url: process.env.BETTER_AUTH_URL!, + publicKey: process.env.FEDERATION_PUBLIC_KEY!, + encryptionPublicKey: process.env.FEDERATION_ENCRYPTION_PUBLIC_KEY!, + }), + signal: AbortSignal.timeout(10_000), + }); + } catch (err) { + debug('mutual REGISTER to %s failed (non-fatal): %s', serverUrl, err instanceof Error ? err.message : err); + } + + debug('auto-discovery of %s complete', serverUrl); + return remote.encryptionPublicKey; +} diff --git a/src/lib/federation/url-guard.ts b/src/lib/federation/url-guard.ts index c4875a5..7065669 100644 --- a/src/lib/federation/url-guard.ts +++ b/src/lib/federation/url-guard.ts @@ -7,8 +7,13 @@ const BLOCKED_HOSTNAMES = new Set([ "0.0.0.0", "[::1]", "[::0]", + "metadata.google.internal", + "metadata.goog", + "169.254.169.254", ]); +const DEV_ALLOWED_HOSTNAMES = new Set(["localhost", "127.0.0.1", process.env.DEV_ALLOWED_HOSTNAMES!]); +debug("DEV_ALLOWED_HOSTNAMES: %o", DEV_ALLOWED_HOSTNAMES); function isPrivateIPv4(hostname: string): boolean { const parts = hostname.split(".").map(Number); if (parts.length !== 4 || parts.some((p) => isNaN(p))) return false; @@ -34,7 +39,9 @@ function isPrivateIPv6(hostname: string): boolean { /** * Throws if the URL points to a private/internal address or uses a - * non-HTTP(S) protocol. Call before any server-side fetch to prevent SSRF. + * non-HTTP(S) protocol. In development, localhost/127.0.0.1 are explicitly + * allowed for local federation testing while all other safety checks + * remain enforced. */ export function assertSafeUrl(url: string): void { let parsed: URL; @@ -50,6 +57,10 @@ export function assertSafeUrl(url: string): void { const hostname = parsed.hostname; + if (process.env.NODE_ENV === "development" && DEV_ALLOWED_HOSTNAMES.has(hostname)) { + return; + } + if (BLOCKED_HOSTNAMES.has(hostname)) { debug("blocked hostname: %s", hostname); throw new UrlGuardError(`Blocked internal address: ${hostname}`); diff --git a/src/lib/plugins/client/social.ts b/src/lib/plugins/client/social.ts index ca68747..e3ab3e3 100644 --- a/src/lib/plugins/client/social.ts +++ b/src/lib/plugins/client/social.ts @@ -112,7 +112,33 @@ export const sipherSocialClientPlugin = () => { } return data.postId; + }, + followUser: async (userId: string, federationUrl?: string) => { + const body: Record = { + method: "INSERT", + userId, + }; + if (federationUrl) { + body.federationUrl = federationUrl; } + + const { data, error } = await $fetch<{ + following: { + id: string; + createdAt: Date; + followerId: string; + followingId: string; + accepted: boolean; + }; + }>("/social/follows", { + method: "POST", + body, + }); + if (error || !data) { + throw new Error("Failed to follow user"); + } + return data.following; + } } }, } satisfies BetterAuthClientPlugin; diff --git a/src/lib/plugins/server/helpers/social/endpoints/follows.ts b/src/lib/plugins/server/helpers/social/endpoints/follows.ts index 408b757..2881433 100644 --- a/src/lib/plugins/server/helpers/social/endpoints/follows.ts +++ b/src/lib/plugins/server/helpers/social/endpoints/follows.ts @@ -1,16 +1,239 @@ -import { createAuthEndpoint } from "better-auth/api" -import { z } from "zod" +import { getFederationQueue } from "@/lib/bull"; +import db from "@/lib/db"; +import { blacklistedServers, deliveryJobs, follows, serverRegistry, user } from "@/lib/db/schema"; +import { decryptPayload, getOwnEncryptionSecretKey, verifySignature } from "@/lib/federation/keytools"; +import { discoverAndRegister, DiscoveryError } from "@/lib/federation/registry"; +import { createAuthEndpoint, getSessionFromCtx } from "better-auth/api"; +import createDebug from "debug"; +import { and, eq } from "drizzle-orm"; +import { z } from "zod"; + +const debug = createDebug("app:plugins:server:helpers:social:follows"); + +const followSchema = z.discriminatedUnion( + "method", [ + z.object({ + method: z.literal("INSERT"), + userId: z.string(), + federationUrl: z.url().optional(), + }), + z.object({ + method: z.literal("FEDERATE"), + signature: z.string(), + payload: z.object({ + ephemeralPublicKey: z.string(), + iv: z.string(), + ciphertext: z.string(), + authTag: z.string(), + }).transform((payload, ctx) => { + try { + const decrypted = decryptPayload(payload, getOwnEncryptionSecretKey()); + const parsedPayload = JSON.parse(decrypted); + + const parsedPayloadSchema = z.object({ + following: z.object({ + id: z.string(), + createdAt: z.coerce.date(), + followerId: z.string(), + followingId: z.string(), + accepted: z.boolean(), + followerServerUrl: z.string().nullable(), + }), + federationUrl: z.string(), + method: z.literal("FEDERATE"), + }).safeParse(parsedPayload); + if (!parsedPayloadSchema.success) { + ctx.addIssue({ code: "custom", message: "Invalid payload" }); + return z.never(); + } + return { ...parsedPayloadSchema.data, _raw: decrypted }; + } catch { + ctx.addIssue({ code: "custom", message: "Invalid payload" }); + return z.never(); + } + }), + }), + z.object({ + method: z.literal("UNFOLLOW"), + userId: z.string(), + }), +], { error: "Invalid follow method" }, +) export const followUser = createAuthEndpoint("/social/follows", { method: "POST", -}, async (context) => { }) + body: followSchema, +}, async (context) => { + debug("FOLLOW – %s", context.body.method); + const { method } = context.body; + switch (method) { + case "INSERT": { + const session = await getSessionFromCtx(context); + debug("FOLLOW – user: %o", session); + if (!session) { + return context.json({ error: "Unauthorized" }, { status: 401 }); + }; -export const unfollowUser = createAuthEndpoint("/social/follows/:id", { - method: "DELETE", - params: z.object({ - id: z.string(), - }), -}, async (context) => { }) + const { userId, federationUrl } = context.body; + const ownUrl = process.env.BETTER_AUTH_URL!; + const isLocal = !federationUrl || federationUrl === ownUrl; + + const [existingFollow] = await db + .select({ id: follows.id }) + .from(follows) + .where(and( + eq(follows.followerId, session.user.id), + eq(follows.followingId, userId), + )) + .limit(1); + + if (existingFollow) { + return context.json({ error: "You are already following this user." }, { status: 409 }); + } + + if (isLocal) { + const [targetUser] = await db + .select({ id: user.id, isPrivate: user.isPrivate }) + .from(user) + .where(eq(user.id, userId)) + .limit(1); + + if (!targetUser) { + return context.json({ error: "User not found." }, { status: 404 }); + } + + const following = await db.insert(follows).values({ + id: crypto.randomUUID(), + followerId: session.user.id, + followingId: userId, + accepted: !targetUser.isPrivate, + createdAt: new Date(), + }).returning(); + + return context.json({ following }, { status: 200 }); + } + + const serverUrl = federationUrl!.toString().replace(/\/+$/, ''); + + const [blacklisted] = await db + .select({ id: blacklistedServers.id }) + .from(blacklistedServers) + .where(eq(blacklistedServers.serverUrl, serverUrl)) + .limit(1); + + if (blacklisted) { + return context.json({ error: "This server has been blocked." }, { status: 403 }); + } + + const [existing] = await db + .select({ url: serverRegistry.url }) + .from(serverRegistry) + .where(eq(serverRegistry.url, serverUrl)) + .limit(1); + + if (!existing) { + try { + debug("FOLLOW – discovering and registering server %s", serverUrl); + await discoverAndRegister(serverUrl); + } catch (err) { + if (err instanceof DiscoveryError) { + debug("discovery failed for %s: %s", serverUrl, err.message); + return context.json({ error: "Could not reach the federation server." }, { status: 502 }); + } + throw err; + } + } + + const following = await db.insert(follows).values({ + id: crypto.randomUUID(), + followerId: session.user.id, + followingId: userId, + accepted: false, + createdAt: new Date(), + followerServerUrl: serverUrl, + }).returning(); + + const job = await db.insert(deliveryJobs).values({ + id: crypto.randomUUID(), + targetUrl: serverUrl + "/api/auth/social/follows", + payload: JSON.stringify({ following: following[0], federationUrl: ownUrl, method: "FEDERATE" }), + attempts: 0, + createdAt: new Date(), + }).returning(); + + await getFederationQueue().add("deliver-follow", { + deliveryJobId: job[0].id, + targetUrl: job[0].targetUrl, + serverUrl, + payload: JSON.stringify({ following: following[0], federationUrl: ownUrl, method: "FEDERATE" }), + }); + + return context.json({ following }, { status: 200 }); + } + case "FEDERATE": { + const { payload, signature } = context.body; + + if (!payload || payload instanceof z.ZodNever || !("following" in payload) || !("federationUrl" in payload)) { + return context.json({ error: "Invalid payload", code: "INVALID_PAYLOAD" }, { status: 400 }); + } + + const { following, federationUrl, _raw } = payload; + + const [server] = await db + .select({ url: serverRegistry.url, publicKey: serverRegistry.publicKey }) + .from(serverRegistry) + .where(eq(serverRegistry.url, federationUrl)) + .limit(1); + + if (!server) { + return context.json({ + error: "Unknown federation server. Please redo the discovery process and try again.", + code: "UNKNOWN_FEDERATION_SERVER_INTERACTION", + }, { status: 403 }); + } + + const senderPublicKey = new Uint8Array(Buffer.from(server.publicKey, "base64")); + if (!verifySignature(_raw, signature, senderPublicKey)) { + return context.json({ + error: "Signature verification failed.", + code: "INVALID_SIGNATURE", + }, { status: 403 }); + } + + const [targetUser] = await db + .select({ id: user.id, isPrivate: user.isPrivate }) + .from(user) + .where(eq(user.id, following.followingId)) + .limit(1); + + if (!targetUser) { + return context.json({ + error: "The user being followed does not exist on this server.", + code: "USER_NOT_FOUND", + }, { status: 404 }); + } + + const accepted = !targetUser.isPrivate; + + await db.insert(follows).values({ + id: crypto.randomUUID(), + followerId: following.followerId, + followingId: following.followingId, + accepted, + createdAt: new Date(), + followingServerUrl: server.url, + }); + + return context.json({ status: "acknowledged", accepted }, { status: 200 }); + } + case "UNFOLLOW": { + return context.json({ error: "Not implemented" }, { status: 501 }); + } + default: { + return context.json({ error: "Invalid method" }, { status: 400 }); + } + } +}) export const getFollows = createAuthEndpoint("/social/follows/following", { method: "GET", diff --git a/src/lib/plugins/server/helpers/social/endpoints/index.ts b/src/lib/plugins/server/helpers/social/endpoints/index.ts index fadbb6d..6a83a13 100644 --- a/src/lib/plugins/server/helpers/social/endpoints/index.ts +++ b/src/lib/plugins/server/helpers/social/endpoints/index.ts @@ -1,7 +1,7 @@ import { createBlock, deleteBlock, getBlocks } from "./blocks"; -import { followUser, getFollowers, getFollows, unfollowUser } from "./follows"; +import { followUser, getFollowers, getFollows } from "./follows"; import { createMute, deleteMute, getMutes } from "./mutes"; import { createPost, getPost, uploadFile } from "./posts"; -export { createBlock, createMute, createPost, deleteBlock, deleteMute, followUser, getBlocks, getFollowers, getFollows, getMutes, getPost, unfollowUser, uploadFile }; +export { createBlock, createMute, createPost, deleteBlock, deleteMute, followUser, getBlocks, getFollowers, getFollows, getMutes, getPost, uploadFile }; diff --git a/src/lib/plugins/server/helpers/social/endpoints/posts.ts b/src/lib/plugins/server/helpers/social/endpoints/posts.ts index 50d6f03..dc860ae 100644 --- a/src/lib/plugins/server/helpers/social/endpoints/posts.ts +++ b/src/lib/plugins/server/helpers/social/endpoints/posts.ts @@ -1,7 +1,9 @@ import db from "@/lib/db"; -import { posts } from "@/lib/db/schema"; +import { deliveryJobs, follows, posts } from "@/lib/db/schema"; +import { getFederationQueue, type FederationDeliveryJob } from "@/lib/bull"; import minioClient from "@/plugins/server/storage/minio.client"; import { createAuthEndpoint, getSessionFromCtx } from "better-auth/api"; +import { and, eq } from "drizzle-orm"; import { z } from "zod"; import { postContentSchema } from "../social"; @@ -16,8 +18,6 @@ export const createPost = createAuthEndpoint("/social/posts", { return context.json({ error: "Unauthorized" }, { status: 401 }); } - - // Create post const post = await db.insert(posts).values({ id: crypto.randomUUID(), @@ -28,6 +28,36 @@ export const createPost = createAuthEndpoint("/social/posts", { createdAt: new Date(), }).returning({ id: posts.id }); + // Enqueue federation delivery jobs for each follower's server + const followers = await db.select().from(follows).where(and(eq(follows.followingId, user.user.id), eq(follows.accepted, true))); + const uniqueUrls = [...new Set(followers.map(f => f.followerServerUrl).filter(Boolean))] as string[]; + const payload = JSON.stringify({ content }); + + const jobRows = uniqueUrls.map(url => ({ + id: crypto.randomUUID(), + targetUrl: url + "/social/posts", + serverUrl: url, + payload, + attempts: 0, + createdAt: new Date(), + })); + + if (jobRows.length > 0) { + await db.insert(deliveryJobs).values(jobRows); + + await getFederationQueue().addBulk( + jobRows.map(row => ({ + name: 'deliver-post' as const, + data: { + deliveryJobId: row.id, + targetUrl: row.targetUrl, + serverUrl: row.serverUrl, + payload: row.payload, + } satisfies FederationDeliveryJob, + })), + ); + } + return context.json({ id: post[0].id }, { status: 200 }); }); diff --git a/src/lib/plugins/server/helpers/social/social.ts b/src/lib/plugins/server/helpers/social/social.ts index 7a7a872..efae50a 100644 --- a/src/lib/plugins/server/helpers/social/social.ts +++ b/src/lib/plugins/server/helpers/social/social.ts @@ -94,6 +94,11 @@ export default { type: "date", required: true, index: false + }, + federationUrl: { + type: "string", + required: false, + index: true, } } }, @@ -103,19 +108,11 @@ export default { type: "string", required: true, index: false, - references: { - model: "user", - field: "id" - } }, followingId: { type: "string", required: true, index: false, - references: { - model: "user", - field: "id" - } }, accepted: { type: "boolean", @@ -127,7 +124,25 @@ export default { type: "date", required: true, index: false - } + }, + followerServerUrl: { + type: "string", + required: false, + index: true, + references: { + model: "serverRegistry", + field: "url" + } + }, + followingServerUrl: { + type: "string", + required: false, + index: true, + references: { + model: "serverRegistry", + field: "url" + } + }, } }, deliveryJobs: { diff --git a/src/lib/redis/index.ts b/src/lib/redis/index.ts new file mode 100644 index 0000000..b97cc22 --- /dev/null +++ b/src/lib/redis/index.ts @@ -0,0 +1,12 @@ +import Redis from "ioredis"; + +let redisClient: Redis | null = null; + +export function getRedisClient(): Redis { + if (!redisClient) { + redisClient = new Redis(process.env.REDIS_URL!); + } + return redisClient; +} + +export default getRedisClient; \ No newline at end of file