Add hybrid research copilot workspace

This commit is contained in:
2026-03-14 19:32:00 -04:00
parent 7a42d73a48
commit 2ee9a549a3
27 changed files with 2864 additions and 323 deletions

View File

@@ -0,0 +1,165 @@
import {
afterAll,
beforeAll,
beforeEach,
describe,
expect,
it
} from 'bun:test';
import { mock } from 'bun:test';
import { mkdtempSync, readFileSync, rmSync } from 'node:fs';
import { tmpdir } from 'node:os';
import { join } from 'node:path';
import { Database } from 'bun:sqlite';
const TEST_USER_ID = 'copilot-user';
let tempDir: string | null = null;
let sqliteClient: Database | null = null;
let copilotRepo: typeof import('./research-copilot') | null = null;
async function loadRepoModule() {
const moduleUrl = new URL(`./research-copilot.ts?test=${Date.now()}`, import.meta.url).href;
return await import(moduleUrl) as typeof import('./research-copilot');
}
function resetDbSingletons() {
const globalState = globalThis as typeof globalThis & {
__fiscalSqliteClient?: Database;
__fiscalDrizzleDb?: unknown;
};
globalState.__fiscalSqliteClient?.close();
globalState.__fiscalSqliteClient = undefined;
globalState.__fiscalDrizzleDb = undefined;
}
function applyMigration(client: Database, fileName: string) {
const sql = readFileSync(join(process.cwd(), 'drizzle', fileName), 'utf8');
client.exec(sql);
}
function ensureUser(client: Database) {
const now = Date.now();
client.exec(`
INSERT OR REPLACE INTO user (id, name, email, emailVerified, image, createdAt, updatedAt, role, banned, banReason, banExpires)
VALUES ('${TEST_USER_ID}', 'Copilot User', 'copilot@example.com', 1, NULL, ${now}, ${now}, NULL, 0, NULL, NULL);
`);
}
describe('research copilot repo', () => {
beforeAll(async () => {
mock.restore();
tempDir = mkdtempSync(join(tmpdir(), 'fiscal-copilot-repo-'));
process.env.DATABASE_URL = `file:${join(tempDir, 'repo.sqlite')}`;
(process.env as Record<string, string | undefined>).NODE_ENV = 'test';
resetDbSingletons();
sqliteClient = new Database(join(tempDir, 'repo.sqlite'), { create: true });
sqliteClient.exec('PRAGMA foreign_keys = ON;');
applyMigration(sqliteClient, '0000_cold_silver_centurion.sql');
applyMigration(sqliteClient, '0008_research_workspace.sql');
applyMigration(sqliteClient, '0013_research_copilot.sql');
ensureUser(sqliteClient);
const globalState = globalThis as typeof globalThis & {
__fiscalSqliteClient?: Database;
__fiscalDrizzleDb?: unknown;
};
globalState.__fiscalSqliteClient = sqliteClient;
globalState.__fiscalDrizzleDb = undefined;
copilotRepo = await loadRepoModule();
});
afterAll(() => {
mock.restore();
sqliteClient?.close();
resetDbSingletons();
if (tempDir) {
rmSync(tempDir, { recursive: true, force: true });
}
});
beforeEach(() => {
sqliteClient?.exec('DELETE FROM research_copilot_message;');
sqliteClient?.exec('DELETE FROM research_copilot_session;');
});
it('creates and reloads ticker-scoped sessions', async () => {
if (!copilotRepo) {
throw new Error('repo not initialized');
}
const session = await copilotRepo.getOrCreateResearchCopilotSession({
userId: TEST_USER_ID,
ticker: 'msft',
selectedSources: ['documents', 'research'],
pinnedArtifactIds: [2, 2, 5]
});
const loaded = await copilotRepo.getResearchCopilotSessionByTicker(TEST_USER_ID, 'MSFT');
expect(session.ticker).toBe('MSFT');
expect(session.selected_sources).toEqual(['documents', 'research']);
expect(session.pinned_artifact_ids).toEqual([2, 5]);
expect(loaded?.id).toBe(session.id);
});
it('appends messages and updates session state', async () => {
if (!copilotRepo) {
throw new Error('repo not initialized');
}
const session = await copilotRepo.getOrCreateResearchCopilotSession({
userId: TEST_USER_ID,
ticker: 'NVDA'
});
await copilotRepo.appendResearchCopilotMessage({
userId: TEST_USER_ID,
sessionId: session.id,
role: 'user',
contentMarkdown: 'What changed in the latest filing?',
selectedSources: ['filings'],
pinnedArtifactIds: [7],
memoSection: 'thesis'
});
await copilotRepo.appendResearchCopilotMessage({
userId: TEST_USER_ID,
sessionId: session.id,
role: 'assistant',
contentMarkdown: 'Demand remained strong [1]',
citations: [{
index: 1,
label: 'NVDA 10-K [1]',
chunkId: 1,
href: '/filings?ticker=NVDA',
source: 'filings',
sourceKind: 'filing_brief',
sourceRef: '0001',
title: '10-K brief',
ticker: 'NVDA',
accessionNumber: '0001',
filingDate: '2026-01-01',
excerpt: 'Demand remained strong.',
artifactId: 3
}]
});
const updated = await copilotRepo.upsertResearchCopilotSessionState({
userId: TEST_USER_ID,
ticker: 'NVDA',
title: 'NVDA demand update',
selectedSources: ['filings'],
pinnedArtifactIds: [7]
});
expect(updated.title).toBe('NVDA demand update');
expect(updated.messages).toHaveLength(2);
expect(updated.messages[0]?.selected_sources).toEqual(['filings']);
expect(updated.messages[0]?.memo_section).toBe('thesis');
expect(updated.messages[1]?.citations[0]?.artifactId).toBe(3);
});
});

View File

@@ -0,0 +1,229 @@
import { and, asc, eq } from 'drizzle-orm';
import type {
ResearchCopilotCitation,
ResearchCopilotMessage,
ResearchCopilotSession,
ResearchCopilotSuggestedAction,
ResearchMemoSection,
SearchSource
} from '@/lib/types';
import { db } from '@/lib/server/db';
import {
researchCopilotMessage,
researchCopilotSession
} from '@/lib/server/db/schema';
type ResearchCopilotSessionRow = typeof researchCopilotSession.$inferSelect;
type ResearchCopilotMessageRow = typeof researchCopilotMessage.$inferSelect;
const DEFAULT_SELECTED_SOURCES: SearchSource[] = ['documents', 'filings', 'research'];
function normalizeTicker(ticker: string) {
return ticker.trim().toUpperCase();
}
function normalizeSources(value?: SearchSource[] | null) {
const unique = new Set<SearchSource>();
for (const source of value ?? DEFAULT_SELECTED_SOURCES) {
if (source === 'documents' || source === 'filings' || source === 'research') {
unique.add(source);
}
}
return unique.size > 0 ? [...unique] : [...DEFAULT_SELECTED_SOURCES];
}
function normalizePinnedArtifactIds(value?: number[] | null) {
const unique = new Set<number>();
for (const id of value ?? []) {
const normalized = Math.trunc(Number(id));
if (Number.isInteger(normalized) && normalized > 0) {
unique.add(normalized);
}
}
return [...unique];
}
function normalizeOptionalString(value?: string | null) {
const normalized = value?.trim();
return normalized ? normalized : null;
}
function toCitationArray(value: unknown): ResearchCopilotCitation[] {
return Array.isArray(value) ? value as ResearchCopilotCitation[] : [];
}
function toActionArray(value: unknown): ResearchCopilotSuggestedAction[] {
return Array.isArray(value) ? value as ResearchCopilotSuggestedAction[] : [];
}
function toFollowUps(value: unknown) {
return Array.isArray(value)
? value.filter((entry): entry is string => typeof entry === 'string' && entry.trim().length > 0)
: [];
}
function toMessage(row: ResearchCopilotMessageRow): ResearchCopilotMessage {
return {
id: row.id,
session_id: row.session_id,
user_id: row.user_id,
role: row.role,
content_markdown: row.content_markdown,
citations: toCitationArray(row.citations),
follow_ups: toFollowUps(row.follow_ups),
suggested_actions: toActionArray(row.suggested_actions),
selected_sources: normalizeSources(row.selected_sources),
pinned_artifact_ids: normalizePinnedArtifactIds(row.pinned_artifact_ids),
memo_section: row.memo_section ?? null,
created_at: row.created_at
};
}
function toSession(row: ResearchCopilotSessionRow, messages: ResearchCopilotMessage[]): ResearchCopilotSession {
return {
id: row.id,
user_id: row.user_id,
ticker: row.ticker,
title: row.title ?? null,
selected_sources: normalizeSources(row.selected_sources),
pinned_artifact_ids: normalizePinnedArtifactIds(row.pinned_artifact_ids),
created_at: row.created_at,
updated_at: row.updated_at,
messages
};
}
async function listMessagesForSession(sessionId: number) {
const rows = await db
.select()
.from(researchCopilotMessage)
.where(eq(researchCopilotMessage.session_id, sessionId))
.orderBy(asc(researchCopilotMessage.created_at), asc(researchCopilotMessage.id));
return rows.map(toMessage);
}
async function getSessionRowByTicker(userId: string, ticker: string) {
const [row] = await db
.select()
.from(researchCopilotSession)
.where(and(
eq(researchCopilotSession.user_id, userId),
eq(researchCopilotSession.ticker, normalizeTicker(ticker))
))
.limit(1);
return row ?? null;
}
export async function getResearchCopilotSessionByTicker(userId: string, ticker: string) {
const row = await getSessionRowByTicker(userId, ticker);
if (!row) {
return null;
}
return toSession(row, await listMessagesForSession(row.id));
}
export async function getOrCreateResearchCopilotSession(input: {
userId: string;
ticker: string;
title?: string | null;
selectedSources?: SearchSource[] | null;
pinnedArtifactIds?: number[] | null;
}) {
const normalizedTicker = normalizeTicker(input.ticker);
if (!normalizedTicker) {
throw new Error('ticker is required');
}
const existing = await getSessionRowByTicker(input.userId, normalizedTicker);
if (existing) {
const messages = await listMessagesForSession(existing.id);
return toSession(existing, messages);
}
const now = new Date().toISOString();
const [created] = await db
.insert(researchCopilotSession)
.values({
user_id: input.userId,
ticker: normalizedTicker,
title: normalizeOptionalString(input.title),
selected_sources: normalizeSources(input.selectedSources),
pinned_artifact_ids: normalizePinnedArtifactIds(input.pinnedArtifactIds),
created_at: now,
updated_at: now
})
.returning();
return toSession(created, []);
}
export async function upsertResearchCopilotSessionState(input: {
userId: string;
ticker: string;
title?: string | null;
selectedSources?: SearchSource[] | null;
pinnedArtifactIds?: number[] | null;
}) {
const session = await getOrCreateResearchCopilotSession(input);
const [updated] = await db
.update(researchCopilotSession)
.set({
title: input.title === undefined ? session.title : normalizeOptionalString(input.title),
selected_sources: input.selectedSources === undefined
? session.selected_sources
: normalizeSources(input.selectedSources),
pinned_artifact_ids: input.pinnedArtifactIds === undefined
? session.pinned_artifact_ids
: normalizePinnedArtifactIds(input.pinnedArtifactIds),
updated_at: new Date().toISOString()
})
.where(eq(researchCopilotSession.id, session.id))
.returning();
return toSession(updated, await listMessagesForSession(updated.id));
}
export async function appendResearchCopilotMessage(input: {
userId: string;
sessionId: number;
role: ResearchCopilotMessage['role'];
contentMarkdown: string;
citations?: ResearchCopilotCitation[] | null;
followUps?: string[] | null;
suggestedActions?: ResearchCopilotSuggestedAction[] | null;
selectedSources?: SearchSource[] | null;
pinnedArtifactIds?: number[] | null;
memoSection?: ResearchMemoSection | null;
}) {
const now = new Date().toISOString();
const [created] = await db
.insert(researchCopilotMessage)
.values({
session_id: input.sessionId,
user_id: input.userId,
role: input.role,
content_markdown: input.contentMarkdown.trim(),
citations: input.citations ?? [],
follow_ups: input.followUps ?? [],
suggested_actions: input.suggestedActions ?? [],
selected_sources: input.selectedSources ? normalizeSources(input.selectedSources) : null,
pinned_artifact_ids: input.pinnedArtifactIds ? normalizePinnedArtifactIds(input.pinnedArtifactIds) : null,
memo_section: input.memoSection ?? null,
created_at: now
})
.returning();
await db
.update(researchCopilotSession)
.set({ updated_at: now })
.where(eq(researchCopilotSession.id, input.sessionId));
return toMessage(created);
}

View File

@@ -25,6 +25,7 @@ import {
researchMemo,
researchMemoEvidence
} from '@/lib/server/db/schema';
import { getResearchCopilotSessionByTicker } from '@/lib/server/repos/research-copilot';
import { getFilingByAccession, listFilingsRecords } from '@/lib/server/repos/filings';
import { getWatchlistItemByTicker } from '@/lib/server/repos/watchlist';
@@ -374,6 +375,26 @@ async function getArtifactByIdForUser(id: number, userId: string) {
return row ?? null;
}
export async function getResearchArtifactsByIdsForUser(userId: string, ids: number[]) {
const normalizedIds = [...new Set(ids.map((id) => Math.trunc(id)).filter((id) => Number.isInteger(id) && id > 0))];
if (normalizedIds.length === 0) {
return [];
}
const rows = await db
.select()
.from(researchArtifact)
.where(and(
eq(researchArtifact.user_id, userId),
sql`${researchArtifact.id} in (${sql.join(normalizedIds.map((id) => sql`${id}`), sql`, `)})`
));
const order = new Map(normalizedIds.map((id, index) => [id, index]));
return rows
.sort((left, right) => (order.get(left.id) ?? Number.MAX_SAFE_INTEGER) - (order.get(right.id) ?? Number.MAX_SAFE_INTEGER))
.map((row) => toResearchArtifact(row));
}
async function getMemoByIdForUser(id: number, userId: string) {
const [row] = await db
.select()
@@ -902,12 +923,13 @@ export async function getResearchPacket(userId: string, ticker: string): Promise
export async function getResearchWorkspace(userId: string, ticker: string): Promise<ResearchWorkspace> {
const normalizedTicker = normalizeTicker(ticker);
const [coverage, memo, library, packet, latestFiling] = await Promise.all([
const [coverage, memo, library, packet, latestFiling, copilotSession] = await Promise.all([
getWatchlistItemByTicker(userId, normalizedTicker),
getResearchMemoByTicker(userId, normalizedTicker),
listResearchArtifacts(userId, { ticker: normalizedTicker, limit: 40 }),
getResearchPacket(userId, normalizedTicker),
listFilingsRecords({ ticker: normalizedTicker, limit: 1 })
listFilingsRecords({ ticker: normalizedTicker, limit: 1 }),
getResearchCopilotSessionByTicker(userId, normalizedTicker)
]);
return {
@@ -918,7 +940,8 @@ export async function getResearchWorkspace(userId: string, ticker: string): Prom
memo,
library: library.artifacts,
packet,
availableTags: library.availableTags
availableTags: library.availableTags,
copilotSession
};
}
@@ -1119,4 +1142,3 @@ export async function getResearchArtifactFileResponse(userId: string, id: number
export function rebuildResearchArtifactIndex() {
rebuildArtifactSearchIndex();
}