Browse Source

init knn framework

master
Chen Lixiang 1 year ago
parent
commit
5cf9f066df
5 changed files with 117 additions and 111 deletions
  1. +8
    -21
      inc/hnsw.h
  2. +16
    -2
      inc/utils.h
  3. +45
    -21
      src/hnsw.c
  4. +13
    -11
      src/test.c
  5. +35
    -56
      src/utils.c

+ 8
- 21
inc/hnsw.h View File

@ -3,40 +3,27 @@
#include <stdio.h>
#include "utils.h"
typedef struct NSWGraph
{
} NSWGraph;
// you can add more structures here or modify existing structrues.
typedef struct HNSWGraph
{
size_t layers_num; // number of layers
NSWGraph **layers;
} HNSWGraph;
typedef struct HNSWContext
{
size_t dim; // dimension of dataset
size_t len; // size of dataset
VecData *data; // vectors will be loaded into this array
HNSWGraph *graph; // graph of HNSW
} HNSWContext;
// you can define some help functions here
typedef struct {
unsigned int* visited;
unsigned int size;
unsigned int mark;
} VisitedList;
// help functions
// public functions
HNSWContext *init_hnsw_context(const char *file_path, size_t dim, size_t len);
void insert_vec(HNSWContext *ctx, VecData vec);
void approximate_knn(HNSWContext *ctx, VecData *results);
VisitedList* visited_list_new(int size);
void visited_list_reset(VisitedList* vl);
unsigned int visited_list_get_visit_mark(VisitedList* vl);
unsigned int* visited_list_get_visited(VisitedList* vl);
void visited_list_free(VisitedList* vl);
// Please do not modify these function signatures!
// To simply our program, we do not consider reclaiming memory space here.
// Please implement these functions according to HNSW algorithm.
HNSWContext *hnsw_init_context(const char *filename, size_t dim, size_t len);
void hnsw_approximate_knn(HNSWContext *ctx, VecData *q, int *results, int k);

+ 16
- 2
inc/utils.h View File

@ -1,8 +1,22 @@
#pragma once
#include <stdio.h>
#define GLOBAL_DIM 128
typedef struct {
int id;
float* vector;
float* vec;
} VecData;
typedef struct {
FILE* stream;
char *filename;
int offset;
} FileContext;
float vec_dist(VecData x, VecData y);
VecData* fvecs_read(const char* filename, int* bounds, int* num);
FileContext* init_file_context(const char *filename);
void read_4bytes(FileContext* ctx, void* dst);
void read_vec_data(FileContext* ctx, void* dst);
void free_file_context(FileContext* ctx);

+ 45
- 21
src/hnsw.c View File

@ -1,30 +1,54 @@
#include <limits.h>
#include "hnsw.h"
HNSWContext *hnsw_init_context(const char *filename, size_t dim, size_t len)
{
HNSWContext *ctx = (HNSWContext *) malloc(sizeof(HNSWContext));
ctx->dim = dim;
ctx->len = len;
ctx->data = (VecData *) malloc(sizeof(VecData) * len);
VisitedList* visited_list_new(int size) {
VisitedList* vl = malloc(sizeof(VisitedList));
vl->size = size;
vl->mark = 1;
vl->visited = calloc(size, sizeof(unsigned int));
return vl;
}
// init file context
FileContext* f_ctx = init_file_context(filename);
void visited_list_reset(VisitedList* vl) {
if (++vl->mark == 0) {
vl->mark = 1;
memset(vl->visited, 0, sizeof(unsigned int) * vl->size);
for (int i = 0; i < len; i++)
{
ctx->data[i].id = i;
ctx->data[i].vec = (float *) malloc(sizeof(float) * GLOBAL_DIM);
read_vec_data(f_ctx, ctx->data[i].vec);
}
}
unsigned int visited_list_get_visit_mark(VisitedList* vl) {
return vl->mark;
free_file_context(f_ctx);
return ctx;
}
unsigned int* visited_list_get_visited(VisitedList* vl) {
return vl->visited;
}
void hnsw_approximate_knn(HNSWContext *ctx, VecData *q, int *results, int k)
{
// sort existing vectors
for (size_t i = 0; i < k && i < ctx->len - 1; i++)
{
float min_dist = vec_dist(*q, ctx->data[i]);
size_t idx = i;
for (size_t j = i + 1; j < ctx->len; j++)
{
float dist = vec_dist(*q, ctx->data[j]);
if (dist < min_dist)
{
min_dist = dist;
idx = i;
}
}
if (idx != i)
{
VecData tmp = ctx->data[idx];
ctx->data[idx] = ctx->data[i];
ctx->data[i] = tmp;
}
}
void visited_list_free(VisitedList* vl) {
free(vl->visited);
free(vl);
}
// copy results
for (int i = 0; i < k; i++)
{
results[i] = ctx->data[i].id;
}
}

+ 13
- 11
src/test.c View File

@ -1,21 +1,23 @@
#include <stdio.h>
#include <stdlib.h>
#include "hnsw.h"
#include "utils.h"
int main() {
int main(int argc, char *argv[]) {
int num = 0;
VecData* vecs = fvecs_read("../dataset/siftsmall_base.fvecs", NULL, &num);
printf("num: %d\n", num);
printf("id of vector 1: %d\n", vecs[0].id);
printf("test distance: %f\n", vec_dist(vecs[0], vecs[1]));
// Free memory
for (int i = 0; i < num; i++) {
free(vecs[i].vector);
FileContext* f_ctx = init_file_context(argv[1]);
VecData data;
data.vec = malloc(sizeof(float) * GLOBAL_DIM);
for (int i = 0; i < 100; i++)
{
read_vec_data(f_ctx, data.vec);
for (int j = 0; j < GLOBAL_DIM; j++)
{
printf("%f ", data.vec[j]);
}
putchar('\n');
}
free(vecs);
return 0;
}

+ 35
- 56
src/utils.c View File

@ -1,73 +1,52 @@
#include "utils.h"
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
float vec_dist(VecData x, VecData y) {
float vec_dist(VecData x, VecData y)
{
float sum = 0.0;
for (int i = 0; i < 128; i++) {
float diff = x.vector[i] - y.vector[i];
for (size_t i = 0; i < GLOBAL_DIM; i++)
{
float diff = x.vec[i] - y.vec[i];
sum += diff * diff;
}
return sqrt(sum);
return sum;
}
VecData* fvecs_read(const char* filename, int* bounds, int* num) {
FILE* fid = fopen(filename, "rb");
if (fid == NULL) {
FileContext *init_file_context(const char *filename)
{
FileContext *ctx = (FileContext *)malloc(sizeof(FileContext));
ctx->filename = (char *)malloc(strlen(filename) + 1L);
memcpy(ctx->filename, filename, strlen(filename) + 1L);
ctx->stream = fopen(filename, "rb");
if (ctx->stream == NULL)
{
fprintf(stderr, "I/O error : Unable to open the file %s\n", filename);
exit(EXIT_FAILURE);
}
ctx->offset = 0;
return ctx;
}
int d;
fread(&d, sizeof(int), 1, fid);
fseek(fid, 0, SEEK_END);
long file_size = ftell(fid);
fseek(fid, 0, SEEK_SET);
long vec_size = (long) d * sizeof(float);
long vec_count = (file_size - sizeof(int)) / vec_size;
int a = 1;
int bmax = vec_count;
int b = bmax;
if (bounds != NULL) {
if (bounds[1] == 1) {
b = bounds[0];
} else if (bounds[1] == 2) {
a = bounds[0];
b = bounds[1];
}
}
if (a < 1 || b > bmax || b < a) {
VecData* v = NULL;
fclose(fid);
return v;
}
int n = b - a + 1;
fseek(fid, (a - 1) * vec_size, SEEK_SET);
// Read n vectors
VecData* v = malloc(n * sizeof(VecData));
for (int i = 0; i < n; i++) {
VecData vec;
vec.id = i + a;
vec.vector = malloc(d * sizeof(float));
fread(vec.vector, sizeof(float), d, fid);
v[i] = vec;
}
void read_4bytes(FileContext *ctx, void *dst)
{
size_t s = fread(dst, 4L, 1, ctx->stream);
assert(s == 1L);
}
void read_vec_data(FileContext *ctx, void *dst)
{
read_4bytes(ctx, dst);
size_t s = fread(dst, 4L, GLOBAL_DIM, ctx->stream);
assert(s == GLOBAL_DIM);
}
fclose(fid);
if (num != NULL) {
*num = n;
}
return v;
}
void free_file_context(FileContext* ctx)
{
fclose(ctx->stream);
free(ctx->filename);
free(ctx);
}

Loading…
Cancel
Save