feat: Ai会话 & 删除文章

This commit is contained in:
Shu Guang 2025-05-17 22:13:15 +08:00
parent 1b98b1bd66
commit 618a072215
11 changed files with 8527 additions and 8 deletions

View File

@ -0,0 +1,73 @@
package com.example.system.common;
import java.util.Date;
/**
* AI竞赛助手会话实体类
*/
public class AiChat {
private Long chatId; // 会话ID
private Long userId; // 用户ID
private String userQuestion; // 用户问题
private String aiResponse; // AI回复内容
private Date chatTime; // 会话时间
private String competitionType;// 竞赛类型
private Integer isDeleted; // 是否删除(0-未删除1-已删除)
// Getter Setter 方法
public Long getChatId() {
return chatId;
}
public void setChatId(Long chatId) {
this.chatId = chatId;
}
public Long getUserId() {
return userId;
}
public void setUserId(Long userId) {
this.userId = userId;
}
public String getUserQuestion() {
return userQuestion;
}
public void setUserQuestion(String userQuestion) {
this.userQuestion = userQuestion;
}
public String getAiResponse() {
return aiResponse;
}
public void setAiResponse(String aiResponse) {
this.aiResponse = aiResponse;
}
public Date getChatTime() {
return chatTime;
}
public void setChatTime(Date chatTime) {
this.chatTime = chatTime;
}
public String getCompetitionType() {
return competitionType;
}
public void setCompetitionType(String competitionType) {
this.competitionType = competitionType;
}
public Integer getIsDeleted() {
return isDeleted;
}
public void setIsDeleted(Integer isDeleted) {
this.isDeleted = isDeleted;
}
}

View File

@ -0,0 +1,44 @@
package com.example.system.common;
/**
* 通用响应结果类
*/
public class AiResult<T> {
private Integer code;
private String message;
private T data;
private AiResult(Integer code, String message, T data) {
this.code = code;
this.message = message;
this.data = data;
}
public static <T> AiResult<T> success(T data) {
return new AiResult<>(200, "操作成功", data);
}
public static <T> AiResult<T> success(T data, String message) {
return new AiResult<>(200, message, data);
}
public static <T> AiResult<T> error(String message) {
return new AiResult<>(500, message, null);
}
public static <T> AiResult<T> error(Integer code, String message) {
return new AiResult<>(code, message, null);
}
// Getter 方法
public Integer getCode() {
return code;
}
public String getMessage() {
return message;
}
public T getData() {
return data;
}
}

View File

@ -0,0 +1,52 @@
package com.example.system.controller;
import com.example.system.common.AiResult;
import com.example.system.dto.AiChatRequestDTO;
import com.example.system.common.AiChat;
import com.example.system.service.AiChatService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.bind.annotation.*;
import javax.annotation.Resource;
import java.util.List;
/**
* AI竞赛助手控制器
*/
@RestController
@CrossOrigin
@RequestMapping("/api/ai-chat")
public class AiChatController {
private static final Logger logger = LoggerFactory.getLogger(AiChatController.class);
@Resource
private AiChatService aiChatService;
/**
* 与AI助手对话
*
* @param requestDTO 请求DTO
* @return 对话结果
*/
@PostMapping("/chat")
public AiResult<String> chat(@RequestBody AiChatRequestDTO requestDTO) {
logger.info("接收到AI对话请求: userId={}", requestDTO.getUserId());
return aiChatService.chat(requestDTO);
}
/**
* 获取用户最近的会话记录
*
* @param userId 用户ID
* @return 会话记录列表
*/
@GetMapping("/history/{userId}")
public AiResult<List<AiChat>> getRecentChats(@PathVariable Long userId) {
logger.info("获取会话历史: userId={}", userId);
return aiChatService.getRecentChats(userId);
}
}

View File

@ -0,0 +1,4 @@
package com.example.system.controller;
public class AiController {
}

View File

@ -8,10 +8,7 @@ import com.example.system.utils.JWTUtil;
import io.jsonwebtoken.Claims; import io.jsonwebtoken.Claims;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.CrossOrigin; import org.springframework.web.bind.annotation.*;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import java.sql.Date; import java.sql.Date;
@ -71,14 +68,14 @@ public class ArticleController {
} }
//删除文章(根据文章id) //删除文章(根据文章id)
@RequestMapping("/deleteArticleById") @RequestMapping(value = "/deleteArticleById", method = RequestMethod.POST)
public Object deleteArticleById(Integer id, HttpServletRequest request) { public Object deleteArticleById(@RequestParam Integer id) {
log.info("删除自己的文章(根据文章id) id: {}", id); log.info("删除自己的文章(根据文章id) id: {}", id);
if (id <= 0) { if (id == null || id <= 0) {
return Result.fail("文章id不合法"); return Result.fail("文章id不合法");
} }
Boolean ret = articleService.deleteArticleById(id); Boolean ret = articleService.deleteArticleById(id);
return ret; return ret ? Result.success("删除成功") : Result.fail("删除失败");
} }
//根据文章id修改文章只能修改自己的文章 //根据文章id修改文章只能修改自己的文章

View File

@ -0,0 +1,35 @@
package com.example.system.dto;
/**
* AI竞赛助手会话请求DTO
*/
public class AiChatRequestDTO {
private Long userId;
private String question;
private String competitionType;
public Long getUserId() {
return userId;
}
public void setUserId(Long userId) {
this.userId = userId;
}
public String getQuestion() {
return question;
}
public void setQuestion(String question) {
this.question = question;
}
public String getCompetitionType() {
return competitionType;
}
public void setCompetitionType(String competitionType) {
this.competitionType = competitionType;
}
}

View File

@ -0,0 +1,39 @@
package com.example.system.mapper;
import com.example.system.common.AiChat;
import org.apache.ibatis.annotations.Mapper;
import java.util.List;
/**
* AI竞赛助手Mapper接口
*/
@Mapper
public interface AiChatMapper {
/**
* 插入会话记录
*
* @param aiChat 会话记录
* @return 影响的行数
*/
int insert(AiChat aiChat);
/**
* 根据会话ID查询会话记录
*
* @param chatId 会话ID
* @return 会话记录
*/
AiChat selectById(Long chatId);
/**
* 查询用户最近的会话记录
*
* @param userId 用户ID
* @param limit 限制条数
* @return 会话记录列表
*/
List<AiChat> selectRecentByUserId(Long userId, int limit);
}

View File

@ -0,0 +1,31 @@
package com.example.system.service;
import com.example.system.common.AiResult;
import com.example.system.dto.AiChatRequestDTO;
import com.example.system.common.AiChat;
import java.util.List;
/**
* AI竞赛助手服务接口
*/
public interface AiChatService {
/**
* 与AI助手对话
*
* @param requestDTO 请求DTO
* @return 对话结果
*/
AiResult<String> chat(AiChatRequestDTO requestDTO);
/**
* 获取用户最近的会话记录
*
* @param userId 用户ID
* @return 会话记录列表
*/
AiResult<List<AiChat>> getRecentChats(Long userId);
}

View File

@ -0,0 +1,220 @@
package com.example.system.service.impl;
import com.example.system.common.AiResult;
import com.example.system.dto.AiChatRequestDTO;
import com.example.system.common.AiChat;
import com.example.system.mapper.AiChatMapper;
import com.example.system.service.AiChatService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import javax.annotation.Resource;
import java.util.*;
import java.util.concurrent.*;
/**
* AI竞赛助手服务实现类
*/
@Service
public class AiChatServiceImpl implements AiChatService {
private static final Logger logger = LoggerFactory.getLogger(AiChatServiceImpl.class);
@Resource
private AiChatMapper aiChatMapper;
// AI服务相关配置
private static final String API_URL = "https://openai.933999.xyz/v1/chat/completions";
private static final String API_KEY = "sk-1PBIyxIdJ42yyC11XRNqbEXYDt2eZRNVNbd8XxmKjnPXGh5S";
private static final String MODEL = "gpt-4o-mini";
private static final int TIMEOUT_SECONDS = 60;
// 创建RestTemplate
private final RestTemplate restTemplate;
public AiChatServiceImpl() {
org.springframework.http.client.SimpleClientHttpRequestFactory factory =
new org.springframework.http.client.SimpleClientHttpRequestFactory();
factory.setConnectTimeout(30000); // 30秒连接超时
factory.setReadTimeout(60000); // 60秒读取超时
this.restTemplate = new RestTemplate(factory);
}
@Override
public AiResult<String> chat(AiChatRequestDTO requestDTO) {
// 参数校验
if (requestDTO == null || requestDTO.getUserId() == null) {
return AiResult.error("用户ID不能为空");
}
String question = requestDTO.getQuestion();
if (question == null || question.trim().isEmpty()) {
return AiResult.error("问题内容不能为空");
}
// 限制问题长度避免资源浪费
if (question.length() > 2000) {
return AiResult.error("问题内容过长请控制在2000字以内");
}
Long userId = requestDTO.getUserId();
try {
// 1. 获取最近会话记录作为上下文
List<AiChat> recentChats = aiChatMapper.selectRecentByUserId(userId, 5);
// 2. 构建AI请求
String aiResponse = callAiService(question, recentChats);
// 3. 保存会话记录
AiChat chat = new AiChat();
chat.setUserId(userId);
chat.setUserQuestion(question);
chat.setAiResponse(aiResponse);
chat.setChatTime(new Date());
chat.setCompetitionType(requestDTO.getCompetitionType());
chat.setIsDeleted(0);
int rows = aiChatMapper.insert(chat);
if (rows <= 0) {
logger.warn("保存会话记录失败: userId={}", userId);
}
return AiResult.success(aiResponse, "AI助手回复成功");
} catch (Exception e) {
logger.error("AI对话失败", e);
return AiResult.error("AI服务出现错误请稍后再试");
}
}
@Override
public AiResult<List<AiChat>> getRecentChats(Long userId) {
if (userId == null) {
return AiResult.error("用户ID不能为空");
}
try {
List<AiChat> chats = aiChatMapper.selectRecentByUserId(userId, 5);
// 按照时间正序排列旧消息在前新消息在后
Collections.reverse(chats);
return AiResult.success(chats);
} catch (Exception e) {
logger.error("获取会话记录失败", e);
return AiResult.error("获取会话记录失败");
}
}
/**
* 调用AI服务
*
* @param question 用户问题
* @param history 历史会话记录
* @return AI回复
*/
private String callAiService(String question, List<AiChat> history) {
// 系统提示词
String systemPrompt = "你是一个竞赛助手AI专注于帮助学生解决各类竞赛问题。提供专业的解答使用清晰、简洁的语言。" +
"如果涉及代码,请提供详细注释和解释。对于不确定的内容,坦诚表明并提供可能的思路。";
// 构建消息列表
List<Map<String, String>> messages = new ArrayList<>();
// 添加系统消息
Map<String, String> systemMessage = new HashMap<>();
systemMessage.put("role", "system");
systemMessage.put("content", systemPrompt);
messages.add(systemMessage);
// 添加历史记录
if (history != null && !history.isEmpty()) {
for (AiChat chat : history) {
// 用户消息
Map<String, String> userMessage = new HashMap<>();
userMessage.put("role", "user");
userMessage.put("content", chat.getUserQuestion());
messages.add(userMessage);
// AI回复
Map<String, String> aiMessage = new HashMap<>();
aiMessage.put("role", "assistant");
aiMessage.put("content", chat.getAiResponse());
messages.add(aiMessage);
}
}
// 添加当前问题
Map<String, String> currentQuestion = new HashMap<>();
currentQuestion.put("role", "user");
currentQuestion.put("content", question);
messages.add(currentQuestion);
// 构建请求体
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("model", MODEL);
requestBody.put("messages", messages);
requestBody.put("max_tokens", 2000);
requestBody.put("temperature", 0.7);
// 设置HTTP请求头
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.set("Authorization", "Bearer " + API_KEY);
HttpEntity<Map<String, Object>> requestEntity = new HttpEntity<>(requestBody, headers);
// 创建线程池执行请求支持超时控制
ExecutorService executor = Executors.newSingleThreadExecutor();
Future<String> future = executor.submit(() -> {
try {
// 发送请求
long startTime = System.currentTimeMillis();
logger.info("开始请求AI服务: {}", API_URL);
Map<String, Object> responseBody = restTemplate.postForObject(API_URL, requestEntity, Map.class);
long endTime = System.currentTimeMillis();
logger.info("AI服务请求完成耗时: {}ms", (endTime - startTime));
if (responseBody == null) {
return "AI服务无响应请稍后再试";
}
// 解析响应
@SuppressWarnings("unchecked")
List<Map<String, Object>> choices = (List<Map<String, Object>>) responseBody.get("choices");
if (choices != null && !choices.isEmpty()) {
Map<String, Object> choice = choices.get(0);
@SuppressWarnings("unchecked")
Map<String, String> messageObj = (Map<String, String>) choice.get("message");
return messageObj.get("content");
}
return "未能获取有效回复";
} catch (Exception e) {
logger.error("AI服务调用异常", e);
throw new RuntimeException("AI服务调用异常: " + e.getMessage());
}
});
try {
// 设置超时时间
return future.get(TIMEOUT_SECONDS, TimeUnit.SECONDS);
} catch (TimeoutException e) {
logger.error("AI服务请求超时", e);
future.cancel(true);
return "AI服务响应超时请稍后再试";
} catch (Exception e) {
logger.error("获取AI响应失败", e);
future.cancel(true);
return "AI服务暂时不可用请稍后再试";
} finally {
executor.shutdownNow();
}
}
}

View File

@ -0,0 +1,22 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.example.system.mapper.AiChatMapper">
<insert id="insert" parameterType="com.example.system.common.AiChat">
INSERT INTO ai_competition_chat
(chat_id, user_id, user_question, ai_response, chat_time, competition_type, is_deleted)
VALUES
(#{chatId}, #{userId}, #{userQuestion}, #{aiResponse}, #{chatTime}, #{competitionType}, #{isDeleted})
</insert>
<select id="selectById" resultType="com.example.system.common.AiChat">
SELECT * FROM ai_competition_chat WHERE chat_id = #{chatId}
</select>
<select id="selectRecentByUserId" resultType="com.example.system.common.AiChat">
SELECT * FROM ai_competition_chat
WHERE user_id = #{userId}
ORDER BY chat_time DESC
LIMIT #{limit}
</select>
</mapper>

File diff suppressed because it is too large Load Diff