tasq/lib/utils/subject_suggestions.dart

238 lines
6.0 KiB
Dart

import 'dart:math' as math;
class SubjectSuggestionEngine {
const SubjectSuggestionEngine._();
static List<String> suggest({
required Iterable<String> existingSubjects,
required String query,
int limit = 8,
}) {
final statsByKey = <String, _SubjectStats>{};
for (final raw in existingSubjects) {
final cleaned = normalizeDisplay(raw);
if (cleaned.isEmpty) {
continue;
}
final key = normalizeKey(cleaned);
if (key.isEmpty) {
continue;
}
final stats = statsByKey.putIfAbsent(
key,
() => _SubjectStats(display: cleaned),
);
stats.count += 1;
if (_isBetterDisplay(cleaned, stats.display)) {
stats.display = cleaned;
}
}
if (statsByKey.isEmpty) {
return const <String>[];
}
final cleanedQuery = normalizeDisplay(query);
final queryKey = normalizeKey(cleanedQuery);
final scored =
statsByKey.entries
.map((entry) {
final value = entry.value;
final score = _score(
candidateKey: entry.key,
candidateDisplay: value.display,
count: value.count,
queryKey: queryKey,
);
return _ScoredSubject(subject: value.display, score: score);
})
.where((entry) => entry.score > 0)
.toList()
..sort((a, b) {
final byScore = b.score.compareTo(a.score);
if (byScore != 0) {
return byScore;
}
return a.subject.toLowerCase().compareTo(b.subject.toLowerCase());
});
return scored.take(limit).map((entry) => entry.subject).toList();
}
static String normalizeDisplay(String input) {
final trimmed = input.trim();
if (trimmed.isEmpty) {
return '';
}
final compactWhitespace = trimmed.replaceAll(RegExp(r'\s+'), ' ');
final punctuationSpacing = compactWhitespace
.replaceAll(RegExp(r'\s+([,.;:!?])'), r'$1')
.replaceAll(RegExp(r'([,.;:!?])(\S)'), r'$1 $2')
.replaceAll(RegExp(r'\s+'), ' ')
.trim();
final words = punctuationSpacing.split(' ');
final correctedWords = words.map(_correctWord).toList(growable: false);
final sentence = correctedWords.join(' ').trim();
if (sentence.isEmpty) {
return '';
}
return sentence[0].toUpperCase() + sentence.substring(1);
}
static String normalizeKey(String input) {
final lowered = input.toLowerCase();
return lowered
.replaceAll(RegExp(r'[^a-z0-9\s]'), ' ')
.replaceAll(RegExp(r'\s+'), ' ')
.trim();
}
static double _score({
required String candidateKey,
required String candidateDisplay,
required int count,
required String queryKey,
}) {
final popularity = math.log(count + 1) * 0.1;
if (queryKey.isEmpty) {
return 0.5 + popularity;
}
final startsWith = candidateKey.startsWith(queryKey) ? 1.2 : 0.0;
final contains =
!candidateKey.startsWith(queryKey) && candidateKey.contains(queryKey)
? 0.5
: 0.0;
final vectorSimilarity = _cosineSimilarity(
_tokenVector(candidateKey),
_tokenVector(queryKey),
);
final displayLower = candidateDisplay.toLowerCase();
final queryLower = queryKey.toLowerCase();
final editLikeBoost = displayLower.contains(queryLower) ? 0.25 : 0.0;
return (vectorSimilarity * 2.0) +
startsWith +
contains +
editLikeBoost +
popularity;
}
static Map<String, int> _tokenVector(String input) {
final tokens = input
.split(' ')
.where((token) => token.isNotEmpty)
.toList(growable: false);
final vector = <String, int>{};
for (final token in tokens) {
vector[token] = (vector[token] ?? 0) + 1;
}
return vector;
}
static double _cosineSimilarity(Map<String, int> a, Map<String, int> b) {
if (a.isEmpty || b.isEmpty) {
return 0;
}
var dot = 0.0;
var normA = 0.0;
var normB = 0.0;
for (final entry in a.entries) {
final av = entry.value.toDouble();
normA += av * av;
final bv = b[entry.key]?.toDouble() ?? 0.0;
dot += av * bv;
}
for (final entry in b.entries) {
final bv = entry.value.toDouble();
normB += bv * bv;
}
final denominator = math.sqrt(normA) * math.sqrt(normB);
if (denominator == 0) {
return 0;
}
return dot / denominator;
}
static String _correctWord(String rawWord) {
if (rawWord.isEmpty) {
return rawWord;
}
final punctuationMatch = RegExp(
r'^([^a-zA-Z0-9]*)(.*?)([^a-zA-Z0-9]*)$',
).firstMatch(rawWord);
if (punctuationMatch == null) {
return rawWord;
}
final leading = punctuationMatch.group(1) ?? '';
final core = punctuationMatch.group(2) ?? '';
final trailing = punctuationMatch.group(3) ?? '';
if (core.isEmpty) {
return rawWord;
}
final isAcronym = core.length > 1 && core == core.toUpperCase();
final correctedCore = isAcronym
? core
: core[0].toUpperCase() + core.substring(1).toLowerCase();
return '$leading$correctedCore$trailing';
}
static bool _isBetterDisplay(String candidate, String current) {
if (candidate == current) {
return false;
}
final candidatePenalty = _displayPenalty(candidate);
final currentPenalty = _displayPenalty(current);
if (candidatePenalty != currentPenalty) {
return candidatePenalty < currentPenalty;
}
return candidate.length < current.length;
}
static int _displayPenalty(String value) {
var penalty = 0;
if (value.contains(RegExp(r'\s{2,}'))) {
penalty += 2;
}
if (value == value.toUpperCase()) {
penalty += 1;
}
return penalty;
}
}
class _SubjectStats {
_SubjectStats({required this.display});
String display;
int count = 0;
}
class _ScoredSubject {
_ScoredSubject({required this.subject, required this.score});
final String subject;
final double score;
}