238 lines
6.0 KiB
Dart
238 lines
6.0 KiB
Dart
import 'dart:math' as math;
|
|
|
|
class SubjectSuggestionEngine {
|
|
const SubjectSuggestionEngine._();
|
|
|
|
static List<String> suggest({
|
|
required Iterable<String> existingSubjects,
|
|
required String query,
|
|
int limit = 8,
|
|
}) {
|
|
final statsByKey = <String, _SubjectStats>{};
|
|
|
|
for (final raw in existingSubjects) {
|
|
final cleaned = normalizeDisplay(raw);
|
|
if (cleaned.isEmpty) {
|
|
continue;
|
|
}
|
|
final key = normalizeKey(cleaned);
|
|
if (key.isEmpty) {
|
|
continue;
|
|
}
|
|
final stats = statsByKey.putIfAbsent(
|
|
key,
|
|
() => _SubjectStats(display: cleaned),
|
|
);
|
|
stats.count += 1;
|
|
if (_isBetterDisplay(cleaned, stats.display)) {
|
|
stats.display = cleaned;
|
|
}
|
|
}
|
|
|
|
if (statsByKey.isEmpty) {
|
|
return const <String>[];
|
|
}
|
|
|
|
final cleanedQuery = normalizeDisplay(query);
|
|
final queryKey = normalizeKey(cleanedQuery);
|
|
|
|
final scored =
|
|
statsByKey.entries
|
|
.map((entry) {
|
|
final value = entry.value;
|
|
final score = _score(
|
|
candidateKey: entry.key,
|
|
candidateDisplay: value.display,
|
|
count: value.count,
|
|
queryKey: queryKey,
|
|
);
|
|
return _ScoredSubject(subject: value.display, score: score);
|
|
})
|
|
.where((entry) => entry.score > 0)
|
|
.toList()
|
|
..sort((a, b) {
|
|
final byScore = b.score.compareTo(a.score);
|
|
if (byScore != 0) {
|
|
return byScore;
|
|
}
|
|
return a.subject.toLowerCase().compareTo(b.subject.toLowerCase());
|
|
});
|
|
|
|
return scored.take(limit).map((entry) => entry.subject).toList();
|
|
}
|
|
|
|
static String normalizeDisplay(String input) {
|
|
final trimmed = input.trim();
|
|
if (trimmed.isEmpty) {
|
|
return '';
|
|
}
|
|
|
|
final compactWhitespace = trimmed.replaceAll(RegExp(r'\s+'), ' ');
|
|
final punctuationSpacing = compactWhitespace
|
|
.replaceAll(RegExp(r'\s+([,.;:!?])'), r'$1')
|
|
.replaceAll(RegExp(r'([,.;:!?])(\S)'), r'$1 $2')
|
|
.replaceAll(RegExp(r'\s+'), ' ')
|
|
.trim();
|
|
|
|
final words = punctuationSpacing.split(' ');
|
|
final correctedWords = words.map(_correctWord).toList(growable: false);
|
|
final sentence = correctedWords.join(' ').trim();
|
|
|
|
if (sentence.isEmpty) {
|
|
return '';
|
|
}
|
|
|
|
return sentence[0].toUpperCase() + sentence.substring(1);
|
|
}
|
|
|
|
static String normalizeKey(String input) {
|
|
final lowered = input.toLowerCase();
|
|
return lowered
|
|
.replaceAll(RegExp(r'[^a-z0-9\s]'), ' ')
|
|
.replaceAll(RegExp(r'\s+'), ' ')
|
|
.trim();
|
|
}
|
|
|
|
static double _score({
|
|
required String candidateKey,
|
|
required String candidateDisplay,
|
|
required int count,
|
|
required String queryKey,
|
|
}) {
|
|
final popularity = math.log(count + 1) * 0.1;
|
|
|
|
if (queryKey.isEmpty) {
|
|
return 0.5 + popularity;
|
|
}
|
|
|
|
final startsWith = candidateKey.startsWith(queryKey) ? 1.2 : 0.0;
|
|
final contains =
|
|
!candidateKey.startsWith(queryKey) && candidateKey.contains(queryKey)
|
|
? 0.5
|
|
: 0.0;
|
|
|
|
final vectorSimilarity = _cosineSimilarity(
|
|
_tokenVector(candidateKey),
|
|
_tokenVector(queryKey),
|
|
);
|
|
|
|
final displayLower = candidateDisplay.toLowerCase();
|
|
final queryLower = queryKey.toLowerCase();
|
|
final editLikeBoost = displayLower.contains(queryLower) ? 0.25 : 0.0;
|
|
|
|
return (vectorSimilarity * 2.0) +
|
|
startsWith +
|
|
contains +
|
|
editLikeBoost +
|
|
popularity;
|
|
}
|
|
|
|
static Map<String, int> _tokenVector(String input) {
|
|
final tokens = input
|
|
.split(' ')
|
|
.where((token) => token.isNotEmpty)
|
|
.toList(growable: false);
|
|
final vector = <String, int>{};
|
|
for (final token in tokens) {
|
|
vector[token] = (vector[token] ?? 0) + 1;
|
|
}
|
|
return vector;
|
|
}
|
|
|
|
static double _cosineSimilarity(Map<String, int> a, Map<String, int> b) {
|
|
if (a.isEmpty || b.isEmpty) {
|
|
return 0;
|
|
}
|
|
|
|
var dot = 0.0;
|
|
var normA = 0.0;
|
|
var normB = 0.0;
|
|
|
|
for (final entry in a.entries) {
|
|
final av = entry.value.toDouble();
|
|
normA += av * av;
|
|
final bv = b[entry.key]?.toDouble() ?? 0.0;
|
|
dot += av * bv;
|
|
}
|
|
|
|
for (final entry in b.entries) {
|
|
final bv = entry.value.toDouble();
|
|
normB += bv * bv;
|
|
}
|
|
|
|
final denominator = math.sqrt(normA) * math.sqrt(normB);
|
|
if (denominator == 0) {
|
|
return 0;
|
|
}
|
|
|
|
return dot / denominator;
|
|
}
|
|
|
|
static String _correctWord(String rawWord) {
|
|
if (rawWord.isEmpty) {
|
|
return rawWord;
|
|
}
|
|
|
|
final punctuationMatch = RegExp(
|
|
r'^([^a-zA-Z0-9]*)(.*?)([^a-zA-Z0-9]*)$',
|
|
).firstMatch(rawWord);
|
|
if (punctuationMatch == null) {
|
|
return rawWord;
|
|
}
|
|
|
|
final leading = punctuationMatch.group(1) ?? '';
|
|
final core = punctuationMatch.group(2) ?? '';
|
|
final trailing = punctuationMatch.group(3) ?? '';
|
|
|
|
if (core.isEmpty) {
|
|
return rawWord;
|
|
}
|
|
|
|
final isAcronym = core.length > 1 && core == core.toUpperCase();
|
|
final correctedCore = isAcronym
|
|
? core
|
|
: core[0].toUpperCase() + core.substring(1).toLowerCase();
|
|
|
|
return '$leading$correctedCore$trailing';
|
|
}
|
|
|
|
static bool _isBetterDisplay(String candidate, String current) {
|
|
if (candidate == current) {
|
|
return false;
|
|
}
|
|
|
|
final candidatePenalty = _displayPenalty(candidate);
|
|
final currentPenalty = _displayPenalty(current);
|
|
if (candidatePenalty != currentPenalty) {
|
|
return candidatePenalty < currentPenalty;
|
|
}
|
|
|
|
return candidate.length < current.length;
|
|
}
|
|
|
|
static int _displayPenalty(String value) {
|
|
var penalty = 0;
|
|
if (value.contains(RegExp(r'\s{2,}'))) {
|
|
penalty += 2;
|
|
}
|
|
if (value == value.toUpperCase()) {
|
|
penalty += 1;
|
|
}
|
|
return penalty;
|
|
}
|
|
}
|
|
|
|
class _SubjectStats {
|
|
_SubjectStats({required this.display});
|
|
|
|
String display;
|
|
int count = 0;
|
|
}
|
|
|
|
class _ScoredSubject {
|
|
_ScoredSubject({required this.subject, required this.score});
|
|
|
|
final String subject;
|
|
final double score;
|
|
}
|