2025-04-13 21:53:03 +08:00

86 lines
4.4 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.yupi.springbootinit.manager;
import com.yupi.springbootinit.common.ErrorCode;
import com.yupi.springbootinit.exception.BusinessException;
import io.github.briqt.spark4j.SparkClient;
import io.github.briqt.spark4j.constant.SparkApiVersion;
import io.github.briqt.spark4j.model.SparkMessage;
import io.github.briqt.spark4j.model.request.SparkRequest;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
@Service
@Slf4j
public class AiManager {
@Resource
private SparkClient sparkClient;
/**
* 向 AI 发送请求
*
* @param isNeedTemplate 是否使用模板,进行 AI 生成; true 使用 、false 不使用 false 的情况是只想用 AI 不只是生成前端代码
* @param content 内容
* 分析需求:
* 分析网站用户的增长情况
* 原始数据:
* 日期,用户数
* 1号,10
* 2号,20
* 3号,30
* @return AI 返回的内容
* '【【【【'
* <p>
* '【【【【'
*/
public String sendMsgToXingHuo(boolean isNeedTemplate, String content) {
List<SparkMessage> messages = new ArrayList<>();
if (isNeedTemplate) {
// AI 生成问题的预设条件
String predefinedInformation = "请严格按照下面的输出格式生成结果,且不得添加任何多余内容(例如无关文字、注释、代码块标记或反引号):\n" +
"\n" +
"'【【【【'" +
"{ 生成 Echarts V5 的 option 配置对象 JSON 代码,要求为合法 JSON 格式且不含任何额外内容(如注释或多余字符) } '【【【【' 结论: {\n" +
"提供对数据的详细分析结论,内容应尽可能准确、详细,不允许添加其他无关文字或注释 }\n" +
"\n" +
"示例: 输入: 分析需求: 分析网站用户增长情况,请使用柱状图展示 原始数据: 日期,用户数 1号,10 2号,20 3号,30\n" +
"\n" +
"期望输出: '【【【【' { \"title\": { \"text\": \"分析网站用户增长情况\" }, \"xAxis\": { \"type\": \"category\", \"data\": [\"1号\", \"2号\", \"3号\"] }, \"yAxis\": { \"type\": \"value\" }, \"series\": [ { \"name\": \"用户数\", \"type\": \"bar\", \"data\": [10, 20, 30] } ] } '【【【【' 从数据看网站用户数由1号的10人增长到2号的20人再到3号的30人呈现出明显的上升趋势。这表明在这段时间内网站用户吸引力增强可能与推广活动、内容更新或其他外部因素有关。";
messages.add(SparkMessage.systemContent(predefinedInformation + "\n" + "----------------------------------"));
}
messages.add(SparkMessage.userContent(content));
// 构造请求
SparkRequest sparkRequest = SparkRequest.builder()
// 消息列表
.messages(messages)
// 模型回答的tokens的最大长度,非必传,取值为[1,4096],默认为2048
.maxTokens(2048)
// 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高 非必传,取值为[0,1],默认为0.5
.temperature(0.6)
// 指定请求版本
.apiVersion(SparkApiVersion.V4_0)
.build();
// 同步调用
String responseContent = sparkClient.chatSync(sparkRequest).getContent().trim();
if (!isNeedTemplate) {
return responseContent;
}
log.info("星火 AI 返回的结果 {}", responseContent);
AtomicInteger atomicInteger = new AtomicInteger(1);
while (responseContent.split("'【【【【'").length < 3) {
responseContent = sparkClient.chatSync(sparkRequest).getContent().trim();
if (atomicInteger.incrementAndGet() >= 4) {
throw new BusinessException(ErrorCode.SYSTEM_ERROR, "星火 AI 生成失败");
}
}
return responseContent;
}
}