230 lines
7.0 KiB
TypeScript
230 lines
7.0 KiB
TypeScript
import { and, asc, eq } from 'drizzle-orm';
|
|
import type {
|
|
ResearchCopilotCitation,
|
|
ResearchCopilotMessage,
|
|
ResearchCopilotSession,
|
|
ResearchCopilotSuggestedAction,
|
|
ResearchMemoSection,
|
|
SearchSource
|
|
} from '@/lib/types';
|
|
import { db } from '@/lib/server/db';
|
|
import {
|
|
researchCopilotMessage,
|
|
researchCopilotSession
|
|
} from '@/lib/server/db/schema';
|
|
|
|
type ResearchCopilotSessionRow = typeof researchCopilotSession.$inferSelect;
|
|
type ResearchCopilotMessageRow = typeof researchCopilotMessage.$inferSelect;
|
|
|
|
const DEFAULT_SELECTED_SOURCES: SearchSource[] = ['documents', 'filings', 'research'];
|
|
|
|
function normalizeTicker(ticker: string) {
|
|
return ticker.trim().toUpperCase();
|
|
}
|
|
|
|
function normalizeSources(value?: SearchSource[] | null) {
|
|
const unique = new Set<SearchSource>();
|
|
|
|
for (const source of value ?? DEFAULT_SELECTED_SOURCES) {
|
|
if (source === 'documents' || source === 'filings' || source === 'research') {
|
|
unique.add(source);
|
|
}
|
|
}
|
|
|
|
return unique.size > 0 ? [...unique] : [...DEFAULT_SELECTED_SOURCES];
|
|
}
|
|
|
|
function normalizePinnedArtifactIds(value?: number[] | null) {
|
|
const unique = new Set<number>();
|
|
|
|
for (const id of value ?? []) {
|
|
const normalized = Math.trunc(Number(id));
|
|
if (Number.isInteger(normalized) && normalized > 0) {
|
|
unique.add(normalized);
|
|
}
|
|
}
|
|
|
|
return [...unique];
|
|
}
|
|
|
|
function normalizeOptionalString(value?: string | null) {
|
|
const normalized = value?.trim();
|
|
return normalized ? normalized : null;
|
|
}
|
|
|
|
function toCitationArray(value: unknown): ResearchCopilotCitation[] {
|
|
return Array.isArray(value) ? value as ResearchCopilotCitation[] : [];
|
|
}
|
|
|
|
function toActionArray(value: unknown): ResearchCopilotSuggestedAction[] {
|
|
return Array.isArray(value) ? value as ResearchCopilotSuggestedAction[] : [];
|
|
}
|
|
|
|
function toFollowUps(value: unknown) {
|
|
return Array.isArray(value)
|
|
? value.filter((entry): entry is string => typeof entry === 'string' && entry.trim().length > 0)
|
|
: [];
|
|
}
|
|
|
|
function toMessage(row: ResearchCopilotMessageRow): ResearchCopilotMessage {
|
|
return {
|
|
id: row.id,
|
|
session_id: row.session_id,
|
|
user_id: row.user_id,
|
|
role: row.role,
|
|
content_markdown: row.content_markdown,
|
|
citations: toCitationArray(row.citations),
|
|
follow_ups: toFollowUps(row.follow_ups),
|
|
suggested_actions: toActionArray(row.suggested_actions),
|
|
selected_sources: normalizeSources(row.selected_sources),
|
|
pinned_artifact_ids: normalizePinnedArtifactIds(row.pinned_artifact_ids),
|
|
memo_section: row.memo_section ?? null,
|
|
created_at: row.created_at
|
|
};
|
|
}
|
|
|
|
function toSession(row: ResearchCopilotSessionRow, messages: ResearchCopilotMessage[]): ResearchCopilotSession {
|
|
return {
|
|
id: row.id,
|
|
user_id: row.user_id,
|
|
ticker: row.ticker,
|
|
title: row.title ?? null,
|
|
selected_sources: normalizeSources(row.selected_sources),
|
|
pinned_artifact_ids: normalizePinnedArtifactIds(row.pinned_artifact_ids),
|
|
created_at: row.created_at,
|
|
updated_at: row.updated_at,
|
|
messages
|
|
};
|
|
}
|
|
|
|
async function listMessagesForSession(sessionId: number) {
|
|
const rows = await db
|
|
.select()
|
|
.from(researchCopilotMessage)
|
|
.where(eq(researchCopilotMessage.session_id, sessionId))
|
|
.orderBy(asc(researchCopilotMessage.created_at), asc(researchCopilotMessage.id));
|
|
|
|
return rows.map(toMessage);
|
|
}
|
|
|
|
async function getSessionRowByTicker(userId: string, ticker: string) {
|
|
const [row] = await db
|
|
.select()
|
|
.from(researchCopilotSession)
|
|
.where(and(
|
|
eq(researchCopilotSession.user_id, userId),
|
|
eq(researchCopilotSession.ticker, normalizeTicker(ticker))
|
|
))
|
|
.limit(1);
|
|
|
|
return row ?? null;
|
|
}
|
|
|
|
export async function getResearchCopilotSessionByTicker(userId: string, ticker: string) {
|
|
const row = await getSessionRowByTicker(userId, ticker);
|
|
if (!row) {
|
|
return null;
|
|
}
|
|
|
|
return toSession(row, await listMessagesForSession(row.id));
|
|
}
|
|
|
|
export async function getOrCreateResearchCopilotSession(input: {
|
|
userId: string;
|
|
ticker: string;
|
|
title?: string | null;
|
|
selectedSources?: SearchSource[] | null;
|
|
pinnedArtifactIds?: number[] | null;
|
|
}) {
|
|
const normalizedTicker = normalizeTicker(input.ticker);
|
|
if (!normalizedTicker) {
|
|
throw new Error('ticker is required');
|
|
}
|
|
|
|
const existing = await getSessionRowByTicker(input.userId, normalizedTicker);
|
|
if (existing) {
|
|
const messages = await listMessagesForSession(existing.id);
|
|
return toSession(existing, messages);
|
|
}
|
|
|
|
const now = new Date().toISOString();
|
|
const [created] = await db
|
|
.insert(researchCopilotSession)
|
|
.values({
|
|
user_id: input.userId,
|
|
ticker: normalizedTicker,
|
|
title: normalizeOptionalString(input.title),
|
|
selected_sources: normalizeSources(input.selectedSources),
|
|
pinned_artifact_ids: normalizePinnedArtifactIds(input.pinnedArtifactIds),
|
|
created_at: now,
|
|
updated_at: now
|
|
})
|
|
.returning();
|
|
|
|
return toSession(created, []);
|
|
}
|
|
|
|
export async function upsertResearchCopilotSessionState(input: {
|
|
userId: string;
|
|
ticker: string;
|
|
title?: string | null;
|
|
selectedSources?: SearchSource[] | null;
|
|
pinnedArtifactIds?: number[] | null;
|
|
}) {
|
|
const session = await getOrCreateResearchCopilotSession(input);
|
|
const [updated] = await db
|
|
.update(researchCopilotSession)
|
|
.set({
|
|
title: input.title === undefined ? session.title : normalizeOptionalString(input.title),
|
|
selected_sources: input.selectedSources === undefined
|
|
? session.selected_sources
|
|
: normalizeSources(input.selectedSources),
|
|
pinned_artifact_ids: input.pinnedArtifactIds === undefined
|
|
? session.pinned_artifact_ids
|
|
: normalizePinnedArtifactIds(input.pinnedArtifactIds),
|
|
updated_at: new Date().toISOString()
|
|
})
|
|
.where(eq(researchCopilotSession.id, session.id))
|
|
.returning();
|
|
|
|
return toSession(updated, await listMessagesForSession(updated.id));
|
|
}
|
|
|
|
export async function appendResearchCopilotMessage(input: {
|
|
userId: string;
|
|
sessionId: number;
|
|
role: ResearchCopilotMessage['role'];
|
|
contentMarkdown: string;
|
|
citations?: ResearchCopilotCitation[] | null;
|
|
followUps?: string[] | null;
|
|
suggestedActions?: ResearchCopilotSuggestedAction[] | null;
|
|
selectedSources?: SearchSource[] | null;
|
|
pinnedArtifactIds?: number[] | null;
|
|
memoSection?: ResearchMemoSection | null;
|
|
}) {
|
|
const now = new Date().toISOString();
|
|
const [created] = await db
|
|
.insert(researchCopilotMessage)
|
|
.values({
|
|
session_id: input.sessionId,
|
|
user_id: input.userId,
|
|
role: input.role,
|
|
content_markdown: input.contentMarkdown.trim(),
|
|
citations: input.citations ?? [],
|
|
follow_ups: input.followUps ?? [],
|
|
suggested_actions: input.suggestedActions ?? [],
|
|
selected_sources: input.selectedSources ? normalizeSources(input.selectedSources) : null,
|
|
pinned_artifact_ids: input.pinnedArtifactIds ? normalizePinnedArtifactIds(input.pinnedArtifactIds) : null,
|
|
memo_section: input.memoSection ?? null,
|
|
created_at: now
|
|
})
|
|
.returning();
|
|
|
|
await db
|
|
.update(researchCopilotSession)
|
|
.set({ updated_at: now })
|
|
.where(eq(researchCopilotSession.id, input.sessionId));
|
|
|
|
return toMessage(created);
|
|
}
|