embed method

Future<List<double>> embed(
  1. String text
)

Generate a 384-dim embedding for text.

Returns an empty list if the model is not initialized.

Implementation

Future<List<double>> embed(String text) async {
  if (!_initialized || _tokenizer == null || _session == null) return [];

  try {
    // Tokenize with truncation and padding.
    final encoding = _tokenizer!.encode(text);

    final inputIds = encoding.ids; // Int32List
    final attentionMask = encoding.attentionMask; // Uint8List
    final tokenTypeIds = encoding.typeIds; // Uint8List
    final seqLen = inputIds.length;

    // ONNX Runtime expects Int32List for all inputs.
    final attMaskInt32 = Int32List.fromList(
      attentionMask.map((e) => e.toInt()).toList(),
    );
    final typeIdsInt32 = Int32List.fromList(
      tokenTypeIds.map((e) => e.toInt()).toList(),
    );

    // Create input tensors.
    final inputs = <String, OrtValue>{
      'input_ids': await OrtValue.fromList(inputIds, [1, seqLen]),
      'attention_mask': await OrtValue.fromList(attMaskInt32, [1, seqLen]),
      'token_type_ids': await OrtValue.fromList(typeIdsInt32, [1, seqLen]),
    };

    // Run inference.
    final outputs = await _session!.run(inputs);

    // Dispose input tensors.
    for (final v in inputs.values) {
      await v.dispose();
    }

    // Extract embedding. The model outputs token-level embeddings
    // with shape [1, seqLen, 384] that need mean pooling.
    final outputKey = outputs.keys.first;
    final rawOutput = await outputs[outputKey]!.asList();

    // Dispose output tensors.
    for (final v in outputs.values) {
      await v.dispose();
    }

    // Output is [1, seqLen, 384] — extract the inner lists.
    List<double> embedding;
    if (rawOutput.first is List) {
      final tokenEmbeddings = (rawOutput.first as List).cast<List<dynamic>>();
      embedding = _meanPool(tokenEmbeddings, attentionMask);
    } else {
      // If somehow it's already flat (sentence-level).
      embedding = rawOutput.cast<double>();
    }

    return _normalize(embedding);
  } catch (e) {
    dev.log('Embedding failed: $e', name: 'LocalEmbedding');
    return [];
  }
}