-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample.lua
More file actions
71 lines (63 loc) · 2.05 KB
/
example.lua
File metadata and controls
71 lines (63 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
-- good resources
-- https://opensearch.org/blog/improving-document-retrieval-with-sparse-semantic-encoders/
-- https://huggingface.co/opensearch-project/opensearch-neural-sparse-encoding-v1
--
-- run with
-- text-embeddings-router --model-id opensearch-project/opensearch-neural-sparse-encoding-v1 --pooling splade
local cjson = require("cjson")
local http = require("socket.http")
local ltn12 = require("ltn12")
local pgmoon = require("pgmoon")
local pgvector = require("./src/pgvector")
local pg = pgmoon.new({
database = "pgvector_example",
user = os.getenv("USER")
})
assert(pg:connect())
assert(pg:query("CREATE EXTENSION IF NOT EXISTS vector"))
assert(pg:query("DROP TABLE IF EXISTS documents"))
assert(pg:query("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding sparsevec(30522))"))
function embed(inputs)
local url = "http://localhost:3000/embed_sparse"
local data = {
inputs = inputs,
}
local headers = {
["Content-Type"] = "application/json"
}
local chunks = {}
local r, c, h = http.request {
method = "POST",
url = url,
headers = headers,
source = ltn12.source.string(cjson.encode(data)),
sink = ltn12.sink.table(chunks)
}
assert(c == 200)
local res = cjson.decode(table.concat(chunks))
local embeddings = {}
for i, item in ipairs(res) do
local embedding = {}
for i, v in ipairs(item) do
embedding[v["index"] + 1] = v["value"]
end
embeddings[i] = embedding
end
return embeddings
end
local documents = {
"The dog is barking",
"The cat is purring",
"The bear is growling"
}
local embeddings = embed(documents)
for i, content in ipairs(documents) do
local embedding = embeddings[i]
assert(pg:query("INSERT INTO documents (content, embedding) VALUES ($1, $2)", content, pgvector.sparsevec(embedding, 30522)))
end
local query = "forest"
local embedding = embed({query})[1]
local res = assert(pg:query("SELECT content FROM documents ORDER BY embedding <#> $1 LIMIT 5", pgvector.sparsevec(embedding, 30522)))
for i, row in ipairs(res) do
print(row["content"])
end