import { and, desc, eq, inArray, sql } from 'drizzle-orm'; import type { Task, TaskStatus, TaskType } from '@/lib/types'; import { db } from '@/lib/server/db'; import { taskRun } from '@/lib/server/db/schema'; type TaskRow = typeof taskRun.$inferSelect; type CreateTaskInput = { id: string; user_id: string; task_type: TaskType; payload: Record; priority: number; max_attempts: number; }; function toTask(row: TaskRow): Task { return { id: row.id, user_id: row.user_id, task_type: row.task_type, status: row.status, priority: row.priority, payload: row.payload, result: row.result, error: row.error, attempts: row.attempts, max_attempts: row.max_attempts, workflow_run_id: row.workflow_run_id, created_at: row.created_at, updated_at: row.updated_at, finished_at: row.finished_at }; } export async function createTaskRunRecord(input: CreateTaskInput) { const now = new Date().toISOString(); const [row] = await db .insert(taskRun) .values({ id: input.id, user_id: input.user_id, task_type: input.task_type, status: 'queued', priority: input.priority, payload: input.payload, result: null, error: null, attempts: 0, max_attempts: input.max_attempts, workflow_run_id: null, created_at: now, updated_at: now, finished_at: null }) .returning(); return toTask(row); } export async function setTaskWorkflowRunId(taskId: string, workflowRunId: string) { await db .update(taskRun) .set({ workflow_run_id: workflowRunId, updated_at: new Date().toISOString() }) .where(eq(taskRun.id, taskId)); } export async function getTaskByIdForUser(taskId: string, userId: string) { const [row] = await db .select() .from(taskRun) .where(and(eq(taskRun.id, taskId), eq(taskRun.user_id, userId))) .limit(1); return row ? toTask(row) : null; } export async function listRecentTasksForUser( userId: string, limit = 20, statuses?: TaskStatus[] ) { const safeLimit = Math.min(Math.max(Math.trunc(limit), 1), 200); const rows = statuses && statuses.length > 0 ? await db .select() .from(taskRun) .where(and(eq(taskRun.user_id, userId), inArray(taskRun.status, statuses))) .orderBy(desc(taskRun.created_at)) .limit(safeLimit) : await db .select() .from(taskRun) .where(eq(taskRun.user_id, userId)) .orderBy(desc(taskRun.created_at)) .limit(safeLimit); return rows.map(toTask); } export async function countTasksByStatus() { const rows = await db .select({ status: taskRun.status, count: sql`count(*)` }) .from(taskRun) .groupBy(taskRun.status); const queue: Record = {}; for (const row of rows) { queue[row.status] = Number(row.count); } return queue; } export async function claimQueuedTask(taskId: string) { const [row] = await db .update(taskRun) .set({ status: 'running', attempts: sql`${taskRun.attempts} + 1`, updated_at: new Date().toISOString() }) .where(and(eq(taskRun.id, taskId), eq(taskRun.status, 'queued'))) .returning(); return row ? toTask(row) : null; } export async function completeTask(taskId: string, result: Record) { const [row] = await db .update(taskRun) .set({ status: 'completed', result, error: null, updated_at: new Date().toISOString(), finished_at: new Date().toISOString() }) .where(eq(taskRun.id, taskId)) .returning(); return row ? toTask(row) : null; } export async function markTaskFailure(taskId: string, reason: string) { const [current] = await db .select() .from(taskRun) .where(eq(taskRun.id, taskId)) .limit(1); if (!current) { return { task: null, shouldRetry: false }; } const shouldRetry = current.attempts < current.max_attempts; const [updated] = await db .update(taskRun) .set({ status: shouldRetry ? 'queued' : 'failed', error: reason, updated_at: new Date().toISOString(), finished_at: shouldRetry ? null : new Date().toISOString() }) .where(eq(taskRun.id, taskId)) .returning(); return { task: updated ? toTask(updated) : null, shouldRetry }; } export async function getTaskById(taskId: string) { const [row] = await db .select() .from(taskRun) .where(eq(taskRun.id, taskId)) .limit(1); return row ? toTask(row) : null; }