250 lines
6.0 KiB
TypeScript
250 lines
6.0 KiB
TypeScript
import { embedMany, generateText } from 'ai';
|
|
import { createZhipu } from 'zhipu-ai-provider';
|
|
|
|
type AiWorkload = 'report' | 'extraction';
|
|
type AiProvider = 'zhipu';
|
|
|
|
type AiConfig = {
|
|
provider: AiProvider;
|
|
apiKey?: string;
|
|
baseUrl: string;
|
|
model: string;
|
|
temperature: number;
|
|
};
|
|
|
|
type EnvSource = Record<string, string | undefined>;
|
|
|
|
type GetAiConfigOptions = {
|
|
env?: EnvSource;
|
|
warn?: (message: string) => void;
|
|
};
|
|
|
|
type AiGenerateInput = {
|
|
model: unknown;
|
|
system?: string;
|
|
prompt: string;
|
|
temperature: number;
|
|
maxRetries?: number;
|
|
};
|
|
|
|
type AiGenerateOutput = {
|
|
text: string;
|
|
};
|
|
|
|
type AiEmbedOutput = {
|
|
embeddings: number[][];
|
|
};
|
|
|
|
type RunAiAnalysisOptions = GetAiConfigOptions & {
|
|
workload?: AiWorkload;
|
|
createModel?: (config: AiConfig) => unknown;
|
|
generate?: (input: AiGenerateInput) => Promise<AiGenerateOutput>;
|
|
};
|
|
|
|
type EmbeddingConfig = {
|
|
provider: AiProvider;
|
|
apiKey?: string;
|
|
baseUrl: string;
|
|
model: 'embedding-3';
|
|
dimensions: 256;
|
|
};
|
|
|
|
type RunAiEmbeddingsOptions = GetAiConfigOptions & {
|
|
createModel?: (config: EmbeddingConfig) => unknown;
|
|
embed?: (input: {
|
|
model: unknown;
|
|
values: string[];
|
|
}) => Promise<AiEmbedOutput>;
|
|
};
|
|
|
|
const CODING_API_BASE_URL = 'https://api.z.ai/api/coding/paas/v4';
|
|
const SEARCH_EMBEDDING_MODEL = 'embedding-3';
|
|
const SEARCH_EMBEDDING_DIMENSIONS = 256;
|
|
|
|
let warnedIgnoredZhipuBaseUrl = false;
|
|
|
|
function envValue(name: string, env: EnvSource = process.env) {
|
|
const value = env[name];
|
|
if (!value) {
|
|
return undefined;
|
|
}
|
|
|
|
const trimmed = value.trim();
|
|
return trimmed.length > 0 ? trimmed : undefined;
|
|
}
|
|
|
|
function parseTemperature(value: string | undefined) {
|
|
const parsed = Number(value);
|
|
if (!Number.isFinite(parsed)) {
|
|
return 0.2;
|
|
}
|
|
|
|
return Math.min(Math.max(parsed, 0), 2);
|
|
}
|
|
|
|
function warnIgnoredZhipuBaseUrl(env: EnvSource, warn: (message: string) => void) {
|
|
if (warnedIgnoredZhipuBaseUrl) {
|
|
return;
|
|
}
|
|
|
|
const configuredBaseUrl = envValue('ZHIPU_BASE_URL', env);
|
|
if (!configuredBaseUrl) {
|
|
return;
|
|
}
|
|
|
|
warnedIgnoredZhipuBaseUrl = true;
|
|
warn(
|
|
`[AI SDK] ZHIPU_BASE_URL is ignored. The Coding API endpoint is hardcoded to ${CODING_API_BASE_URL}.`
|
|
);
|
|
}
|
|
|
|
function defaultCreateModel(config: AiConfig) {
|
|
const zhipu = createZhipu({
|
|
apiKey: config.apiKey,
|
|
baseURL: config.baseUrl
|
|
});
|
|
|
|
return zhipu(config.model);
|
|
}
|
|
|
|
async function defaultGenerate(input: AiGenerateInput): Promise<AiGenerateOutput> {
|
|
const result = await generateText({
|
|
model: input.model as never,
|
|
system: input.system,
|
|
prompt: input.prompt,
|
|
temperature: input.temperature,
|
|
maxRetries: input.maxRetries ?? 0
|
|
});
|
|
|
|
return { text: result.text };
|
|
}
|
|
|
|
function defaultCreateEmbeddingModel(config: EmbeddingConfig) {
|
|
const zhipu = createZhipu({
|
|
apiKey: config.apiKey,
|
|
baseURL: config.baseUrl
|
|
});
|
|
|
|
return zhipu.textEmbeddingModel(config.model, {
|
|
dimensions: config.dimensions
|
|
});
|
|
}
|
|
|
|
async function defaultEmbed(input: {
|
|
model: unknown;
|
|
values: string[];
|
|
}): Promise<AiEmbedOutput> {
|
|
const result = await embedMany({
|
|
model: input.model as never,
|
|
values: input.values,
|
|
maxRetries: 0
|
|
});
|
|
|
|
return { embeddings: result.embeddings as number[][] };
|
|
}
|
|
|
|
export function getAiConfig(options?: GetAiConfigOptions) {
|
|
return getReportAiConfig(options);
|
|
}
|
|
|
|
export function getReportAiConfig(options?: GetAiConfigOptions) {
|
|
const env = options?.env ?? process.env;
|
|
warnIgnoredZhipuBaseUrl(env, options?.warn ?? console.warn);
|
|
|
|
return {
|
|
provider: 'zhipu',
|
|
apiKey: envValue('ZHIPU_API_KEY', env),
|
|
baseUrl: CODING_API_BASE_URL,
|
|
model: envValue('ZHIPU_MODEL', env) ?? 'glm-4.6',
|
|
temperature: parseTemperature(envValue('AI_TEMPERATURE', env))
|
|
} satisfies AiConfig;
|
|
}
|
|
|
|
export function getExtractionAiConfig(options?: GetAiConfigOptions) {
|
|
return {
|
|
...getReportAiConfig(options),
|
|
temperature: 0
|
|
};
|
|
}
|
|
|
|
export function getEmbeddingAiConfig(options?: GetAiConfigOptions) {
|
|
const env = options?.env ?? process.env;
|
|
warnIgnoredZhipuBaseUrl(env, options?.warn ?? console.warn);
|
|
|
|
return {
|
|
provider: 'zhipu',
|
|
apiKey: envValue('ZHIPU_API_KEY', env),
|
|
baseUrl: CODING_API_BASE_URL,
|
|
model: SEARCH_EMBEDDING_MODEL,
|
|
dimensions: SEARCH_EMBEDDING_DIMENSIONS
|
|
} satisfies EmbeddingConfig;
|
|
}
|
|
|
|
export function isAiConfigured(options?: GetAiConfigOptions) {
|
|
const config = getReportAiConfig(options);
|
|
return Boolean(config.apiKey);
|
|
}
|
|
|
|
export async function runAiAnalysis(prompt: string, systemPrompt?: string, options?: RunAiAnalysisOptions) {
|
|
const workload = options?.workload ?? 'report';
|
|
const config = workload === 'extraction'
|
|
? getExtractionAiConfig(options)
|
|
: getReportAiConfig(options);
|
|
|
|
if (!config.apiKey) {
|
|
throw new Error('ZHIPU_API_KEY is required for AI workloads');
|
|
}
|
|
|
|
const createModel = options?.createModel ?? defaultCreateModel;
|
|
const generate = options?.generate ?? defaultGenerate;
|
|
const model = createModel(config);
|
|
|
|
const result = await generate({
|
|
model,
|
|
system: systemPrompt,
|
|
prompt,
|
|
temperature: config.temperature,
|
|
maxRetries: 0
|
|
});
|
|
|
|
const text = result.text.trim();
|
|
if (!text) {
|
|
throw new Error('AI SDK returned an empty response');
|
|
}
|
|
|
|
return {
|
|
provider: config.provider,
|
|
model: config.model,
|
|
text
|
|
};
|
|
}
|
|
|
|
export async function runAiEmbeddings(values: string[], options?: RunAiEmbeddingsOptions) {
|
|
const sanitizedValues = values
|
|
.map((value) => value.trim())
|
|
.filter((value) => value.length > 0);
|
|
|
|
if (sanitizedValues.length === 0) {
|
|
return [];
|
|
}
|
|
|
|
const config = getEmbeddingAiConfig(options);
|
|
if (!config.apiKey) {
|
|
throw new Error('ZHIPU_API_KEY is required for AI workloads');
|
|
}
|
|
|
|
const createModel = options?.createModel ?? defaultCreateEmbeddingModel;
|
|
const embed = options?.embed ?? defaultEmbed;
|
|
const model = createModel(config);
|
|
const result = await embed({
|
|
model,
|
|
values: sanitizedValues
|
|
});
|
|
|
|
return result.embeddings.map((embedding) => embedding.map((value) => Number(value)));
|
|
}
|
|
|
|
export function __resetAiWarningsForTests() {
|
|
warnedIgnoredZhipuBaseUrl = false;
|
|
}
|