Spaces:
Running
Running
talk-llama : add new example + sync ggml from llama.cpp (#664)
Browse files* talk-llama : talk with LLaMA AI
* talk.llama : disable EOS token
* talk-llama : add README instructions
* ggml : fix build in debug
- .gitignore +3 -0
- Makefile +8 -5
- examples/CMakeLists.txt +1 -0
- examples/talk-llama/.gitignore +2 -0
- examples/talk-llama/CMakeLists.txt +10 -0
- examples/talk-llama/README.md +32 -0
- examples/talk-llama/llama.cpp +1865 -0
- examples/talk-llama/llama.h +153 -0
- examples/talk-llama/speak.sh +20 -0
- examples/talk-llama/talk-llama.cpp +529 -0
- examples/talk/speak.sh +4 -1
- ggml.c +0 -0
- ggml.h +26 -1
- whisper.cpp +22 -35
.gitignore
CHANGED
|
@@ -18,6 +18,7 @@ build-sanitize-thread/
|
|
| 18 |
/stream
|
| 19 |
/command
|
| 20 |
/talk
|
|
|
|
| 21 |
/bench
|
| 22 |
|
| 23 |
arm_neon.h
|
|
@@ -32,3 +33,5 @@ examples/whisper.objc/whisper.objc.xcodeproj/xcuserdata/
|
|
| 32 |
examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
|
| 33 |
|
| 34 |
extra/bench-gg.txt
|
|
|
|
|
|
|
|
|
| 18 |
/stream
|
| 19 |
/command
|
| 20 |
/talk
|
| 21 |
+
/talk-llama
|
| 22 |
/bench
|
| 23 |
|
| 24 |
arm_neon.h
|
|
|
|
| 33 |
examples/whisper.objc/whisper.objc.xcodeproj/project.xcworkspace/xcuserdata
|
| 34 |
|
| 35 |
extra/bench-gg.txt
|
| 36 |
+
|
| 37 |
+
*.mlmodel*
|
Makefile
CHANGED
|
@@ -36,7 +36,7 @@ LDFLAGS =
|
|
| 36 |
|
| 37 |
# ref: https://github.com/ggerganov/whisper.cpp/issues/37
|
| 38 |
ifneq ($(wildcard /usr/include/musl/*),)
|
| 39 |
-
CFLAGS
|
| 40 |
CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
|
| 41 |
endif
|
| 42 |
|
|
@@ -178,7 +178,7 @@ $(info I CC: $(CCV))
|
|
| 178 |
$(info I CXX: $(CXXV))
|
| 179 |
$(info )
|
| 180 |
|
| 181 |
-
default: main
|
| 182 |
|
| 183 |
#
|
| 184 |
# Build library
|
|
@@ -197,7 +197,7 @@ libwhisper.so: ggml.o whisper.o
|
|
| 197 |
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
|
| 198 |
|
| 199 |
clean:
|
| 200 |
-
rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
|
| 201 |
|
| 202 |
#
|
| 203 |
# Examples
|
|
@@ -212,6 +212,9 @@ main: examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o
|
|
| 212 |
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
|
| 213 |
./main -h
|
| 214 |
|
|
|
|
|
|
|
|
|
|
| 215 |
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
|
| 216 |
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
|
| 217 |
|
|
@@ -221,8 +224,8 @@ command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whi
|
|
| 221 |
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
|
| 222 |
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
|
| 223 |
|
| 224 |
-
|
| 225 |
-
$(CXX) $(CXXFLAGS) examples/
|
| 226 |
|
| 227 |
#
|
| 228 |
# Audio samples
|
|
|
|
| 36 |
|
| 37 |
# ref: https://github.com/ggerganov/whisper.cpp/issues/37
|
| 38 |
ifneq ($(wildcard /usr/include/musl/*),)
|
| 39 |
+
CFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
|
| 40 |
CXXFLAGS += -D_POSIX_SOURCE -D_GNU_SOURCE
|
| 41 |
endif
|
| 42 |
|
|
|
|
| 178 |
$(info I CXX: $(CXXV))
|
| 179 |
$(info )
|
| 180 |
|
| 181 |
+
default: main bench
|
| 182 |
|
| 183 |
#
|
| 184 |
# Build library
|
|
|
|
| 197 |
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
|
| 198 |
|
| 199 |
clean:
|
| 200 |
+
rm -f *.o main stream command talk talk-llama bench libwhisper.a libwhisper.so
|
| 201 |
|
| 202 |
#
|
| 203 |
# Examples
|
|
|
|
| 212 |
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
|
| 213 |
./main -h
|
| 214 |
|
| 215 |
+
bench: examples/bench/bench.cpp ggml.o whisper.o
|
| 216 |
+
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
|
| 217 |
+
|
| 218 |
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
|
| 219 |
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
|
| 220 |
|
|
|
|
| 224 |
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
|
| 225 |
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
|
| 226 |
|
| 227 |
+
talk-llama: examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
|
| 228 |
+
$(CXX) $(CXXFLAGS) examples/talk-llama/talk-llama.cpp examples/talk-llama/llama.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk-llama $(CC_SDL) $(LDFLAGS)
|
| 229 |
|
| 230 |
#
|
| 231 |
# Audio samples
|
examples/CMakeLists.txt
CHANGED
|
@@ -63,4 +63,5 @@ else()
|
|
| 63 |
add_subdirectory(command)
|
| 64 |
add_subdirectory(bench)
|
| 65 |
add_subdirectory(talk)
|
|
|
|
| 66 |
endif()
|
|
|
|
| 63 |
add_subdirectory(command)
|
| 64 |
add_subdirectory(bench)
|
| 65 |
add_subdirectory(talk)
|
| 66 |
+
add_subdirectory(talk-llama)
|
| 67 |
endif()
|
examples/talk-llama/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
eleven-labs.py
|
| 2 |
+
audio.mp3
|
examples/talk-llama/CMakeLists.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if (WHISPER_SUPPORT_SDL2)
|
| 2 |
+
# talk-llama
|
| 3 |
+
set(TARGET talk-llama)
|
| 4 |
+
|
| 5 |
+
add_executable(${TARGET} talk-llama.cpp llama.cpp)
|
| 6 |
+
|
| 7 |
+
include(DefaultTargetOptions)
|
| 8 |
+
|
| 9 |
+
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
|
| 10 |
+
endif ()
|
examples/talk-llama/README.md
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# talk-llama
|
| 2 |
+
|
| 3 |
+
Talk with an LLaMA AI in your terminal
|
| 4 |
+
|
| 5 |
+
[Demo Talk](https://user-images.githubusercontent.com/1991296/228024237-848f998c-c334-46a6-bef8-3271590da83b.mp4)
|
| 6 |
+
|
| 7 |
+
## Building
|
| 8 |
+
|
| 9 |
+
The `talk-llama` tool depends on SDL2 library to capture audio from the microphone. You can build it like this:
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
# Install SDL2 on Linux
|
| 13 |
+
sudo apt-get install libsdl2-dev
|
| 14 |
+
|
| 15 |
+
# Install SDL2 on Mac OS
|
| 16 |
+
brew install sdl2
|
| 17 |
+
|
| 18 |
+
# Build the "talk-llama" executable
|
| 19 |
+
make talk-llama
|
| 20 |
+
|
| 21 |
+
# Run it
|
| 22 |
+
./talk-llama -mw ./models/ggml-small.en.bin -ml ../llama.cpp/models/13B/ggml-model-q4_0.bin -p "Georgi" -t 8
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
- The `-mw` argument specifies the Whisper model that you would like to use. Recommended `base` or `small` for real-time experience
|
| 26 |
+
- The `-ml` argument specifies the LLaMA model that you would like to use. Read the instructions in https://github.com/ggerganov/llama.cpp for information about how to obtain a `ggml` compatible LLaMA model
|
| 27 |
+
|
| 28 |
+
## TTS
|
| 29 |
+
|
| 30 |
+
For best experience, this example needs a TTS tool to convert the generated text responses to voice.
|
| 31 |
+
You can use any TTS engine that you would like - simply edit the [speak.sh](speak.sh) script to your needs.
|
| 32 |
+
By default, it is configured to use MacOS's `say`, but you can use whatever you wish.
|
examples/talk-llama/llama.cpp
ADDED
|
@@ -0,0 +1,1865 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama.h"
|
| 2 |
+
|
| 3 |
+
#include "ggml.h"
|
| 4 |
+
|
| 5 |
+
#include <cinttypes>
|
| 6 |
+
#include <fstream>
|
| 7 |
+
#include <random>
|
| 8 |
+
#include <map>
|
| 9 |
+
#include <unordered_map>
|
| 10 |
+
#include <queue>
|
| 11 |
+
#include <regex>
|
| 12 |
+
#include <cassert>
|
| 13 |
+
#include <cstring>
|
| 14 |
+
|
| 15 |
+
#define LLAMA_USE_SCRATCH
|
| 16 |
+
#define LLAMA_MAX_SCRATCH_BUFFERS 16
|
| 17 |
+
|
| 18 |
+
#define LLAMA_ASSERT(x) \
|
| 19 |
+
do { \
|
| 20 |
+
if (!(x)) { \
|
| 21 |
+
fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
|
| 22 |
+
abort(); \
|
| 23 |
+
} \
|
| 24 |
+
} while (0)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
// determine number of model parts based on the dimension
|
| 28 |
+
static const std::unordered_map<int, int> LLAMA_N_PARTS = {
|
| 29 |
+
{ 4096, 1 },
|
| 30 |
+
{ 5120, 2 },
|
| 31 |
+
{ 6656, 4 },
|
| 32 |
+
{ 8192, 8 },
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
// available llama models
|
| 36 |
+
enum e_model {
|
| 37 |
+
MODEL_UNKNOWN,
|
| 38 |
+
MODEL_7B,
|
| 39 |
+
MODEL_13B,
|
| 40 |
+
MODEL_30B,
|
| 41 |
+
MODEL_65B,
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
static const size_t MB = 1024*1024;
|
| 45 |
+
|
| 46 |
+
// computed for n_ctx == 2048
|
| 47 |
+
// TODO: dynamically determine these sizes
|
| 48 |
+
// needs modifications in ggml
|
| 49 |
+
|
| 50 |
+
static const std::map<e_model, size_t> MEM_REQ_SCRATCH0 = {
|
| 51 |
+
{ MODEL_7B, 512ull*MB },
|
| 52 |
+
{ MODEL_13B, 512ull*MB },
|
| 53 |
+
{ MODEL_30B, 512ull*MB },
|
| 54 |
+
{ MODEL_65B, 512ull*MB },
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
static const std::map<e_model, size_t> MEM_REQ_SCRATCH1 = {
|
| 58 |
+
{ MODEL_7B, 512ull*MB },
|
| 59 |
+
{ MODEL_13B, 512ull*MB },
|
| 60 |
+
{ MODEL_30B, 512ull*MB },
|
| 61 |
+
{ MODEL_65B, 512ull*MB },
|
| 62 |
+
};
|
| 63 |
+
|
| 64 |
+
// 2*n_embd*n_ctx*n_layer*sizeof(float16)
|
| 65 |
+
static const std::map<e_model, size_t> MEM_REQ_KV_SELF = {
|
| 66 |
+
{ MODEL_7B, 1026ull*MB },
|
| 67 |
+
{ MODEL_13B, 1608ull*MB },
|
| 68 |
+
{ MODEL_30B, 3124ull*MB },
|
| 69 |
+
{ MODEL_65B, 5120ull*MB },
|
| 70 |
+
};
|
| 71 |
+
|
| 72 |
+
// this is mostly needed for temporary mul_mat buffers to dequantize the data
|
| 73 |
+
// not actually needed if BLAS is disabled
|
| 74 |
+
static const std::map<e_model, size_t> MEM_REQ_EVAL = {
|
| 75 |
+
{ MODEL_7B, 768ull*MB },
|
| 76 |
+
{ MODEL_13B, 1024ull*MB },
|
| 77 |
+
{ MODEL_30B, 1280ull*MB },
|
| 78 |
+
{ MODEL_65B, 1536ull*MB },
|
| 79 |
+
};
|
| 80 |
+
|
| 81 |
+
// default hparams (LLaMA 7B)
|
| 82 |
+
struct llama_hparams {
|
| 83 |
+
int32_t n_vocab = 32000;
|
| 84 |
+
int32_t n_ctx = 512; // this is provided as user input?
|
| 85 |
+
int32_t n_embd = 4096;
|
| 86 |
+
int32_t n_mult = 256;
|
| 87 |
+
int32_t n_head = 32;
|
| 88 |
+
int32_t n_layer = 32;
|
| 89 |
+
int32_t n_rot = 64;
|
| 90 |
+
int32_t f16 = 1;
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
struct llama_layer {
|
| 94 |
+
// normalization
|
| 95 |
+
struct ggml_tensor * attention_norm;
|
| 96 |
+
|
| 97 |
+
// attention
|
| 98 |
+
struct ggml_tensor * wq;
|
| 99 |
+
struct ggml_tensor * wk;
|
| 100 |
+
struct ggml_tensor * wv;
|
| 101 |
+
struct ggml_tensor * wo;
|
| 102 |
+
|
| 103 |
+
// normalization
|
| 104 |
+
struct ggml_tensor * ffn_norm;
|
| 105 |
+
|
| 106 |
+
// ff
|
| 107 |
+
struct ggml_tensor * w1;
|
| 108 |
+
struct ggml_tensor * w2;
|
| 109 |
+
struct ggml_tensor * w3;
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
struct llama_kv_cache {
|
| 113 |
+
struct ggml_tensor * k;
|
| 114 |
+
struct ggml_tensor * v;
|
| 115 |
+
|
| 116 |
+
struct ggml_context * ctx;
|
| 117 |
+
|
| 118 |
+
std::vector<uint8_t> buf;
|
| 119 |
+
|
| 120 |
+
int n; // number of tokens currently in the cache
|
| 121 |
+
};
|
| 122 |
+
|
| 123 |
+
struct llama_model {
|
| 124 |
+
e_model type = MODEL_UNKNOWN;
|
| 125 |
+
|
| 126 |
+
llama_hparams hparams;
|
| 127 |
+
|
| 128 |
+
struct ggml_tensor * tok_embeddings;
|
| 129 |
+
|
| 130 |
+
struct ggml_tensor * norm;
|
| 131 |
+
struct ggml_tensor * output;
|
| 132 |
+
|
| 133 |
+
std::vector<llama_layer> layers;
|
| 134 |
+
|
| 135 |
+
// context
|
| 136 |
+
struct ggml_context * ctx;
|
| 137 |
+
|
| 138 |
+
// key + value cache for the self attention
|
| 139 |
+
// TODO: move to llama_state
|
| 140 |
+
struct llama_kv_cache kv_self;
|
| 141 |
+
|
| 142 |
+
// the model memory buffer
|
| 143 |
+
std::vector<uint8_t> buf;
|
| 144 |
+
|
| 145 |
+
// tensors
|
| 146 |
+
int n_loaded;
|
| 147 |
+
std::unordered_map<std::string, struct ggml_tensor *> tensors;
|
| 148 |
+
};
|
| 149 |
+
|
| 150 |
+
struct llama_vocab {
|
| 151 |
+
using id = int32_t;
|
| 152 |
+
using token = std::string;
|
| 153 |
+
|
| 154 |
+
struct token_score {
|
| 155 |
+
token tok;
|
| 156 |
+
float score;
|
| 157 |
+
};
|
| 158 |
+
|
| 159 |
+
std::unordered_map<token, id> token_to_id;
|
| 160 |
+
std::vector<token_score> id_to_token;
|
| 161 |
+
};
|
| 162 |
+
|
| 163 |
+
struct llama_context {
|
| 164 |
+
std::mt19937 rng;
|
| 165 |
+
|
| 166 |
+
int64_t t_load_us = 0;
|
| 167 |
+
int64_t t_start_us = 0;
|
| 168 |
+
|
| 169 |
+
int64_t t_sample_us = 0;
|
| 170 |
+
int64_t t_eval_us = 0;
|
| 171 |
+
int64_t t_p_eval_us = 0;
|
| 172 |
+
|
| 173 |
+
int32_t n_sample = 0; // number of tokens sampled
|
| 174 |
+
int32_t n_eval = 0; // number of eval calls
|
| 175 |
+
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
| 176 |
+
|
| 177 |
+
llama_model model;
|
| 178 |
+
llama_vocab vocab;
|
| 179 |
+
|
| 180 |
+
size_t mem_per_token = 0;
|
| 181 |
+
|
| 182 |
+
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
| 183 |
+
std::vector<float> logits;
|
| 184 |
+
bool logits_all = false;
|
| 185 |
+
|
| 186 |
+
// input embedding (1-dimensional array: [n_embd])
|
| 187 |
+
std::vector<float> embedding;
|
| 188 |
+
|
| 189 |
+
// memory buffers used to evaluate the model
|
| 190 |
+
// TODO: move in llama_state
|
| 191 |
+
std::vector<uint8_t> buf_compute;
|
| 192 |
+
std::vector<uint8_t> buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS];
|
| 193 |
+
|
| 194 |
+
int buf_last = 0;
|
| 195 |
+
size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
|
| 196 |
+
|
| 197 |
+
void use_buf(struct ggml_context * ctx, int i) {
|
| 198 |
+
#if defined(LLAMA_USE_SCRATCH)
|
| 199 |
+
size_t last_size = 0;
|
| 200 |
+
|
| 201 |
+
if (i == -1) {
|
| 202 |
+
last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, });
|
| 203 |
+
} else {
|
| 204 |
+
auto & buf = buf_scratch[i];
|
| 205 |
+
last_size = ggml_set_scratch(ctx, { 0, buf.size(), buf.data(), });
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
if (buf_last >= 0) {
|
| 209 |
+
buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size);
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
buf_last = i;
|
| 213 |
+
#else
|
| 214 |
+
(void) i;
|
| 215 |
+
(void) ctx;
|
| 216 |
+
#endif
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
size_t get_buf_max_mem(int i) const {
|
| 220 |
+
#if defined(LLAMA_USE_SCRATCH)
|
| 221 |
+
return buf_max_size[i];
|
| 222 |
+
#else
|
| 223 |
+
(void) i;
|
| 224 |
+
return 0;
|
| 225 |
+
#endif
|
| 226 |
+
}
|
| 227 |
+
};
|
| 228 |
+
|
| 229 |
+
//
|
| 230 |
+
// kv cache
|
| 231 |
+
//
|
| 232 |
+
|
| 233 |
+
static bool kv_cache_init(
|
| 234 |
+
const struct llama_hparams & hparams,
|
| 235 |
+
struct llama_kv_cache & cache,
|
| 236 |
+
ggml_type wtype,
|
| 237 |
+
int n_ctx) {
|
| 238 |
+
const int n_embd = hparams.n_embd;
|
| 239 |
+
const int n_layer = hparams.n_layer;
|
| 240 |
+
|
| 241 |
+
const int n_mem = n_layer*n_ctx;
|
| 242 |
+
const int n_elements = n_embd*n_mem;
|
| 243 |
+
|
| 244 |
+
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
|
| 245 |
+
|
| 246 |
+
struct ggml_init_params params;
|
| 247 |
+
params.mem_size = cache.buf.size();
|
| 248 |
+
params.mem_buffer = cache.buf.data();
|
| 249 |
+
|
| 250 |
+
cache.ctx = ggml_init(params);
|
| 251 |
+
|
| 252 |
+
if (!cache.ctx) {
|
| 253 |
+
fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
|
| 254 |
+
return false;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 258 |
+
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
| 259 |
+
|
| 260 |
+
return true;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
static void kv_cache_free(struct llama_kv_cache & cache) {
|
| 264 |
+
if (cache.ctx) {
|
| 265 |
+
ggml_free(cache.ctx);
|
| 266 |
+
cache.ctx = nullptr;
|
| 267 |
+
}
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
struct llama_context_params llama_context_default_params() {
|
| 271 |
+
struct llama_context_params result = {
|
| 272 |
+
/*.n_ctx =*/ 512,
|
| 273 |
+
/*.n_parts =*/ -1,
|
| 274 |
+
/*.seed =*/ 0,
|
| 275 |
+
/*.f16_kv =*/ false,
|
| 276 |
+
/*.logits_all =*/ false,
|
| 277 |
+
/*.vocab_only =*/ false,
|
| 278 |
+
/*.use_mlock =*/ false,
|
| 279 |
+
/*.embedding =*/ false,
|
| 280 |
+
/*.progress_callback =*/ nullptr,
|
| 281 |
+
/*.progress_callback_user_data =*/ nullptr,
|
| 282 |
+
};
|
| 283 |
+
|
| 284 |
+
return result;
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
//
|
| 288 |
+
// model loading
|
| 289 |
+
//
|
| 290 |
+
|
| 291 |
+
static bool llama_model_load(
|
| 292 |
+
const std::string & fname,
|
| 293 |
+
llama_context & lctx,
|
| 294 |
+
int n_ctx,
|
| 295 |
+
int n_parts,
|
| 296 |
+
ggml_type memory_type,
|
| 297 |
+
bool vocab_only,
|
| 298 |
+
llama_progress_callback progress_callback,
|
| 299 |
+
void *progress_callback_user_data) {
|
| 300 |
+
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
| 301 |
+
|
| 302 |
+
const int64_t t_start_us = ggml_time_us();
|
| 303 |
+
|
| 304 |
+
lctx.t_start_us = t_start_us;
|
| 305 |
+
|
| 306 |
+
std::vector<char> f_buf(1024*1024);
|
| 307 |
+
|
| 308 |
+
auto & model = lctx.model;
|
| 309 |
+
auto & vocab = lctx.vocab;
|
| 310 |
+
|
| 311 |
+
auto fin = std::ifstream(fname, std::ios::binary);
|
| 312 |
+
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
|
| 313 |
+
if (!fin) {
|
| 314 |
+
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
|
| 315 |
+
return false;
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
// verify magic
|
| 319 |
+
{
|
| 320 |
+
uint32_t magic;
|
| 321 |
+
fin.read((char *) &magic, sizeof(magic));
|
| 322 |
+
if (magic == LLAMA_FILE_MAGIC_UNVERSIONED) {
|
| 323 |
+
fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n",
|
| 324 |
+
__func__, fname.c_str());
|
| 325 |
+
return false;
|
| 326 |
+
}
|
| 327 |
+
if (magic != LLAMA_FILE_MAGIC) {
|
| 328 |
+
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
|
| 329 |
+
return false;
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
uint32_t format_version;
|
| 333 |
+
fin.read((char *) &format_version, sizeof(format_version));
|
| 334 |
+
|
| 335 |
+
if (format_version != LLAMA_FILE_VERSION) {
|
| 336 |
+
fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ", expected %d)\n",
|
| 337 |
+
__func__, fname.c_str(), format_version, LLAMA_FILE_VERSION);
|
| 338 |
+
return false;
|
| 339 |
+
}
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
int n_ff = 0;
|
| 343 |
+
|
| 344 |
+
// load hparams
|
| 345 |
+
{
|
| 346 |
+
auto & hparams = model.hparams;
|
| 347 |
+
|
| 348 |
+
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
| 349 |
+
//fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
|
| 350 |
+
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
|
| 351 |
+
fin.read((char *) &hparams.n_mult, sizeof(hparams.n_mult));
|
| 352 |
+
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
|
| 353 |
+
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
|
| 354 |
+
fin.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
|
| 355 |
+
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
|
| 356 |
+
|
| 357 |
+
hparams.n_ctx = n_ctx;
|
| 358 |
+
|
| 359 |
+
n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
|
| 360 |
+
|
| 361 |
+
if (n_parts < 1) {
|
| 362 |
+
n_parts = LLAMA_N_PARTS.at(hparams.n_embd);
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
// temp warning to tell the user to use "--n_parts"
|
| 366 |
+
if (hparams.f16 == 4 && n_parts != 1) {
|
| 367 |
+
fprintf(stderr, "%s: GPTQ model detected - are you sure n_parts should be %d? we normally expect it to be 1\n", __func__, n_parts);
|
| 368 |
+
fprintf(stderr, "%s: use '--n_parts 1' if necessary\n", __func__);
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
if (hparams.n_layer == 32) {
|
| 372 |
+
model.type = e_model::MODEL_7B;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
if (hparams.n_layer == 40) {
|
| 376 |
+
model.type = e_model::MODEL_13B;
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
if (hparams.n_layer == 60) {
|
| 380 |
+
model.type = e_model::MODEL_30B;
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
if (hparams.n_layer == 80) {
|
| 384 |
+
model.type = e_model::MODEL_65B;
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
| 388 |
+
fprintf(stderr, "%s: n_ctx = %d\n", __func__, hparams.n_ctx);
|
| 389 |
+
fprintf(stderr, "%s: n_embd = %d\n", __func__, hparams.n_embd);
|
| 390 |
+
fprintf(stderr, "%s: n_mult = %d\n", __func__, hparams.n_mult);
|
| 391 |
+
fprintf(stderr, "%s: n_head = %d\n", __func__, hparams.n_head);
|
| 392 |
+
fprintf(stderr, "%s: n_layer = %d\n", __func__, hparams.n_layer);
|
| 393 |
+
fprintf(stderr, "%s: n_rot = %d\n", __func__, hparams.n_rot);
|
| 394 |
+
fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16);
|
| 395 |
+
fprintf(stderr, "%s: n_ff = %d\n", __func__, n_ff);
|
| 396 |
+
fprintf(stderr, "%s: n_parts = %d\n", __func__, n_parts);
|
| 397 |
+
fprintf(stderr, "%s: type = %d\n", __func__, model.type);
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
// load vocab
|
| 401 |
+
{
|
| 402 |
+
std::string word;
|
| 403 |
+
vocab.id_to_token.resize(model.hparams.n_vocab);
|
| 404 |
+
std::vector<char> tmp(64);
|
| 405 |
+
|
| 406 |
+
for (int i = 0; i < model.hparams.n_vocab; i++) {
|
| 407 |
+
uint32_t len;
|
| 408 |
+
fin.read((char *) &len, sizeof(len));
|
| 409 |
+
|
| 410 |
+
word.resize(len);
|
| 411 |
+
if (len > 0) {
|
| 412 |
+
tmp.resize(len);
|
| 413 |
+
fin.read(tmp.data(), len);
|
| 414 |
+
word.assign(tmp.data(), len);
|
| 415 |
+
} else {
|
| 416 |
+
word.clear();
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
float score;
|
| 420 |
+
fin.read((char *) &score, sizeof(score));
|
| 421 |
+
|
| 422 |
+
vocab.token_to_id[word] = i;
|
| 423 |
+
|
| 424 |
+
auto &tok_score = vocab.id_to_token[i];
|
| 425 |
+
tok_score.tok = word;
|
| 426 |
+
tok_score.score = score;
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
if (vocab_only) {
|
| 431 |
+
return true;
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
|
| 435 |
+
// in order to save memory and also to speed up the computation
|
| 436 |
+
// wtype is for per-layer weights, while vtype is for other weights
|
| 437 |
+
ggml_type wtype, vtype;
|
| 438 |
+
switch (model.hparams.f16) {
|
| 439 |
+
case 0: wtype = vtype = GGML_TYPE_F32; break;
|
| 440 |
+
case 1: wtype = vtype = GGML_TYPE_F16; break;
|
| 441 |
+
case 2: wtype = vtype = GGML_TYPE_Q4_0; break;
|
| 442 |
+
case 3: wtype = vtype = GGML_TYPE_Q4_1; break;
|
| 443 |
+
case 4: wtype = GGML_TYPE_Q4_1; vtype = GGML_TYPE_F16; break;
|
| 444 |
+
default:
|
| 445 |
+
{
|
| 446 |
+
fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n",
|
| 447 |
+
__func__, fname.c_str(), model.hparams.f16);
|
| 448 |
+
return false;
|
| 449 |
+
}
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
auto & ctx = model.ctx;
|
| 453 |
+
|
| 454 |
+
size_t ctx_size = 0;
|
| 455 |
+
|
| 456 |
+
{
|
| 457 |
+
const auto & hparams = model.hparams;
|
| 458 |
+
|
| 459 |
+
const int n_embd = hparams.n_embd;
|
| 460 |
+
const int n_layer = hparams.n_layer;
|
| 461 |
+
const int n_ctx = hparams.n_ctx;
|
| 462 |
+
const int n_vocab = hparams.n_vocab;
|
| 463 |
+
|
| 464 |
+
ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // tok_embeddings
|
| 465 |
+
|
| 466 |
+
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm
|
| 467 |
+
|
| 468 |
+
ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // output
|
| 469 |
+
|
| 470 |
+
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
|
| 471 |
+
|
| 472 |
+
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq
|
| 473 |
+
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk
|
| 474 |
+
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv
|
| 475 |
+
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo
|
| 476 |
+
|
| 477 |
+
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm
|
| 478 |
+
|
| 479 |
+
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1
|
| 480 |
+
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
|
| 481 |
+
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
|
| 482 |
+
|
| 483 |
+
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_k
|
| 484 |
+
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_v
|
| 485 |
+
|
| 486 |
+
ctx_size += (5 + 10*n_layer)*256; // object overhead
|
| 487 |
+
|
| 488 |
+
fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
// print memory requirements
|
| 492 |
+
{
|
| 493 |
+
const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
|
| 494 |
+
|
| 495 |
+
// this is the total memory required to run the inference
|
| 496 |
+
const size_t mem_required =
|
| 497 |
+
ctx_size +
|
| 498 |
+
MEM_REQ_SCRATCH0.at(model.type) +
|
| 499 |
+
MEM_REQ_SCRATCH1.at(model.type) +
|
| 500 |
+
MEM_REQ_EVAL.at (model.type);
|
| 501 |
+
|
| 502 |
+
// this is the memory required by one llama_state
|
| 503 |
+
const size_t mem_required_state =
|
| 504 |
+
scale*MEM_REQ_KV_SELF.at(model.type);
|
| 505 |
+
|
| 506 |
+
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
|
| 507 |
+
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
// create the ggml context
|
| 511 |
+
{
|
| 512 |
+
lctx.model.buf.resize(ctx_size);
|
| 513 |
+
|
| 514 |
+
struct ggml_init_params params = {
|
| 515 |
+
/*.mem_size =*/ lctx.model.buf.size(),
|
| 516 |
+
/*.mem_buffer =*/ lctx.model.buf.data(),
|
| 517 |
+
};
|
| 518 |
+
|
| 519 |
+
model.ctx = ggml_init(params);
|
| 520 |
+
if (!model.ctx) {
|
| 521 |
+
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
|
| 522 |
+
return false;
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
// prepare memory for the weights
|
| 527 |
+
{
|
| 528 |
+
const auto & hparams = model.hparams;
|
| 529 |
+
|
| 530 |
+
const int n_embd = hparams.n_embd;
|
| 531 |
+
const int n_layer = hparams.n_layer;
|
| 532 |
+
const int n_vocab = hparams.n_vocab;
|
| 533 |
+
|
| 534 |
+
model.layers.resize(n_layer);
|
| 535 |
+
|
| 536 |
+
model.tok_embeddings = ggml_new_tensor_2d(ctx, vtype, n_embd, n_vocab);
|
| 537 |
+
|
| 538 |
+
model.norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
| 539 |
+
model.output = ggml_new_tensor_2d(ctx, vtype, n_embd, n_vocab);
|
| 540 |
+
|
| 541 |
+
// map by name
|
| 542 |
+
model.tensors["tok_embeddings.weight"] = model.tok_embeddings;
|
| 543 |
+
|
| 544 |
+
model.tensors["norm.weight"] = model.norm;
|
| 545 |
+
model.tensors["output.weight"] = model.output;
|
| 546 |
+
|
| 547 |
+
for (int i = 0; i < n_layer; ++i) {
|
| 548 |
+
auto & layer = model.layers[i];
|
| 549 |
+
|
| 550 |
+
layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
| 551 |
+
|
| 552 |
+
layer.wq = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
| 553 |
+
layer.wk = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
| 554 |
+
layer.wv = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
| 555 |
+
layer.wo = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
| 556 |
+
|
| 557 |
+
layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
| 558 |
+
|
| 559 |
+
layer.w1 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff);
|
| 560 |
+
layer.w2 = ggml_new_tensor_2d(ctx, wtype, n_ff, n_embd);
|
| 561 |
+
layer.w3 = ggml_new_tensor_2d(ctx, wtype, n_embd, n_ff);
|
| 562 |
+
|
| 563 |
+
// map by name
|
| 564 |
+
model.tensors["layers." + std::to_string(i) + ".attention_norm.weight"] = layer.attention_norm;
|
| 565 |
+
|
| 566 |
+
model.tensors["layers." + std::to_string(i) + ".attention.wq.weight"] = layer.wq;
|
| 567 |
+
model.tensors["layers." + std::to_string(i) + ".attention.wk.weight"] = layer.wk;
|
| 568 |
+
model.tensors["layers." + std::to_string(i) + ".attention.wv.weight"] = layer.wv;
|
| 569 |
+
model.tensors["layers." + std::to_string(i) + ".attention.wo.weight"] = layer.wo;
|
| 570 |
+
|
| 571 |
+
model.tensors["layers." + std::to_string(i) + ".ffn_norm.weight"] = layer.ffn_norm;
|
| 572 |
+
|
| 573 |
+
model.tensors["layers." + std::to_string(i) + ".feed_forward.w1.weight"] = layer.w1;
|
| 574 |
+
model.tensors["layers." + std::to_string(i) + ".feed_forward.w2.weight"] = layer.w2;
|
| 575 |
+
model.tensors["layers." + std::to_string(i) + ".feed_forward.w3.weight"] = layer.w3;
|
| 576 |
+
}
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
const size_t file_offset = fin.tellg();
|
| 580 |
+
|
| 581 |
+
fin.close();
|
| 582 |
+
|
| 583 |
+
std::vector<uint8_t> tmp;
|
| 584 |
+
|
| 585 |
+
if (progress_callback) {
|
| 586 |
+
progress_callback(0.0, progress_callback_user_data);
|
| 587 |
+
}
|
| 588 |
+
|
| 589 |
+
for (int i = 0; i < n_parts; ++i) {
|
| 590 |
+
const int part_id = i;
|
| 591 |
+
//const int part_id = n_parts - i - 1;
|
| 592 |
+
|
| 593 |
+
std::string fname_part = fname;
|
| 594 |
+
if (i > 0) {
|
| 595 |
+
fname_part += "." + std::to_string(i);
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
fprintf(stderr, "%s: loading model part %d/%d from '%s'\n", __func__, i+1, n_parts, fname_part.c_str());
|
| 599 |
+
|
| 600 |
+
fin = std::ifstream(fname_part, std::ios::binary);
|
| 601 |
+
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
|
| 602 |
+
|
| 603 |
+
fin.seekg(0, fin.end);
|
| 604 |
+
const size_t file_size = fin.tellg();
|
| 605 |
+
|
| 606 |
+
fin.seekg(file_offset);
|
| 607 |
+
|
| 608 |
+
// load weights
|
| 609 |
+
{
|
| 610 |
+
size_t total_size = 0;
|
| 611 |
+
|
| 612 |
+
model.n_loaded = 0;
|
| 613 |
+
|
| 614 |
+
fprintf(stderr, "%s: ", __func__);
|
| 615 |
+
|
| 616 |
+
while (true) {
|
| 617 |
+
int32_t n_dims;
|
| 618 |
+
int32_t length;
|
| 619 |
+
int32_t ftype;
|
| 620 |
+
|
| 621 |
+
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
| 622 |
+
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
| 623 |
+
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
| 624 |
+
|
| 625 |
+
if (fin.eof()) {
|
| 626 |
+
break;
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
int32_t nelements = 1;
|
| 630 |
+
int32_t ne[2] = { 1, 1 };
|
| 631 |
+
for (int i = 0; i < n_dims; ++i) {
|
| 632 |
+
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
| 633 |
+
nelements *= ne[i];
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
std::string name(length, 0);
|
| 637 |
+
fin.read(&name[0], length);
|
| 638 |
+
|
| 639 |
+
if (model.tensors.find(name.data()) == model.tensors.end()) {
|
| 640 |
+
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
| 641 |
+
return false;
|
| 642 |
+
}
|
| 643 |
+
|
| 644 |
+
// split_type = 0: split by columns
|
| 645 |
+
// split_type = 1: split by rows
|
| 646 |
+
int split_type = 0;
|
| 647 |
+
|
| 648 |
+
// split_type = 0:
|
| 649 |
+
// regex:
|
| 650 |
+
// - tok_embeddings.*
|
| 651 |
+
// - layers.*.attention.wo.weight
|
| 652 |
+
// - layers.*.feed_forward.w2.weight
|
| 653 |
+
|
| 654 |
+
// split_type = 1:
|
| 655 |
+
// regex:
|
| 656 |
+
// - output.*
|
| 657 |
+
// - layers.*.attention.wq.weight
|
| 658 |
+
// - layers.*.attention.wk.weight
|
| 659 |
+
// - layers.*.attention.wv.weight
|
| 660 |
+
// - layers.*.feed_forward.w1.weight
|
| 661 |
+
// - layers.*.feed_forward.w3.weight
|
| 662 |
+
if (name.find("tok_embeddings") != std::string::npos) {
|
| 663 |
+
split_type = 0;
|
| 664 |
+
} else if (name.find("layers") != std::string::npos) {
|
| 665 |
+
if (name.find("attention.wo.weight") != std::string::npos) {
|
| 666 |
+
split_type = 0;
|
| 667 |
+
} else if (name.find("feed_forward.w2.weight") != std::string::npos) {
|
| 668 |
+
split_type = 0;
|
| 669 |
+
} else {
|
| 670 |
+
split_type = 1;
|
| 671 |
+
}
|
| 672 |
+
} else if (name.find("output") != std::string::npos) {
|
| 673 |
+
split_type = 1;
|
| 674 |
+
}
|
| 675 |
+
|
| 676 |
+
auto tensor = model.tensors[name.data()];
|
| 677 |
+
|
| 678 |
+
if (n_dims == 1) {
|
| 679 |
+
if (ggml_nelements(tensor) != nelements) {
|
| 680 |
+
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
| 681 |
+
return false;
|
| 682 |
+
}
|
| 683 |
+
} else {
|
| 684 |
+
if (ggml_nelements(tensor)/n_parts != nelements) {
|
| 685 |
+
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
| 686 |
+
return false;
|
| 687 |
+
}
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
if (n_dims == 1) {
|
| 691 |
+
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
| 692 |
+
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
| 693 |
+
__func__, name.data(), tensor->ne[0], tensor->ne[1], ne[0], ne[1]);
|
| 694 |
+
return false;
|
| 695 |
+
}
|
| 696 |
+
} else {
|
| 697 |
+
if (split_type == 0) {
|
| 698 |
+
if (tensor->ne[0]/n_parts != ne[0] || tensor->ne[1] != ne[1]) {
|
| 699 |
+
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
| 700 |
+
__func__, name.data(), tensor->ne[0]/n_parts, tensor->ne[1], ne[0], ne[1]);
|
| 701 |
+
return false;
|
| 702 |
+
}
|
| 703 |
+
} else {
|
| 704 |
+
if (tensor->ne[0] != ne[0] || tensor->ne[1]/n_parts != ne[1]) {
|
| 705 |
+
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
| 706 |
+
__func__, name.data(), tensor->ne[0], tensor->ne[1]/n_parts, ne[0], ne[1]);
|
| 707 |
+
return false;
|
| 708 |
+
}
|
| 709 |
+
}
|
| 710 |
+
}
|
| 711 |
+
|
| 712 |
+
if (0) {
|
| 713 |
+
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
|
| 714 |
+
fprintf(stderr, "%24s - [%5d, %5d], type = %6s, split = %d\n", name.data(), ne[0], ne[1], ftype_str[ftype], split_type);
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
size_t bpe = 0;
|
| 718 |
+
|
| 719 |
+
switch (ftype) {
|
| 720 |
+
case 0: bpe = ggml_type_size(GGML_TYPE_F32); break;
|
| 721 |
+
case 1: bpe = ggml_type_size(GGML_TYPE_F16); break;
|
| 722 |
+
case 2: bpe = ggml_type_size(GGML_TYPE_Q4_0); assert(ne[0] % 64 == 0); break;
|
| 723 |
+
case 3: bpe = ggml_type_size(GGML_TYPE_Q4_1); assert(ne[0] % 64 == 0); break;
|
| 724 |
+
default:
|
| 725 |
+
{
|
| 726 |
+
fprintf(stderr, "%s: unknown ftype %d in model file\n", __func__, ftype);
|
| 727 |
+
return false;
|
| 728 |
+
}
|
| 729 |
+
};
|
| 730 |
+
|
| 731 |
+
if (n_dims == 1 || n_parts == 1) {
|
| 732 |
+
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
| 733 |
+
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
| 734 |
+
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
| 735 |
+
return false;
|
| 736 |
+
}
|
| 737 |
+
|
| 738 |
+
if (part_id == 0) {
|
| 739 |
+
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
|
| 740 |
+
} else {
|
| 741 |
+
fin.seekg(ggml_nbytes(tensor), std::ios::cur);
|
| 742 |
+
}
|
| 743 |
+
|
| 744 |
+
total_size += ggml_nbytes(tensor);
|
| 745 |
+
} else {
|
| 746 |
+
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)/n_parts) {
|
| 747 |
+
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
| 748 |
+
__func__, name.data(), ggml_nbytes(tensor)/n_parts, nelements*bpe);
|
| 749 |
+
return false;
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
if (split_type == 0) {
|
| 753 |
+
const int np0 = ne[0];
|
| 754 |
+
|
| 755 |
+
const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
|
| 756 |
+
assert(row_size == tensor->nb[1]);
|
| 757 |
+
|
| 758 |
+
for (int i1 = 0; i1 < ne[1]; ++i1) {
|
| 759 |
+
const size_t offset_row = i1*row_size;
|
| 760 |
+
const size_t offset = offset_row + ((part_id*np0)/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
|
| 761 |
+
fin.read(reinterpret_cast<char *>(tensor->data) + offset, row_size/n_parts);
|
| 762 |
+
}
|
| 763 |
+
} else {
|
| 764 |
+
const int np1 = ne[1];
|
| 765 |
+
|
| 766 |
+
const size_t row_size = (tensor->ne[0]/ggml_blck_size(tensor->type))*ggml_type_size(tensor->type);
|
| 767 |
+
|
| 768 |
+
for (int i1 = 0; i1 < ne[1]; ++i1) {
|
| 769 |
+
const size_t offset_row = (i1 + part_id*np1)*row_size;
|
| 770 |
+
fin.read(reinterpret_cast<char *>(tensor->data) + offset_row, row_size);
|
| 771 |
+
}
|
| 772 |
+
}
|
| 773 |
+
|
| 774 |
+
total_size += ggml_nbytes(tensor)/n_parts;
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
//fprintf(stderr, "%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
|
| 778 |
+
model.n_loaded++;
|
| 779 |
+
|
| 780 |
+
// progress
|
| 781 |
+
if (progress_callback) {
|
| 782 |
+
double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset);
|
| 783 |
+
double current_progress = (double(i) + current_file_progress) / double(n_parts);
|
| 784 |
+
progress_callback(current_progress, progress_callback_user_data);
|
| 785 |
+
}
|
| 786 |
+
if (model.n_loaded % 8 == 0) {
|
| 787 |
+
fprintf(stderr, ".");
|
| 788 |
+
fflush(stderr);
|
| 789 |
+
}
|
| 790 |
+
}
|
| 791 |
+
|
| 792 |
+
fprintf(stderr, " done\n");
|
| 793 |
+
|
| 794 |
+
fprintf(stderr, "%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, model.n_loaded);
|
| 795 |
+
if (model.n_loaded == 0) {
|
| 796 |
+
fprintf(stderr, "%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
| 797 |
+
} else if (model.n_loaded != (int) model.tensors.size()) {
|
| 798 |
+
fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded);
|
| 799 |
+
return false;
|
| 800 |
+
}
|
| 801 |
+
}
|
| 802 |
+
|
| 803 |
+
fin.close();
|
| 804 |
+
}
|
| 805 |
+
|
| 806 |
+
lctx.t_load_us = ggml_time_us() - t_start_us;
|
| 807 |
+
|
| 808 |
+
if (progress_callback) {
|
| 809 |
+
progress_callback(1.0, progress_callback_user_data);
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
return true;
|
| 813 |
+
}
|
| 814 |
+
|
| 815 |
+
// evaluate the transformer
|
| 816 |
+
//
|
| 817 |
+
// - lctx: llama context
|
| 818 |
+
// - tokens: new batch of tokens to process
|
| 819 |
+
// - n_past: the context size so far
|
| 820 |
+
// - n_threads: number of threads to use
|
| 821 |
+
//
|
| 822 |
+
static bool llama_eval_internal(
|
| 823 |
+
llama_context & lctx,
|
| 824 |
+
const llama_token * tokens,
|
| 825 |
+
const int n_tokens,
|
| 826 |
+
const int n_past,
|
| 827 |
+
const int n_threads) {
|
| 828 |
+
const int64_t t_start_us = ggml_time_us();
|
| 829 |
+
|
| 830 |
+
const int N = n_tokens;
|
| 831 |
+
|
| 832 |
+
const auto & model = lctx.model;
|
| 833 |
+
const auto & hparams = model.hparams;
|
| 834 |
+
|
| 835 |
+
auto & kv_self = model.kv_self;
|
| 836 |
+
|
| 837 |
+
LLAMA_ASSERT(!!kv_self.ctx);
|
| 838 |
+
|
| 839 |
+
const int n_embd = hparams.n_embd;
|
| 840 |
+
const int n_layer = hparams.n_layer;
|
| 841 |
+
const int n_ctx = hparams.n_ctx;
|
| 842 |
+
const int n_head = hparams.n_head;
|
| 843 |
+
const int n_vocab = hparams.n_vocab;
|
| 844 |
+
const int n_rot = hparams.n_embd/hparams.n_head;
|
| 845 |
+
|
| 846 |
+
auto & mem_per_token = lctx.mem_per_token;
|
| 847 |
+
auto & buf_compute = lctx.buf_compute;
|
| 848 |
+
|
| 849 |
+
struct ggml_init_params params = {
|
| 850 |
+
/*.mem_size =*/ buf_compute.size(),
|
| 851 |
+
/*.mem_buffer =*/ buf_compute.data(),
|
| 852 |
+
};
|
| 853 |
+
|
| 854 |
+
struct ggml_context * ctx0 = ggml_init(params);
|
| 855 |
+
|
| 856 |
+
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
| 857 |
+
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
| 858 |
+
ggml_cgraph gf = {};
|
| 859 |
+
gf.n_threads = N > 255 && ggml_cpu_has_blas() ? 1 : n_threads;
|
| 860 |
+
|
| 861 |
+
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
| 862 |
+
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
| 863 |
+
|
| 864 |
+
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
|
| 865 |
+
|
| 866 |
+
for (int il = 0; il < n_layer; ++il) {
|
| 867 |
+
struct ggml_tensor * inpSA = inpL;
|
| 868 |
+
|
| 869 |
+
struct ggml_tensor * cur;
|
| 870 |
+
|
| 871 |
+
lctx.use_buf(ctx0, 0);
|
| 872 |
+
|
| 873 |
+
// norm
|
| 874 |
+
{
|
| 875 |
+
cur = ggml_rms_norm(ctx0, inpL);
|
| 876 |
+
|
| 877 |
+
// cur = attention_norm*cur
|
| 878 |
+
cur = ggml_mul(ctx0,
|
| 879 |
+
ggml_repeat(ctx0, model.layers[il].attention_norm, cur),
|
| 880 |
+
cur);
|
| 881 |
+
}
|
| 882 |
+
|
| 883 |
+
// self-attention
|
| 884 |
+
{
|
| 885 |
+
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
| 886 |
+
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
| 887 |
+
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
| 888 |
+
|
| 889 |
+
// store key and value to memory
|
| 890 |
+
if (N >= 1) {
|
| 891 |
+
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
| 892 |
+
struct ggml_tensor * v = ggml_view_1d(ctx0, kv_self.v, N*n_embd, (ggml_element_size(kv_self.v)*n_embd)*(il*n_ctx + n_past));
|
| 893 |
+
|
| 894 |
+
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
| 895 |
+
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
| 896 |
+
}
|
| 897 |
+
|
| 898 |
+
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
|
| 899 |
+
struct ggml_tensor * Q =
|
| 900 |
+
ggml_permute(ctx0,
|
| 901 |
+
ggml_rope(ctx0,
|
| 902 |
+
ggml_cpy(ctx0,
|
| 903 |
+
Qcur,
|
| 904 |
+
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
|
| 905 |
+
n_past, n_rot, 0),
|
| 906 |
+
0, 2, 1, 3);
|
| 907 |
+
|
| 908 |
+
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
|
| 909 |
+
struct ggml_tensor * K =
|
| 910 |
+
ggml_permute(ctx0,
|
| 911 |
+
ggml_rope(ctx0,
|
| 912 |
+
ggml_reshape_3d(ctx0,
|
| 913 |
+
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
|
| 914 |
+
n_embd/n_head, n_head, n_past + N),
|
| 915 |
+
n_past, n_rot, 1),
|
| 916 |
+
0, 2, 1, 3);
|
| 917 |
+
|
| 918 |
+
// K * Q
|
| 919 |
+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
| 920 |
+
|
| 921 |
+
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
| 922 |
+
struct ggml_tensor * KQ_scaled =
|
| 923 |
+
ggml_scale(ctx0,
|
| 924 |
+
KQ,
|
| 925 |
+
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)));
|
| 926 |
+
|
| 927 |
+
// KQ_masked = mask_past(KQ_scaled)
|
| 928 |
+
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
| 929 |
+
|
| 930 |
+
// KQ = soft_max(KQ_masked)
|
| 931 |
+
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
| 932 |
+
|
| 933 |
+
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
| 934 |
+
struct ggml_tensor * V_trans =
|
| 935 |
+
ggml_cpy(ctx0,
|
| 936 |
+
ggml_permute(ctx0,
|
| 937 |
+
ggml_reshape_3d(ctx0,
|
| 938 |
+
ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
|
| 939 |
+
n_embd/n_head, n_head, n_past + N),
|
| 940 |
+
1, 2, 0, 3),
|
| 941 |
+
ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
|
| 942 |
+
|
| 943 |
+
// KQV = transpose(V) * KQ_soft_max
|
| 944 |
+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
| 945 |
+
|
| 946 |
+
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
| 947 |
+
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 948 |
+
|
| 949 |
+
// cur = KQV_merged.contiguous().view(n_embd, N)
|
| 950 |
+
cur = ggml_cpy(ctx0,
|
| 951 |
+
KQV_merged,
|
| 952 |
+
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
| 953 |
+
|
| 954 |
+
// projection (no bias)
|
| 955 |
+
cur = ggml_mul_mat(ctx0,
|
| 956 |
+
model.layers[il].wo,
|
| 957 |
+
cur);
|
| 958 |
+
}
|
| 959 |
+
|
| 960 |
+
lctx.use_buf(ctx0, 1);
|
| 961 |
+
|
| 962 |
+
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
| 963 |
+
|
| 964 |
+
// feed-forward network
|
| 965 |
+
{
|
| 966 |
+
// norm
|
| 967 |
+
{
|
| 968 |
+
cur = ggml_rms_norm(ctx0, inpFF);
|
| 969 |
+
|
| 970 |
+
// cur = ffn_norm*cur
|
| 971 |
+
cur = ggml_mul(ctx0,
|
| 972 |
+
ggml_repeat(ctx0, model.layers[il].ffn_norm, cur),
|
| 973 |
+
cur);
|
| 974 |
+
}
|
| 975 |
+
|
| 976 |
+
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
| 977 |
+
model.layers[il].w3,
|
| 978 |
+
cur);
|
| 979 |
+
|
| 980 |
+
cur = ggml_mul_mat(ctx0,
|
| 981 |
+
model.layers[il].w1,
|
| 982 |
+
cur);
|
| 983 |
+
|
| 984 |
+
// SILU activation
|
| 985 |
+
cur = ggml_silu(ctx0, cur);
|
| 986 |
+
|
| 987 |
+
cur = ggml_mul(ctx0, cur, tmp);
|
| 988 |
+
|
| 989 |
+
cur = ggml_mul_mat(ctx0,
|
| 990 |
+
model.layers[il].w2,
|
| 991 |
+
cur);
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
cur = ggml_add(ctx0, cur, inpFF);
|
| 995 |
+
|
| 996 |
+
// input for next layer
|
| 997 |
+
inpL = cur;
|
| 998 |
+
}
|
| 999 |
+
|
| 1000 |
+
lctx.use_buf(ctx0, 0);
|
| 1001 |
+
|
| 1002 |
+
// used at the end to optionally extract the embeddings
|
| 1003 |
+
struct ggml_tensor * embeddings = NULL;
|
| 1004 |
+
|
| 1005 |
+
// norm
|
| 1006 |
+
{
|
| 1007 |
+
|
| 1008 |
+
inpL = ggml_rms_norm(ctx0, inpL);
|
| 1009 |
+
|
| 1010 |
+
// inpL = norm*inpL
|
| 1011 |
+
inpL = ggml_mul(ctx0,
|
| 1012 |
+
ggml_repeat(ctx0, model.norm, inpL),
|
| 1013 |
+
inpL);
|
| 1014 |
+
|
| 1015 |
+
embeddings = inpL;
|
| 1016 |
+
}
|
| 1017 |
+
|
| 1018 |
+
// lm_head
|
| 1019 |
+
inpL = ggml_mul_mat(ctx0, model.output, inpL);
|
| 1020 |
+
|
| 1021 |
+
lctx.use_buf(ctx0, -1);
|
| 1022 |
+
|
| 1023 |
+
// logits -> probs
|
| 1024 |
+
//inpL = ggml_soft_max(ctx0, inpL);
|
| 1025 |
+
|
| 1026 |
+
// run the computation
|
| 1027 |
+
ggml_build_forward_expand(&gf, inpL);
|
| 1028 |
+
ggml_graph_compute (ctx0, &gf);
|
| 1029 |
+
|
| 1030 |
+
//if (n_past%100 == 0) {
|
| 1031 |
+
// ggml_graph_print (&gf);
|
| 1032 |
+
// ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot");
|
| 1033 |
+
//}
|
| 1034 |
+
|
| 1035 |
+
//embd_w.resize(n_vocab*N);
|
| 1036 |
+
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
| 1037 |
+
|
| 1038 |
+
// extract logits
|
| 1039 |
+
{
|
| 1040 |
+
auto & logits_out = lctx.logits;
|
| 1041 |
+
|
| 1042 |
+
if (lctx.logits_all) {
|
| 1043 |
+
logits_out.resize(n_vocab * N);
|
| 1044 |
+
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
| 1045 |
+
} else {
|
| 1046 |
+
// return result for just the last token
|
| 1047 |
+
logits_out.resize(n_vocab);
|
| 1048 |
+
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
| 1049 |
+
}
|
| 1050 |
+
}
|
| 1051 |
+
|
| 1052 |
+
// extract embeddings
|
| 1053 |
+
if (lctx.embedding.size()) {
|
| 1054 |
+
auto & embedding_out = lctx.embedding;
|
| 1055 |
+
|
| 1056 |
+
embedding_out.resize(n_embd);
|
| 1057 |
+
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
|
| 1058 |
+
}
|
| 1059 |
+
|
| 1060 |
+
if (mem_per_token == 0) {
|
| 1061 |
+
mem_per_token = ggml_used_mem(ctx0)/N;
|
| 1062 |
+
}
|
| 1063 |
+
|
| 1064 |
+
#if 0
|
| 1065 |
+
printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__,
|
| 1066 |
+
ggml_used_mem(ctx0)/1024.0/1024.0,
|
| 1067 |
+
lctx.get_buf_max_mem(0)/1024.0/1024.0,
|
| 1068 |
+
lctx.get_buf_max_mem(1)/1024.0/1024.0);
|
| 1069 |
+
#endif
|
| 1070 |
+
|
| 1071 |
+
ggml_free(ctx0);
|
| 1072 |
+
|
| 1073 |
+
// measure the performance only for the single-token evals
|
| 1074 |
+
if (N == 1) {
|
| 1075 |
+
lctx.t_eval_us += ggml_time_us() - t_start_us;
|
| 1076 |
+
lctx.n_eval++;
|
| 1077 |
+
}
|
| 1078 |
+
else if (N > 1) {
|
| 1079 |
+
lctx.t_p_eval_us += ggml_time_us() - t_start_us;
|
| 1080 |
+
lctx.n_p_eval += N;
|
| 1081 |
+
}
|
| 1082 |
+
|
| 1083 |
+
return true;
|
| 1084 |
+
}
|
| 1085 |
+
|
| 1086 |
+
//
|
| 1087 |
+
// tokenizer
|
| 1088 |
+
//
|
| 1089 |
+
|
| 1090 |
+
static size_t utf8_len(char src) {
|
| 1091 |
+
const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
| 1092 |
+
uint8_t highbits = static_cast<uint8_t>(src) >> 4;
|
| 1093 |
+
return lookup[highbits];
|
| 1094 |
+
}
|
| 1095 |
+
|
| 1096 |
+
struct llama_sp_symbol {
|
| 1097 |
+
using index = int;
|
| 1098 |
+
index prev;
|
| 1099 |
+
index next;
|
| 1100 |
+
const char * text;
|
| 1101 |
+
size_t n;
|
| 1102 |
+
};
|
| 1103 |
+
|
| 1104 |
+
struct llama_sp_bigram {
|
| 1105 |
+
struct comparator {
|
| 1106 |
+
bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) {
|
| 1107 |
+
return (l.score < r.score) || (l.score == r.score && l.left > r.left);
|
| 1108 |
+
}
|
| 1109 |
+
};
|
| 1110 |
+
using queue_storage = std::vector<llama_sp_bigram>;
|
| 1111 |
+
using queue = std::priority_queue<llama_sp_bigram, queue_storage, comparator>;
|
| 1112 |
+
llama_sp_symbol::index left;
|
| 1113 |
+
llama_sp_symbol::index right;
|
| 1114 |
+
float score;
|
| 1115 |
+
size_t size;
|
| 1116 |
+
};
|
| 1117 |
+
|
| 1118 |
+
// original implementation:
|
| 1119 |
+
// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4
|
| 1120 |
+
struct llama_tokenizer {
|
| 1121 |
+
llama_tokenizer(const llama_vocab & vocab): vocab_(vocab) {}
|
| 1122 |
+
|
| 1123 |
+
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
| 1124 |
+
// split string into utf8 chars
|
| 1125 |
+
int index = 0;
|
| 1126 |
+
size_t offs = 0;
|
| 1127 |
+
while (offs < text.size()) {
|
| 1128 |
+
llama_sp_symbol sym;
|
| 1129 |
+
size_t char_len = std::min(text.size() - offs, utf8_len(text[offs]));
|
| 1130 |
+
sym.text = text.c_str() + offs;
|
| 1131 |
+
sym.n = char_len;
|
| 1132 |
+
offs += char_len;
|
| 1133 |
+
sym.prev = index - 1;
|
| 1134 |
+
sym.next = offs == text.size() ? -1 : index + 1;
|
| 1135 |
+
index++;
|
| 1136 |
+
symbols_.emplace_back(std::move(sym));
|
| 1137 |
+
}
|
| 1138 |
+
|
| 1139 |
+
// seed the work queue with all possible 2-character tokens.
|
| 1140 |
+
for (size_t i = 1; i < symbols_.size(); ++i) {
|
| 1141 |
+
try_add_bigram(i - 1, i);
|
| 1142 |
+
}
|
| 1143 |
+
|
| 1144 |
+
// keep substituting the highest frequency pairs for as long as we can.
|
| 1145 |
+
while (!work_queue_.empty()) {
|
| 1146 |
+
auto bigram = work_queue_.top();
|
| 1147 |
+
work_queue_.pop();
|
| 1148 |
+
|
| 1149 |
+
auto & left_sym = symbols_[bigram.left];
|
| 1150 |
+
auto & right_sym = symbols_[bigram.right];
|
| 1151 |
+
|
| 1152 |
+
// if one of the symbols already got merged, skip it.
|
| 1153 |
+
if (left_sym.n == 0 || right_sym.n == 0 ||
|
| 1154 |
+
left_sym.n + right_sym.n != bigram.size) {
|
| 1155 |
+
continue;
|
| 1156 |
+
}
|
| 1157 |
+
|
| 1158 |
+
// merge the right sym into the left one
|
| 1159 |
+
left_sym.n += right_sym.n;
|
| 1160 |
+
right_sym.n = 0;
|
| 1161 |
+
|
| 1162 |
+
//printf("left = '%*s' size = %zu\n", (int) left_sym.n, left_sym.text, bigram.size);
|
| 1163 |
+
|
| 1164 |
+
// remove the right sym from the chain
|
| 1165 |
+
left_sym.next = right_sym.next;
|
| 1166 |
+
if (right_sym.next >= 0) {
|
| 1167 |
+
symbols_[right_sym.next].prev = bigram.left;
|
| 1168 |
+
}
|
| 1169 |
+
|
| 1170 |
+
// find more substitutions
|
| 1171 |
+
try_add_bigram(left_sym.prev, bigram.left);
|
| 1172 |
+
try_add_bigram(bigram.left, left_sym.next);
|
| 1173 |
+
}
|
| 1174 |
+
|
| 1175 |
+
for (int i = 0; i != -1; i = symbols_[i].next) {
|
| 1176 |
+
auto & symbol = symbols_[i];
|
| 1177 |
+
auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n));
|
| 1178 |
+
|
| 1179 |
+
if (token == vocab_.token_to_id.end()) {
|
| 1180 |
+
// output any symbols that did not form tokens as bytes.
|
| 1181 |
+
for (int j = 0; j < (int) symbol.n; ++j) {
|
| 1182 |
+
llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
|
| 1183 |
+
output.push_back(token_id);
|
| 1184 |
+
}
|
| 1185 |
+
} else {
|
| 1186 |
+
output.push_back((*token).second);
|
| 1187 |
+
}
|
| 1188 |
+
}
|
| 1189 |
+
}
|
| 1190 |
+
|
| 1191 |
+
private:
|
| 1192 |
+
void try_add_bigram(int left, int right) {
|
| 1193 |
+
if (left == -1 || right == -1) {
|
| 1194 |
+
return;
|
| 1195 |
+
}
|
| 1196 |
+
|
| 1197 |
+
const std::string text = std::string(symbols_[left].text, symbols_[left].n + symbols_[right].n);
|
| 1198 |
+
auto token = vocab_.token_to_id.find(text);
|
| 1199 |
+
|
| 1200 |
+
if (token == vocab_.token_to_id.end()) {
|
| 1201 |
+
return;
|
| 1202 |
+
}
|
| 1203 |
+
|
| 1204 |
+
if (static_cast<size_t>((*token).second) >= vocab_.id_to_token.size()) {
|
| 1205 |
+
return;
|
| 1206 |
+
}
|
| 1207 |
+
|
| 1208 |
+
const auto &tok_score = vocab_.id_to_token[(*token).second];
|
| 1209 |
+
|
| 1210 |
+
llama_sp_bigram bigram;
|
| 1211 |
+
bigram.left = left;
|
| 1212 |
+
bigram.right = right;
|
| 1213 |
+
bigram.score = tok_score.score;
|
| 1214 |
+
bigram.size = text.size();
|
| 1215 |
+
work_queue_.push(bigram);
|
| 1216 |
+
}
|
| 1217 |
+
|
| 1218 |
+
const llama_vocab & vocab_;
|
| 1219 |
+
std::vector<llama_sp_symbol> symbols_;
|
| 1220 |
+
llama_sp_bigram::queue work_queue_;
|
| 1221 |
+
};
|
| 1222 |
+
|
| 1223 |
+
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
|
| 1224 |
+
llama_tokenizer tokenizer(vocab);
|
| 1225 |
+
std::vector<llama_vocab::id> output;
|
| 1226 |
+
|
| 1227 |
+
if (text.size() == 0) {
|
| 1228 |
+
return output;
|
| 1229 |
+
}
|
| 1230 |
+
|
| 1231 |
+
if (bos) {
|
| 1232 |
+
output.push_back(1);
|
| 1233 |
+
}
|
| 1234 |
+
|
| 1235 |
+
tokenizer.tokenize(text, output);
|
| 1236 |
+
return output;
|
| 1237 |
+
}
|
| 1238 |
+
|
| 1239 |
+
//
|
| 1240 |
+
// sampling
|
| 1241 |
+
//
|
| 1242 |
+
|
| 1243 |
+
static void sample_top_k(std::vector<std::pair<double, llama_vocab::id>> & logits_id, int top_k) {
|
| 1244 |
+
// find the top k tokens
|
| 1245 |
+
std::partial_sort(
|
| 1246 |
+
logits_id.begin(),
|
| 1247 |
+
logits_id.begin() + top_k, logits_id.end(),
|
| 1248 |
+
[](const std::pair<double, llama_vocab::id> & a, const std::pair<double, llama_vocab::id> & b) {
|
| 1249 |
+
return a.first > b.first;
|
| 1250 |
+
});
|
| 1251 |
+
|
| 1252 |
+
logits_id.resize(top_k);
|
| 1253 |
+
}
|
| 1254 |
+
|
| 1255 |
+
static llama_vocab::id llama_sample_top_p_top_k(
|
| 1256 |
+
llama_context & lctx,
|
| 1257 |
+
const std::vector<llama_vocab::id> & last_n_tokens,
|
| 1258 |
+
int top_k,
|
| 1259 |
+
double top_p,
|
| 1260 |
+
double temp,
|
| 1261 |
+
double repeat_penalty) {
|
| 1262 |
+
auto & rng = lctx.rng;
|
| 1263 |
+
|
| 1264 |
+
const int n_logits = lctx.model.hparams.n_vocab;
|
| 1265 |
+
|
| 1266 |
+
const auto & logits = lctx.logits;
|
| 1267 |
+
const auto * plogits = logits.data() + logits.size() - n_logits;
|
| 1268 |
+
|
| 1269 |
+
std::vector<std::pair<double, llama_vocab::id>> logits_id;
|
| 1270 |
+
logits_id.reserve(n_logits);
|
| 1271 |
+
|
| 1272 |
+
{
|
| 1273 |
+
const double scale = 1.0/temp;
|
| 1274 |
+
for (int i = 0; i < n_logits; ++i) {
|
| 1275 |
+
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
|
| 1276 |
+
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
| 1277 |
+
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
| 1278 |
+
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
| 1279 |
+
if (plogits[i] < 0.0) {
|
| 1280 |
+
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
|
| 1281 |
+
} else {
|
| 1282 |
+
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
|
| 1283 |
+
}
|
| 1284 |
+
} else {
|
| 1285 |
+
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
|
| 1286 |
+
}
|
| 1287 |
+
}
|
| 1288 |
+
}
|
| 1289 |
+
|
| 1290 |
+
sample_top_k(logits_id, top_k);
|
| 1291 |
+
|
| 1292 |
+
double maxl = -std::numeric_limits<double>::infinity();
|
| 1293 |
+
for (const auto & kv : logits_id) {
|
| 1294 |
+
maxl = std::max(maxl, kv.first);
|
| 1295 |
+
}
|
| 1296 |
+
|
| 1297 |
+
// compute probs for the top k tokens
|
| 1298 |
+
std::vector<double> probs;
|
| 1299 |
+
probs.reserve(logits_id.size());
|
| 1300 |
+
|
| 1301 |
+
double sum = 0.0;
|
| 1302 |
+
for (const auto & kv : logits_id) {
|
| 1303 |
+
double p = exp(kv.first - maxl);
|
| 1304 |
+
probs.push_back(p);
|
| 1305 |
+
sum += p;
|
| 1306 |
+
}
|
| 1307 |
+
|
| 1308 |
+
// normalize the probs
|
| 1309 |
+
for (auto & p : probs) {
|
| 1310 |
+
p /= sum;
|
| 1311 |
+
}
|
| 1312 |
+
|
| 1313 |
+
if (top_p < 1.0f) {
|
| 1314 |
+
double cumsum = 0.0f;
|
| 1315 |
+
for (int i = 0; i < (int) probs.size(); i++) {
|
| 1316 |
+
cumsum += probs[i];
|
| 1317 |
+
if (cumsum >= top_p) {
|
| 1318 |
+
probs.resize(i + 1);
|
| 1319 |
+
logits_id.resize(i + 1);
|
| 1320 |
+
break;
|
| 1321 |
+
}
|
| 1322 |
+
}
|
| 1323 |
+
|
| 1324 |
+
cumsum = 1.0/cumsum;
|
| 1325 |
+
for (int i = 0; i < (int) probs.size(); i++) {
|
| 1326 |
+
probs[i] *= cumsum;
|
| 1327 |
+
}
|
| 1328 |
+
}
|
| 1329 |
+
|
| 1330 |
+
//printf("\n");
|
| 1331 |
+
//for (int i = 0; i < (int) 10; i++) {
|
| 1332 |
+
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
|
| 1333 |
+
//}
|
| 1334 |
+
//printf("\n\n");
|
| 1335 |
+
//exit(0);
|
| 1336 |
+
|
| 1337 |
+
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
| 1338 |
+
int idx = dist(rng);
|
| 1339 |
+
|
| 1340 |
+
return logits_id[idx].second;
|
| 1341 |
+
}
|
| 1342 |
+
|
| 1343 |
+
//
|
| 1344 |
+
// quantization
|
| 1345 |
+
//
|
| 1346 |
+
|
| 1347 |
+
// TODO: reuse code from the llama_model_load() somehow
|
| 1348 |
+
bool llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, int itype, int qk) {
|
| 1349 |
+
ggml_type type = GGML_TYPE_Q4_1;
|
| 1350 |
+
|
| 1351 |
+
switch (itype) {
|
| 1352 |
+
case 2: type = GGML_TYPE_Q4_0; break;
|
| 1353 |
+
case 3: type = GGML_TYPE_Q4_1; break;
|
| 1354 |
+
default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1;
|
| 1355 |
+
};
|
| 1356 |
+
|
| 1357 |
+
if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) {
|
| 1358 |
+
fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type);
|
| 1359 |
+
return false;
|
| 1360 |
+
}
|
| 1361 |
+
|
| 1362 |
+
llama_vocab vocab;
|
| 1363 |
+
|
| 1364 |
+
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
|
| 1365 |
+
|
| 1366 |
+
auto finp = std::ifstream(fname_inp, std::ios::binary);
|
| 1367 |
+
if (!finp) {
|
| 1368 |
+
fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str());
|
| 1369 |
+
return false;
|
| 1370 |
+
}
|
| 1371 |
+
|
| 1372 |
+
auto fout = std::ofstream(fname_out, std::ios::binary);
|
| 1373 |
+
if (!fout) {
|
| 1374 |
+
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str());
|
| 1375 |
+
return false;
|
| 1376 |
+
}
|
| 1377 |
+
|
| 1378 |
+
// verify magic
|
| 1379 |
+
{
|
| 1380 |
+
uint32_t magic;
|
| 1381 |
+
finp.read((char *) &magic, sizeof(magic));
|
| 1382 |
+
if (magic == LLAMA_FILE_MAGIC_UNVERSIONED) {
|
| 1383 |
+
fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n",
|
| 1384 |
+
__func__, fname_inp.c_str());
|
| 1385 |
+
return false;
|
| 1386 |
+
}
|
| 1387 |
+
if (magic != LLAMA_FILE_MAGIC) {
|
| 1388 |
+
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
|
| 1389 |
+
return false;
|
| 1390 |
+
}
|
| 1391 |
+
|
| 1392 |
+
fout.write((char *) &magic, sizeof(magic));
|
| 1393 |
+
|
| 1394 |
+
uint32_t format_version;
|
| 1395 |
+
finp.read((char *) &format_version, sizeof(format_version));
|
| 1396 |
+
|
| 1397 |
+
if (format_version != LLAMA_FILE_VERSION) {
|
| 1398 |
+
fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ", expected %d)\n",
|
| 1399 |
+
__func__, fname_inp.c_str(), format_version, LLAMA_FILE_VERSION);
|
| 1400 |
+
return false;
|
| 1401 |
+
}
|
| 1402 |
+
|
| 1403 |
+
fout.write((char *) &format_version, sizeof(format_version));
|
| 1404 |
+
}
|
| 1405 |
+
|
| 1406 |
+
llama_hparams hparams;
|
| 1407 |
+
|
| 1408 |
+
// load hparams
|
| 1409 |
+
{
|
| 1410 |
+
finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
| 1411 |
+
//finp.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
|
| 1412 |
+
finp.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
|
| 1413 |
+
finp.read((char *) &hparams.n_mult, sizeof(hparams.n_mult));
|
| 1414 |
+
finp.read((char *) &hparams.n_head, sizeof(hparams.n_head));
|
| 1415 |
+
finp.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
|
| 1416 |
+
finp.read((char *) &hparams.n_rot, sizeof(hparams.n_rot));
|
| 1417 |
+
finp.read((char *) &hparams.f16, sizeof(hparams.f16));
|
| 1418 |
+
|
| 1419 |
+
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
| 1420 |
+
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
|
| 1421 |
+
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
|
| 1422 |
+
printf("%s: n_mult = %d\n", __func__, hparams.n_mult);
|
| 1423 |
+
printf("%s: n_head = %d\n", __func__, hparams.n_head);
|
| 1424 |
+
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
|
| 1425 |
+
printf("%s: f16 = %d\n", __func__, hparams.f16);
|
| 1426 |
+
|
| 1427 |
+
fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
| 1428 |
+
//fout.write((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
|
| 1429 |
+
fout.write((char *) &hparams.n_embd, sizeof(hparams.n_embd));
|
| 1430 |
+
fout.write((char *) &hparams.n_mult, sizeof(hparams.n_mult));
|
| 1431 |
+
fout.write((char *) &hparams.n_head, sizeof(hparams.n_head));
|
| 1432 |
+
fout.write((char *) &hparams.n_layer, sizeof(hparams.n_layer));
|
| 1433 |
+
fout.write((char *) &hparams.n_rot, sizeof(hparams.n_rot));
|
| 1434 |
+
fout.write((char *) &itype, sizeof(hparams.f16));
|
| 1435 |
+
}
|
| 1436 |
+
|
| 1437 |
+
// load vocab
|
| 1438 |
+
{
|
| 1439 |
+
const int32_t n_vocab = hparams.n_vocab;
|
| 1440 |
+
|
| 1441 |
+
if (n_vocab != hparams.n_vocab) {
|
| 1442 |
+
fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
| 1443 |
+
__func__, fname_inp.c_str(), n_vocab, hparams.n_vocab);
|
| 1444 |
+
return false;
|
| 1445 |
+
}
|
| 1446 |
+
|
| 1447 |
+
std::string word;
|
| 1448 |
+
vocab.id_to_token.resize(n_vocab);
|
| 1449 |
+
for (int i = 0; i < n_vocab; i++) {
|
| 1450 |
+
uint32_t len;
|
| 1451 |
+
finp.read ((char *) &len, sizeof(len));
|
| 1452 |
+
fout.write((char *) &len, sizeof(len));
|
| 1453 |
+
|
| 1454 |
+
word.resize(len);
|
| 1455 |
+
finp.read ((char *) word.data(), len);
|
| 1456 |
+
fout.write((char *) word.data(), len);
|
| 1457 |
+
|
| 1458 |
+
float score;
|
| 1459 |
+
finp.read ((char *) &score, sizeof(score));
|
| 1460 |
+
fout.write((char *) &score, sizeof(score));
|
| 1461 |
+
|
| 1462 |
+
vocab.token_to_id[word] = i;
|
| 1463 |
+
|
| 1464 |
+
auto &tok_score = vocab.id_to_token[i];
|
| 1465 |
+
tok_score.tok = word;
|
| 1466 |
+
tok_score.score = score;
|
| 1467 |
+
}
|
| 1468 |
+
}
|
| 1469 |
+
|
| 1470 |
+
// load weights
|
| 1471 |
+
{
|
| 1472 |
+
size_t total_size_org = 0;
|
| 1473 |
+
size_t total_size_new = 0;
|
| 1474 |
+
|
| 1475 |
+
std::vector<float> work;
|
| 1476 |
+
|
| 1477 |
+
std::vector<uint8_t> data_u8;
|
| 1478 |
+
std::vector<ggml_fp16_t> data_f16;
|
| 1479 |
+
std::vector<float> data_f32;
|
| 1480 |
+
|
| 1481 |
+
std::vector<int64_t> hist_all(1 << 4, 0);
|
| 1482 |
+
|
| 1483 |
+
while (true) {
|
| 1484 |
+
int32_t n_dims;
|
| 1485 |
+
int32_t length;
|
| 1486 |
+
int32_t ftype;
|
| 1487 |
+
|
| 1488 |
+
finp.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
| 1489 |
+
finp.read(reinterpret_cast<char *>(&length), sizeof(length));
|
| 1490 |
+
finp.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
| 1491 |
+
|
| 1492 |
+
if (finp.eof()) {
|
| 1493 |
+
break;
|
| 1494 |
+
}
|
| 1495 |
+
|
| 1496 |
+
int32_t nelements = 1;
|
| 1497 |
+
int32_t ne[2] = { 1, 1 };
|
| 1498 |
+
for (int i = 0; i < n_dims; ++i) {
|
| 1499 |
+
finp.read (reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
| 1500 |
+
nelements *= ne[i];
|
| 1501 |
+
}
|
| 1502 |
+
|
| 1503 |
+
std::string name(length, 0);
|
| 1504 |
+
finp.read (&name[0], length);
|
| 1505 |
+
|
| 1506 |
+
{
|
| 1507 |
+
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
|
| 1508 |
+
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
|
| 1509 |
+
}
|
| 1510 |
+
|
| 1511 |
+
// regexes of tensor names to be quantized
|
| 1512 |
+
const std::vector<std::string> k_names = {
|
| 1513 |
+
".*weight",
|
| 1514 |
+
};
|
| 1515 |
+
|
| 1516 |
+
bool quantize = false;
|
| 1517 |
+
for (const auto & s : k_names) {
|
| 1518 |
+
if (std::regex_match(name, std::regex(s))) {
|
| 1519 |
+
quantize = true;
|
| 1520 |
+
break;
|
| 1521 |
+
}
|
| 1522 |
+
}
|
| 1523 |
+
|
| 1524 |
+
// quantize only 2D tensors
|
| 1525 |
+
quantize &= (n_dims == 2);
|
| 1526 |
+
|
| 1527 |
+
if (quantize) {
|
| 1528 |
+
if (ftype != 0 && ftype != 1) {
|
| 1529 |
+
fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype);
|
| 1530 |
+
return false;
|
| 1531 |
+
}
|
| 1532 |
+
|
| 1533 |
+
if (ftype == 1) {
|
| 1534 |
+
data_f16.resize(nelements);
|
| 1535 |
+
finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
|
| 1536 |
+
data_f32.resize(nelements);
|
| 1537 |
+
for (int i = 0; i < nelements; ++i) {
|
| 1538 |
+
data_f32[i] = ggml_fp16_to_fp32(data_f16[i]);
|
| 1539 |
+
}
|
| 1540 |
+
} else {
|
| 1541 |
+
data_f32.resize(nelements);
|
| 1542 |
+
finp.read(reinterpret_cast<char *>(data_f32.data()), nelements * sizeof(float));
|
| 1543 |
+
}
|
| 1544 |
+
|
| 1545 |
+
ftype = itype;
|
| 1546 |
+
} else {
|
| 1547 |
+
const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t);
|
| 1548 |
+
|
| 1549 |
+
data_u8.resize(nelements*bpe);
|
| 1550 |
+
finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
|
| 1551 |
+
}
|
| 1552 |
+
|
| 1553 |
+
fout.write(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
| 1554 |
+
fout.write(reinterpret_cast<char *>(&length), sizeof(length));
|
| 1555 |
+
fout.write(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
| 1556 |
+
for (int i = 0; i < n_dims; ++i) {
|
| 1557 |
+
fout.write(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
| 1558 |
+
}
|
| 1559 |
+
fout.write(&name[0], length);
|
| 1560 |
+
|
| 1561 |
+
if (quantize) {
|
| 1562 |
+
printf("quantizing .. ");
|
| 1563 |
+
work.resize(nelements); // for quantization
|
| 1564 |
+
|
| 1565 |
+
size_t cur_size = 0;
|
| 1566 |
+
std::vector<int64_t> hist_cur(1 << 4, 0);
|
| 1567 |
+
|
| 1568 |
+
switch (type) {
|
| 1569 |
+
case GGML_TYPE_Q4_0:
|
| 1570 |
+
{
|
| 1571 |
+
cur_size = ggml_quantize_q4_0(data_f32.data(), work.data(), nelements, ne[0], qk, hist_cur.data());
|
| 1572 |
+
} break;
|
| 1573 |
+
case GGML_TYPE_Q4_1:
|
| 1574 |
+
{
|
| 1575 |
+
cur_size = ggml_quantize_q4_1(data_f32.data(), work.data(), nelements, ne[0], qk, hist_cur.data());
|
| 1576 |
+
} break;
|
| 1577 |
+
default:
|
| 1578 |
+
{
|
| 1579 |
+
fprintf(stderr, "%s: unsupported quantization type %d\n", __func__, type);
|
| 1580 |
+
return false;
|
| 1581 |
+
}
|
| 1582 |
+
}
|
| 1583 |
+
|
| 1584 |
+
fout.write(reinterpret_cast<char *>(work.data()), cur_size);
|
| 1585 |
+
total_size_new += cur_size;
|
| 1586 |
+
|
| 1587 |
+
printf("size = %8.2f MB -> %8.2f MB | hist: ", nelements * sizeof(float)/1024.0/1024.0, cur_size/1024.0/1024.0);
|
| 1588 |
+
for (int i = 0; i < (int) hist_cur.size(); ++i) {
|
| 1589 |
+
hist_all[i] += hist_cur[i];
|
| 1590 |
+
}
|
| 1591 |
+
|
| 1592 |
+
for (int i = 0; i < (int) hist_cur.size(); ++i) {
|
| 1593 |
+
printf("%5.3f ", hist_cur[i] / (float)nelements);
|
| 1594 |
+
}
|
| 1595 |
+
printf("\n");
|
| 1596 |
+
} else {
|
| 1597 |
+
printf("size = %8.3f MB\n", data_u8.size()/1024.0/1024.0);
|
| 1598 |
+
fout.write(reinterpret_cast<char *>(data_u8.data()), data_u8.size());
|
| 1599 |
+
total_size_new += data_u8.size();
|
| 1600 |
+
}
|
| 1601 |
+
|
| 1602 |
+
total_size_org += nelements * sizeof(float);
|
| 1603 |
+
}
|
| 1604 |
+
|
| 1605 |
+
printf("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
|
| 1606 |
+
printf("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
|
| 1607 |
+
|
| 1608 |
+
{
|
| 1609 |
+
int64_t sum_all = 0;
|
| 1610 |
+
for (int i = 0; i < (int) hist_all.size(); ++i) {
|
| 1611 |
+
sum_all += hist_all[i];
|
| 1612 |
+
}
|
| 1613 |
+
|
| 1614 |
+
printf("%s: hist: ", __func__);
|
| 1615 |
+
for (int i = 0; i < (int) hist_all.size(); ++i) {
|
| 1616 |
+
printf("%5.3f ", hist_all[i] / (float)sum_all);
|
| 1617 |
+
}
|
| 1618 |
+
printf("\n");
|
| 1619 |
+
}
|
| 1620 |
+
}
|
| 1621 |
+
|
| 1622 |
+
finp.close();
|
| 1623 |
+
fout.close();
|
| 1624 |
+
|
| 1625 |
+
return true;
|
| 1626 |
+
}
|
| 1627 |
+
|
| 1628 |
+
//
|
| 1629 |
+
// interface implementation
|
| 1630 |
+
//
|
| 1631 |
+
|
| 1632 |
+
struct llama_context * llama_init_from_file(
|
| 1633 |
+
const char * path_model,
|
| 1634 |
+
struct llama_context_params params) {
|
| 1635 |
+
ggml_time_init();
|
| 1636 |
+
|
| 1637 |
+
llama_context * ctx = new llama_context;
|
| 1638 |
+
|
| 1639 |
+
if (params.seed <= 0) {
|
| 1640 |
+
params.seed = time(NULL);
|
| 1641 |
+
}
|
| 1642 |
+
|
| 1643 |
+
ctx->rng = std::mt19937(params.seed);
|
| 1644 |
+
ctx->logits_all = params.logits_all;
|
| 1645 |
+
|
| 1646 |
+
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
| 1647 |
+
|
| 1648 |
+
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type,
|
| 1649 |
+
params.vocab_only, params.progress_callback,
|
| 1650 |
+
params.progress_callback_user_data)) {
|
| 1651 |
+
fprintf(stderr, "%s: failed to load model\n", __func__);
|
| 1652 |
+
llama_free(ctx);
|
| 1653 |
+
return nullptr;
|
| 1654 |
+
}
|
| 1655 |
+
|
| 1656 |
+
if (params.use_mlock) {
|
| 1657 |
+
char *err;
|
| 1658 |
+
if (!ggml_mlock(ctx->model.ctx, &err)) {
|
| 1659 |
+
fprintf(stderr, "%s\n", err);
|
| 1660 |
+
free(err);
|
| 1661 |
+
llama_free(ctx);
|
| 1662 |
+
return nullptr;
|
| 1663 |
+
}
|
| 1664 |
+
}
|
| 1665 |
+
|
| 1666 |
+
// reserve memory for context buffers
|
| 1667 |
+
{
|
| 1668 |
+
if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) {
|
| 1669 |
+
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
| 1670 |
+
llama_free(ctx);
|
| 1671 |
+
return nullptr;
|
| 1672 |
+
}
|
| 1673 |
+
|
| 1674 |
+
{
|
| 1675 |
+
const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v);
|
| 1676 |
+
fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
| 1677 |
+
}
|
| 1678 |
+
|
| 1679 |
+
const auto & hparams = ctx->model.hparams;
|
| 1680 |
+
|
| 1681 |
+
// resized during inference
|
| 1682 |
+
if (params.logits_all) {
|
| 1683 |
+
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
|
| 1684 |
+
} else {
|
| 1685 |
+
ctx->logits.reserve(hparams.n_ctx);
|
| 1686 |
+
}
|
| 1687 |
+
|
| 1688 |
+
if (params.embedding){
|
| 1689 |
+
ctx->embedding.resize(hparams.n_embd);
|
| 1690 |
+
}
|
| 1691 |
+
|
| 1692 |
+
ctx->buf_compute.resize(MEM_REQ_EVAL.at(ctx->model.type));
|
| 1693 |
+
|
| 1694 |
+
ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type));
|
| 1695 |
+
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type));
|
| 1696 |
+
}
|
| 1697 |
+
|
| 1698 |
+
return ctx;
|
| 1699 |
+
}
|
| 1700 |
+
|
| 1701 |
+
void llama_free(struct llama_context * ctx) {
|
| 1702 |
+
kv_cache_free(ctx->model.kv_self);
|
| 1703 |
+
|
| 1704 |
+
if (ctx->model.ctx) {
|
| 1705 |
+
ggml_free(ctx->model.ctx);
|
| 1706 |
+
}
|
| 1707 |
+
|
| 1708 |
+
delete ctx;
|
| 1709 |
+
}
|
| 1710 |
+
|
| 1711 |
+
int llama_model_quantize(
|
| 1712 |
+
const char * fname_inp,
|
| 1713 |
+
const char * fname_out,
|
| 1714 |
+
int itype,
|
| 1715 |
+
int qk) {
|
| 1716 |
+
if (!llama_model_quantize_internal(fname_inp, fname_out, itype, qk)) {
|
| 1717 |
+
fprintf(stderr, "%s: failed to quantize\n", __func__);
|
| 1718 |
+
return 1;
|
| 1719 |
+
}
|
| 1720 |
+
|
| 1721 |
+
return 0;
|
| 1722 |
+
}
|
| 1723 |
+
|
| 1724 |
+
int llama_eval(
|
| 1725 |
+
struct llama_context * ctx,
|
| 1726 |
+
const llama_token * tokens,
|
| 1727 |
+
int n_tokens,
|
| 1728 |
+
int n_past,
|
| 1729 |
+
int n_threads) {
|
| 1730 |
+
if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) {
|
| 1731 |
+
fprintf(stderr, "%s: failed to eval\n", __func__);
|
| 1732 |
+
return 1;
|
| 1733 |
+
}
|
| 1734 |
+
|
| 1735 |
+
return 0;
|
| 1736 |
+
}
|
| 1737 |
+
|
| 1738 |
+
int llama_tokenize(
|
| 1739 |
+
struct llama_context * ctx,
|
| 1740 |
+
const char * text,
|
| 1741 |
+
llama_token * tokens,
|
| 1742 |
+
int n_max_tokens,
|
| 1743 |
+
bool add_bos) {
|
| 1744 |
+
auto res = llama_tokenize(ctx->vocab, text, add_bos);
|
| 1745 |
+
|
| 1746 |
+
if (n_max_tokens < (int) res.size()) {
|
| 1747 |
+
fprintf(stderr, "%s: too many tokens\n", __func__);
|
| 1748 |
+
return -((int) res.size());
|
| 1749 |
+
}
|
| 1750 |
+
|
| 1751 |
+
for (size_t i = 0; i < res.size(); i++) {
|
| 1752 |
+
tokens[i] = res[i];
|
| 1753 |
+
}
|
| 1754 |
+
|
| 1755 |
+
return res.size();
|
| 1756 |
+
}
|
| 1757 |
+
|
| 1758 |
+
int llama_n_vocab(struct llama_context * ctx) {
|
| 1759 |
+
return ctx->vocab.id_to_token.size();
|
| 1760 |
+
}
|
| 1761 |
+
|
| 1762 |
+
int llama_n_ctx(struct llama_context * ctx) {
|
| 1763 |
+
return ctx->model.hparams.n_ctx;
|
| 1764 |
+
}
|
| 1765 |
+
|
| 1766 |
+
int llama_n_embd(struct llama_context * ctx) {
|
| 1767 |
+
return ctx->model.hparams.n_embd;
|
| 1768 |
+
}
|
| 1769 |
+
|
| 1770 |
+
float * llama_get_logits(struct llama_context * ctx) {
|
| 1771 |
+
return ctx->logits.data();
|
| 1772 |
+
}
|
| 1773 |
+
|
| 1774 |
+
float * llama_get_embeddings(struct llama_context * ctx) {
|
| 1775 |
+
return ctx->embedding.data();
|
| 1776 |
+
}
|
| 1777 |
+
|
| 1778 |
+
const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
|
| 1779 |
+
if (token >= llama_n_vocab(ctx)) {
|
| 1780 |
+
return nullptr;
|
| 1781 |
+
}
|
| 1782 |
+
|
| 1783 |
+
return ctx->vocab.id_to_token[token].tok.c_str();
|
| 1784 |
+
}
|
| 1785 |
+
|
| 1786 |
+
llama_token llama_token_bos() {
|
| 1787 |
+
return 1;
|
| 1788 |
+
}
|
| 1789 |
+
|
| 1790 |
+
llama_token llama_token_eos() {
|
| 1791 |
+
return 2;
|
| 1792 |
+
}
|
| 1793 |
+
|
| 1794 |
+
llama_token llama_sample_top_p_top_k(
|
| 1795 |
+
llama_context * ctx,
|
| 1796 |
+
const llama_token * last_n_tokens_data,
|
| 1797 |
+
int last_n_tokens_size,
|
| 1798 |
+
int top_k,
|
| 1799 |
+
double top_p,
|
| 1800 |
+
double temp,
|
| 1801 |
+
double repeat_penalty) {
|
| 1802 |
+
const int64_t t_start_sample_us = ggml_time_us();
|
| 1803 |
+
|
| 1804 |
+
llama_token result = 0;
|
| 1805 |
+
|
| 1806 |
+
// TODO: avoid this ...
|
| 1807 |
+
const auto last_n_tokens = std::vector<llama_token>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
|
| 1808 |
+
|
| 1809 |
+
result = llama_sample_top_p_top_k(
|
| 1810 |
+
*ctx,
|
| 1811 |
+
last_n_tokens,
|
| 1812 |
+
top_k,
|
| 1813 |
+
top_p,
|
| 1814 |
+
temp,
|
| 1815 |
+
repeat_penalty);
|
| 1816 |
+
|
| 1817 |
+
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
| 1818 |
+
ctx->n_sample++;
|
| 1819 |
+
|
| 1820 |
+
return result;
|
| 1821 |
+
}
|
| 1822 |
+
|
| 1823 |
+
|
| 1824 |
+
void llama_print_timings(struct llama_context * ctx) {
|
| 1825 |
+
const int64_t t_end_us = ggml_time_us();
|
| 1826 |
+
|
| 1827 |
+
const int32_t n_sample = std::max(1, ctx->n_sample);
|
| 1828 |
+
const int32_t n_eval = std::max(1, ctx->n_eval);
|
| 1829 |
+
const int32_t n_p_eval = std::max(1, ctx->n_p_eval);
|
| 1830 |
+
|
| 1831 |
+
fprintf(stderr, "\n");
|
| 1832 |
+
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
|
| 1833 |
+
fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->t_sample_us, n_sample, 1e-3f * ctx->t_sample_us / n_sample);
|
| 1834 |
+
fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3f * ctx->t_p_eval_us, n_p_eval, 1e-3f * ctx->t_p_eval_us / n_p_eval);
|
| 1835 |
+
fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->t_eval_us, n_eval, 1e-3f * ctx->t_eval_us / n_eval);
|
| 1836 |
+
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
|
| 1837 |
+
}
|
| 1838 |
+
|
| 1839 |
+
void llama_reset_timings(struct llama_context * ctx) {
|
| 1840 |
+
ctx->t_start_us = ggml_time_us();
|
| 1841 |
+
|
| 1842 |
+
ctx->t_sample_us = ctx->n_sample = 0;
|
| 1843 |
+
ctx->t_eval_us = ctx->n_eval = 0;
|
| 1844 |
+
ctx->t_p_eval_us = ctx->n_p_eval = 0;
|
| 1845 |
+
}
|
| 1846 |
+
|
| 1847 |
+
const char * llama_print_system_info(void) {
|
| 1848 |
+
static std::string s;
|
| 1849 |
+
|
| 1850 |
+
s = "";
|
| 1851 |
+
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
|
| 1852 |
+
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
|
| 1853 |
+
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
|
| 1854 |
+
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
|
| 1855 |
+
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
|
| 1856 |
+
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
|
| 1857 |
+
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
|
| 1858 |
+
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
|
| 1859 |
+
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
|
| 1860 |
+
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
|
| 1861 |
+
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
|
| 1862 |
+
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
|
| 1863 |
+
|
| 1864 |
+
return s.c_str();
|
| 1865 |
+
}
|
examples/talk-llama/llama.h
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef LLAMA_H
|
| 2 |
+
#define LLAMA_H
|
| 3 |
+
|
| 4 |
+
#include <stddef.h>
|
| 5 |
+
#include <stdint.h>
|
| 6 |
+
#include <stdbool.h>
|
| 7 |
+
|
| 8 |
+
#ifdef LLAMA_SHARED
|
| 9 |
+
# ifdef _WIN32
|
| 10 |
+
# ifdef LLAMA_BUILD
|
| 11 |
+
# define LLAMA_API __declspec(dllexport)
|
| 12 |
+
# else
|
| 13 |
+
# define LLAMA_API __declspec(dllimport)
|
| 14 |
+
# endif
|
| 15 |
+
# else
|
| 16 |
+
# define LLAMA_API __attribute__ ((visibility ("default")))
|
| 17 |
+
# endif
|
| 18 |
+
#else
|
| 19 |
+
# define LLAMA_API
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
#define LLAMA_FILE_VERSION 1
|
| 23 |
+
#define LLAMA_FILE_MAGIC 0x67676d66 // 'ggmf' in hex
|
| 24 |
+
#define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
|
| 25 |
+
|
| 26 |
+
#ifdef __cplusplus
|
| 27 |
+
extern "C" {
|
| 28 |
+
#endif
|
| 29 |
+
|
| 30 |
+
//
|
| 31 |
+
// C interface
|
| 32 |
+
//
|
| 33 |
+
// TODO: show sample usage
|
| 34 |
+
//
|
| 35 |
+
|
| 36 |
+
struct llama_context;
|
| 37 |
+
|
| 38 |
+
typedef int llama_token;
|
| 39 |
+
|
| 40 |
+
typedef struct llama_token_data {
|
| 41 |
+
llama_token id; // token id
|
| 42 |
+
|
| 43 |
+
float p; // probability of the token
|
| 44 |
+
float plog; // log probability of the token
|
| 45 |
+
|
| 46 |
+
} llama_token_data;
|
| 47 |
+
|
| 48 |
+
typedef void (*llama_progress_callback)(double progress, void *ctx);
|
| 49 |
+
|
| 50 |
+
struct llama_context_params {
|
| 51 |
+
int n_ctx; // text context
|
| 52 |
+
int n_parts; // -1 for default
|
| 53 |
+
int seed; // RNG seed, 0 for random
|
| 54 |
+
|
| 55 |
+
bool f16_kv; // use fp16 for KV cache
|
| 56 |
+
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
| 57 |
+
bool vocab_only; // only load the vocabulary, no weights
|
| 58 |
+
bool use_mlock; // force system to keep model in RAM
|
| 59 |
+
bool embedding; // embedding mode only
|
| 60 |
+
|
| 61 |
+
// called with a progress value between 0 and 1, pass NULL to disable
|
| 62 |
+
llama_progress_callback progress_callback;
|
| 63 |
+
// context pointer passed to the progress callback
|
| 64 |
+
void * progress_callback_user_data;
|
| 65 |
+
};
|
| 66 |
+
|
| 67 |
+
LLAMA_API struct llama_context_params llama_context_default_params();
|
| 68 |
+
|
| 69 |
+
// Various functions for loading a ggml llama model.
|
| 70 |
+
// Allocate (almost) all memory needed for the model.
|
| 71 |
+
// Return NULL on failure
|
| 72 |
+
LLAMA_API struct llama_context * llama_init_from_file(
|
| 73 |
+
const char * path_model,
|
| 74 |
+
struct llama_context_params params);
|
| 75 |
+
|
| 76 |
+
// Frees all allocated memory
|
| 77 |
+
LLAMA_API void llama_free(struct llama_context * ctx);
|
| 78 |
+
|
| 79 |
+
// TODO: not great API - very likely to change
|
| 80 |
+
// Returns 0 on success
|
| 81 |
+
LLAMA_API int llama_model_quantize(
|
| 82 |
+
const char * fname_inp,
|
| 83 |
+
const char * fname_out,
|
| 84 |
+
int itype,
|
| 85 |
+
int qk);
|
| 86 |
+
|
| 87 |
+
// Run the llama inference to obtain the logits and probabilities for the next token.
|
| 88 |
+
// tokens + n_tokens is the provided batch of new tokens to process
|
| 89 |
+
// n_past is the number of tokens to use from previous eval calls
|
| 90 |
+
// Returns 0 on success
|
| 91 |
+
LLAMA_API int llama_eval(
|
| 92 |
+
struct llama_context * ctx,
|
| 93 |
+
const llama_token * tokens,
|
| 94 |
+
int n_tokens,
|
| 95 |
+
int n_past,
|
| 96 |
+
int n_threads);
|
| 97 |
+
|
| 98 |
+
// Convert the provided text into tokens.
|
| 99 |
+
// The tokens pointer must be large enough to hold the resulting tokens.
|
| 100 |
+
// Returns the number of tokens on success, no more than n_max_tokens
|
| 101 |
+
// Returns a negative number on failure - the number of tokens that would have been returned
|
| 102 |
+
// TODO: not sure if correct
|
| 103 |
+
LLAMA_API int llama_tokenize(
|
| 104 |
+
struct llama_context * ctx,
|
| 105 |
+
const char * text,
|
| 106 |
+
llama_token * tokens,
|
| 107 |
+
int n_max_tokens,
|
| 108 |
+
bool add_bos);
|
| 109 |
+
|
| 110 |
+
LLAMA_API int llama_n_vocab(struct llama_context * ctx);
|
| 111 |
+
LLAMA_API int llama_n_ctx (struct llama_context * ctx);
|
| 112 |
+
LLAMA_API int llama_n_embd (struct llama_context * ctx);
|
| 113 |
+
|
| 114 |
+
// Token logits obtained from the last call to llama_eval()
|
| 115 |
+
// The logits for the last token are stored in the last row
|
| 116 |
+
// Can be mutated in order to change the probabilities of the next token
|
| 117 |
+
// Rows: n_tokens
|
| 118 |
+
// Cols: n_vocab
|
| 119 |
+
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
|
| 120 |
+
|
| 121 |
+
// Get the embeddings for the input
|
| 122 |
+
// shape: [n_embd] (1-dimensional)
|
| 123 |
+
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
|
| 124 |
+
|
| 125 |
+
// Token Id -> String. Uses the vocabulary in the provided context
|
| 126 |
+
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
|
| 127 |
+
|
| 128 |
+
// Special tokens
|
| 129 |
+
LLAMA_API llama_token llama_token_bos();
|
| 130 |
+
LLAMA_API llama_token llama_token_eos();
|
| 131 |
+
|
| 132 |
+
// TODO: improve the last_n_tokens interface ?
|
| 133 |
+
LLAMA_API llama_token llama_sample_top_p_top_k(
|
| 134 |
+
struct llama_context * ctx,
|
| 135 |
+
const llama_token * last_n_tokens_data,
|
| 136 |
+
int last_n_tokens_size,
|
| 137 |
+
int top_k,
|
| 138 |
+
double top_p,
|
| 139 |
+
double temp,
|
| 140 |
+
double repeat_penalty);
|
| 141 |
+
|
| 142 |
+
// Performance information
|
| 143 |
+
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
| 144 |
+
LLAMA_API void llama_reset_timings(struct llama_context * ctx);
|
| 145 |
+
|
| 146 |
+
// Print system information
|
| 147 |
+
LLAMA_API const char * llama_print_system_info(void);
|
| 148 |
+
|
| 149 |
+
#ifdef __cplusplus
|
| 150 |
+
}
|
| 151 |
+
#endif
|
| 152 |
+
|
| 153 |
+
#endif
|
examples/talk-llama/speak.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Usage:
|
| 4 |
+
# speak.sh <voice_id> <text-to-speak>
|
| 5 |
+
|
| 6 |
+
# espeak
|
| 7 |
+
# Mac OS: brew install espeak
|
| 8 |
+
# Linux: apt-get install espeak
|
| 9 |
+
#
|
| 10 |
+
#espeak -v en-us+m$1 -s 225 -p 50 -a 200 -g 5 -k 5 "$2"
|
| 11 |
+
|
| 12 |
+
# for Mac
|
| 13 |
+
say "$2"
|
| 14 |
+
|
| 15 |
+
# Eleven Labs
|
| 16 |
+
#
|
| 17 |
+
#wd=$(dirname $0)
|
| 18 |
+
#script=$wd/eleven-labs.py
|
| 19 |
+
#python3 $script $1 "$2" >/dev/null 2>&1
|
| 20 |
+
#ffplay -autoexit -nodisp -loglevel quiet -hide_banner -i ./audio.mp3 >/dev/null 2>&1
|
examples/talk-llama/talk-llama.cpp
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Talk with AI
|
| 2 |
+
//
|
| 3 |
+
|
| 4 |
+
#include "common.h"
|
| 5 |
+
#include "common-sdl.h"
|
| 6 |
+
#include "whisper.h"
|
| 7 |
+
#include "llama.h"
|
| 8 |
+
|
| 9 |
+
#include <cassert>
|
| 10 |
+
#include <cstdio>
|
| 11 |
+
#include <fstream>
|
| 12 |
+
#include <regex>
|
| 13 |
+
#include <string>
|
| 14 |
+
#include <thread>
|
| 15 |
+
#include <vector>
|
| 16 |
+
#include <regex>
|
| 17 |
+
|
| 18 |
+
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
|
| 19 |
+
// initialize to prompt numer of chars, since n_tokens <= n_prompt_chars
|
| 20 |
+
std::vector<llama_token> res(text.size() + (int)add_bos);
|
| 21 |
+
int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos);
|
| 22 |
+
assert(n >= 0);
|
| 23 |
+
res.resize(n);
|
| 24 |
+
|
| 25 |
+
return res;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
// command-line parameters
|
| 29 |
+
struct whisper_params {
|
| 30 |
+
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
| 31 |
+
int32_t voice_ms = 10000;
|
| 32 |
+
int32_t capture_id = -1;
|
| 33 |
+
int32_t max_tokens = 32;
|
| 34 |
+
int32_t audio_ctx = 0;
|
| 35 |
+
|
| 36 |
+
float vad_thold = 0.6f;
|
| 37 |
+
float freq_thold = 100.0f;
|
| 38 |
+
|
| 39 |
+
bool speed_up = false;
|
| 40 |
+
bool translate = false;
|
| 41 |
+
bool print_special = false;
|
| 42 |
+
bool print_energy = false;
|
| 43 |
+
bool no_timestamps = true;
|
| 44 |
+
|
| 45 |
+
std::string person = "Georgi";
|
| 46 |
+
std::string language = "en";
|
| 47 |
+
std::string model_wsp = "models/ggml-base.en.bin";
|
| 48 |
+
std::string model_llama = "models/ggml-llama-7B.bin";
|
| 49 |
+
std::string speak = "./examples/talk/speak.sh";
|
| 50 |
+
std::string fname_out;
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
| 54 |
+
|
| 55 |
+
bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
| 56 |
+
for (int i = 1; i < argc; i++) {
|
| 57 |
+
std::string arg = argv[i];
|
| 58 |
+
|
| 59 |
+
if (arg == "-h" || arg == "--help") {
|
| 60 |
+
whisper_print_usage(argc, argv, params);
|
| 61 |
+
exit(0);
|
| 62 |
+
}
|
| 63 |
+
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
| 64 |
+
else if (arg == "-vms" || arg == "--voice-ms") { params.voice_ms = std::stoi(argv[++i]); }
|
| 65 |
+
else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); }
|
| 66 |
+
else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); }
|
| 67 |
+
else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); }
|
| 68 |
+
else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); }
|
| 69 |
+
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
|
| 70 |
+
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
| 71 |
+
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
| 72 |
+
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
| 73 |
+
else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; }
|
| 74 |
+
else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; }
|
| 75 |
+
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
| 76 |
+
else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; }
|
| 77 |
+
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
|
| 78 |
+
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
|
| 79 |
+
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
|
| 80 |
+
else {
|
| 81 |
+
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
| 82 |
+
whisper_print_usage(argc, argv, params);
|
| 83 |
+
exit(0);
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
return true;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
|
| 91 |
+
fprintf(stderr, "\n");
|
| 92 |
+
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
| 93 |
+
fprintf(stderr, "\n");
|
| 94 |
+
fprintf(stderr, "options:\n");
|
| 95 |
+
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
|
| 96 |
+
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
|
| 97 |
+
fprintf(stderr, " -vms N, --voice-ms N [%-7d] voice duration in milliseconds\n", params.voice_ms);
|
| 98 |
+
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
|
| 99 |
+
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
|
| 100 |
+
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
|
| 101 |
+
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
|
| 102 |
+
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
|
| 103 |
+
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
| 104 |
+
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
| 105 |
+
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
|
| 106 |
+
fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false");
|
| 107 |
+
fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str());
|
| 108 |
+
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
|
| 109 |
+
fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str());
|
| 110 |
+
fprintf(stderr, " -mg FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str());
|
| 111 |
+
fprintf(stderr, " -s FILE, --speak TEXT [%-7s] command for TTS\n", params.speak.c_str());
|
| 112 |
+
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
|
| 113 |
+
fprintf(stderr, "\n");
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
std::string transcribe(
|
| 117 |
+
whisper_context * ctx,
|
| 118 |
+
const whisper_params & params,
|
| 119 |
+
const std::vector<float> & pcmf32,
|
| 120 |
+
const std::string prompt_text,
|
| 121 |
+
float & prob,
|
| 122 |
+
int64_t & t_ms) {
|
| 123 |
+
const auto t_start = std::chrono::high_resolution_clock::now();
|
| 124 |
+
|
| 125 |
+
prob = 0.0f;
|
| 126 |
+
t_ms = 0;
|
| 127 |
+
|
| 128 |
+
std::vector<whisper_token> prompt_tokens;
|
| 129 |
+
|
| 130 |
+
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
| 131 |
+
|
| 132 |
+
prompt_tokens.resize(1024);
|
| 133 |
+
prompt_tokens.resize(whisper_tokenize(ctx, prompt_text.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
| 134 |
+
|
| 135 |
+
wparams.print_progress = false;
|
| 136 |
+
wparams.print_special = params.print_special;
|
| 137 |
+
wparams.print_realtime = false;
|
| 138 |
+
wparams.print_timestamps = !params.no_timestamps;
|
| 139 |
+
wparams.translate = params.translate;
|
| 140 |
+
wparams.no_context = true;
|
| 141 |
+
wparams.single_segment = true;
|
| 142 |
+
wparams.max_tokens = params.max_tokens;
|
| 143 |
+
wparams.language = params.language.c_str();
|
| 144 |
+
wparams.n_threads = params.n_threads;
|
| 145 |
+
|
| 146 |
+
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
|
| 147 |
+
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
|
| 148 |
+
|
| 149 |
+
wparams.audio_ctx = params.audio_ctx;
|
| 150 |
+
wparams.speed_up = params.speed_up;
|
| 151 |
+
|
| 152 |
+
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
| 153 |
+
return "";
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
int prob_n = 0;
|
| 157 |
+
std::string result;
|
| 158 |
+
|
| 159 |
+
const int n_segments = whisper_full_n_segments(ctx);
|
| 160 |
+
for (int i = 0; i < n_segments; ++i) {
|
| 161 |
+
const char * text = whisper_full_get_segment_text(ctx, i);
|
| 162 |
+
|
| 163 |
+
result += text;
|
| 164 |
+
|
| 165 |
+
const int n_tokens = whisper_full_n_tokens(ctx, i);
|
| 166 |
+
for (int j = 0; j < n_tokens; ++j) {
|
| 167 |
+
const auto token = whisper_full_get_token_data(ctx, i, j);
|
| 168 |
+
|
| 169 |
+
prob += token.p;
|
| 170 |
+
++prob_n;
|
| 171 |
+
}
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
if (prob_n > 0) {
|
| 175 |
+
prob /= prob_n;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
const auto t_end = std::chrono::high_resolution_clock::now();
|
| 179 |
+
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
| 180 |
+
|
| 181 |
+
return result;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)";
|
| 185 |
+
|
| 186 |
+
// need to have leading ' '
|
| 187 |
+
const std::string k_prompt_llama = R"( Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}.
|
| 188 |
+
{1} is helpful, kind, honest, friendly, good at writing and never fails to answer {0}’s requests immediately and with details and precision.
|
| 189 |
+
There are no annotations like (30 seconds passed...) or (to himself), just what {0} and {1} say aloud to each other.
|
| 190 |
+
The transcript only includes text, it does not include markup like HTML and Markdown.
|
| 191 |
+
{1} responds with short and concise answers.
|
| 192 |
+
|
| 193 |
+
{0}{4} Hello, {1}!
|
| 194 |
+
{1}{4} Hello {0}! How may I help you today?
|
| 195 |
+
{0}{4} What time is it?
|
| 196 |
+
{1}{4} It is {2} o'clock.
|
| 197 |
+
{0}{4} What year is it?
|
| 198 |
+
{1}{4} We are in {3}.
|
| 199 |
+
{0}{4} What is a cat?
|
| 200 |
+
{1}{4} A cat is a domestic species of small carnivorous mammal. It is the only domesticated species in the family Felidae.
|
| 201 |
+
{0}{4} Name a color.
|
| 202 |
+
{1}{4} Blue
|
| 203 |
+
{0}{4})";
|
| 204 |
+
|
| 205 |
+
int main(int argc, char ** argv) {
|
| 206 |
+
whisper_params params;
|
| 207 |
+
|
| 208 |
+
if (whisper_params_parse(argc, argv, params) == false) {
|
| 209 |
+
return 1;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
if (whisper_lang_id(params.language.c_str()) == -1) {
|
| 213 |
+
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
| 214 |
+
whisper_print_usage(argc, argv, params);
|
| 215 |
+
exit(0);
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
// whisper init
|
| 219 |
+
|
| 220 |
+
struct whisper_context * ctx_wsp = whisper_init_from_file(params.model_wsp.c_str());
|
| 221 |
+
|
| 222 |
+
// llama init
|
| 223 |
+
|
| 224 |
+
auto lparams = llama_context_default_params();
|
| 225 |
+
|
| 226 |
+
// tune these to your liking
|
| 227 |
+
lparams.n_ctx = 512;
|
| 228 |
+
lparams.seed = 1;
|
| 229 |
+
lparams.f16_kv = true;
|
| 230 |
+
|
| 231 |
+
struct llama_context * ctx_llama = llama_init_from_file(params.model_llama.c_str(), lparams);
|
| 232 |
+
|
| 233 |
+
// print some info about the processing
|
| 234 |
+
{
|
| 235 |
+
fprintf(stderr, "\n");
|
| 236 |
+
|
| 237 |
+
if (!whisper_is_multilingual(ctx_wsp)) {
|
| 238 |
+
if (params.language != "en" || params.translate) {
|
| 239 |
+
params.language = "en";
|
| 240 |
+
params.translate = false;
|
| 241 |
+
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
| 245 |
+
__func__,
|
| 246 |
+
params.n_threads,
|
| 247 |
+
params.language.c_str(),
|
| 248 |
+
params.translate ? "translate" : "transcribe",
|
| 249 |
+
params.no_timestamps ? 0 : 1);
|
| 250 |
+
|
| 251 |
+
fprintf(stderr, "\n");
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
// init audio
|
| 256 |
+
|
| 257 |
+
audio_async audio(30*1000);
|
| 258 |
+
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
|
| 259 |
+
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
| 260 |
+
return 1;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
audio.resume();
|
| 264 |
+
|
| 265 |
+
int n_iter = 0;
|
| 266 |
+
|
| 267 |
+
bool is_running = true;
|
| 268 |
+
bool force_speak = false;
|
| 269 |
+
|
| 270 |
+
float prob0 = 0.0f;
|
| 271 |
+
|
| 272 |
+
const std::string chat_symb = ":";
|
| 273 |
+
const std::string bot_name = "LLaMA";
|
| 274 |
+
|
| 275 |
+
std::vector<float> pcmf32_cur;
|
| 276 |
+
std::vector<float> pcmf32_prompt;
|
| 277 |
+
|
| 278 |
+
const std::string prompt_whisper = ::replace(k_prompt_whisper, "{1}", bot_name);
|
| 279 |
+
|
| 280 |
+
// construct the initial prompt for LLaMA inference
|
| 281 |
+
std::string prompt_llama = k_prompt_llama;
|
| 282 |
+
|
| 283 |
+
prompt_llama = ::replace(prompt_llama, "{0}", params.person);
|
| 284 |
+
prompt_llama = ::replace(prompt_llama, "{1}", bot_name);
|
| 285 |
+
|
| 286 |
+
{
|
| 287 |
+
// get time string
|
| 288 |
+
std::string time_str;
|
| 289 |
+
{
|
| 290 |
+
time_t t = time(0);
|
| 291 |
+
struct tm * now = localtime(&t);
|
| 292 |
+
char buf[128];
|
| 293 |
+
strftime(buf, sizeof(buf), "%H:%M", now);
|
| 294 |
+
time_str = buf;
|
| 295 |
+
}
|
| 296 |
+
prompt_llama = ::replace(prompt_llama, "{2}", time_str);
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
{
|
| 300 |
+
// get year string
|
| 301 |
+
std::string year_str;
|
| 302 |
+
{
|
| 303 |
+
time_t t = time(0);
|
| 304 |
+
struct tm * now = localtime(&t);
|
| 305 |
+
char buf[128];
|
| 306 |
+
strftime(buf, sizeof(buf), "%Y", now);
|
| 307 |
+
year_str = buf;
|
| 308 |
+
}
|
| 309 |
+
prompt_llama = ::replace(prompt_llama, "{3}", year_str);
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
prompt_llama = ::replace(prompt_llama, "{4}", chat_symb);
|
| 313 |
+
|
| 314 |
+
// evaluate the initial prompt
|
| 315 |
+
|
| 316 |
+
auto embd_inp = ::llama_tokenize(ctx_llama, prompt_llama, true);
|
| 317 |
+
|
| 318 |
+
printf("\n");
|
| 319 |
+
printf("%s : initializing - please wait ...\n", __func__);
|
| 320 |
+
|
| 321 |
+
if (llama_eval(ctx_llama, embd_inp.data(), embd_inp.size(), 0, params.n_threads)) {
|
| 322 |
+
fprintf(stderr, "%s : failed to eval\n", __func__);
|
| 323 |
+
return 1;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
//fprintf(stdout, "\n");
|
| 327 |
+
//fprintf(stdout, "%s", prompt_llama.c_str());
|
| 328 |
+
//fflush(stdout);
|
| 329 |
+
|
| 330 |
+
printf("%s : done! start speaking in the microphone\n", __func__);
|
| 331 |
+
printf("\n");
|
| 332 |
+
printf("%s%s", params.person.c_str(), chat_symb.c_str());
|
| 333 |
+
fflush(stdout);
|
| 334 |
+
|
| 335 |
+
// clear audio buffer
|
| 336 |
+
audio.clear();
|
| 337 |
+
|
| 338 |
+
// text inference variables
|
| 339 |
+
const int voice_id = 2;
|
| 340 |
+
const int n_keep = embd_inp.size();
|
| 341 |
+
const int n_ctx = llama_n_ctx(ctx_llama);
|
| 342 |
+
|
| 343 |
+
int n_past = n_keep;
|
| 344 |
+
int n_prev = 64; // TODO arg
|
| 345 |
+
|
| 346 |
+
std::vector<llama_token> embd;
|
| 347 |
+
|
| 348 |
+
// reverse prompts for detecting when it's time to stop speaking
|
| 349 |
+
std::vector<std::string> antiprompts = {
|
| 350 |
+
params.person + chat_symb,
|
| 351 |
+
};
|
| 352 |
+
|
| 353 |
+
// main loop
|
| 354 |
+
while (is_running) {
|
| 355 |
+
// handle Ctrl + C
|
| 356 |
+
is_running = sdl_poll_events();
|
| 357 |
+
|
| 358 |
+
if (!is_running) {
|
| 359 |
+
break;
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
// delay
|
| 363 |
+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
| 364 |
+
|
| 365 |
+
int64_t t_ms = 0;
|
| 366 |
+
|
| 367 |
+
{
|
| 368 |
+
audio.get(2000, pcmf32_cur);
|
| 369 |
+
|
| 370 |
+
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
|
| 371 |
+
//fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
| 372 |
+
|
| 373 |
+
audio.get(params.voice_ms, pcmf32_cur);
|
| 374 |
+
|
| 375 |
+
std::string text_heard;
|
| 376 |
+
|
| 377 |
+
if (!force_speak) {
|
| 378 |
+
text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms));
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
// remove text between brackets using regex
|
| 382 |
+
{
|
| 383 |
+
std::regex re("\\[.*?\\]");
|
| 384 |
+
text_heard = std::regex_replace(text_heard, re, "");
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
// remove text between brackets using regex
|
| 388 |
+
{
|
| 389 |
+
std::regex re("\\(.*?\\)");
|
| 390 |
+
text_heard = std::regex_replace(text_heard, re, "");
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
// remove all characters, except for letters, numbers, punctuation and ':', '\'', '-', ' '
|
| 394 |
+
text_heard = std::regex_replace(text_heard, std::regex("[^a-zA-Z0-9\\.,\\?!\\s\\:\\'\\-]"), "");
|
| 395 |
+
|
| 396 |
+
// take first line
|
| 397 |
+
text_heard = text_heard.substr(0, text_heard.find_first_of('\n'));
|
| 398 |
+
|
| 399 |
+
// remove leading and trailing whitespace
|
| 400 |
+
text_heard = std::regex_replace(text_heard, std::regex("^\\s+"), "");
|
| 401 |
+
text_heard = std::regex_replace(text_heard, std::regex("\\s+$"), "");
|
| 402 |
+
|
| 403 |
+
const std::vector<llama_token> tokens = llama_tokenize(ctx_llama, text_heard.c_str(), false);
|
| 404 |
+
|
| 405 |
+
if (text_heard.empty() || tokens.empty() || force_speak) {
|
| 406 |
+
//fprintf(stdout, "%s: Heard nothing, skipping ...\n", __func__);
|
| 407 |
+
audio.clear();
|
| 408 |
+
|
| 409 |
+
continue;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
force_speak = false;
|
| 413 |
+
|
| 414 |
+
text_heard.insert(0, 1, ' ');
|
| 415 |
+
text_heard += "\n" + bot_name + chat_symb;
|
| 416 |
+
fprintf(stdout, "%s%s%s", "\033[1m", text_heard.c_str(), "\033[0m");
|
| 417 |
+
fflush(stdout);
|
| 418 |
+
|
| 419 |
+
embd = ::llama_tokenize(ctx_llama, text_heard, false);
|
| 420 |
+
|
| 421 |
+
// text inference
|
| 422 |
+
bool done = false;
|
| 423 |
+
std::string text_to_speak;
|
| 424 |
+
while (true) {
|
| 425 |
+
// predict
|
| 426 |
+
if (embd.size() > 0) {
|
| 427 |
+
if (n_past + (int) embd.size() > n_ctx) {
|
| 428 |
+
n_past = n_keep;
|
| 429 |
+
|
| 430 |
+
// insert n_left/2 tokens at the start of embd from last_n_tokens
|
| 431 |
+
embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
|
| 432 |
+
|
| 433 |
+
//printf("\n---\n");
|
| 434 |
+
//printf("resetting: '");
|
| 435 |
+
//for (int i = 0; i < (int) embd.size(); i++) {
|
| 436 |
+
// printf("%s", llama_token_to_str(ctx_llama, embd[i]));
|
| 437 |
+
//}
|
| 438 |
+
//printf("'\n");
|
| 439 |
+
//printf("\n---\n");
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
if (llama_eval(ctx_llama, embd.data(), embd.size(), n_past, params.n_threads)) {
|
| 443 |
+
fprintf(stderr, "%s : failed to eval\n", __func__);
|
| 444 |
+
return 1;
|
| 445 |
+
}
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
//printf("n_iter = %d, n_past = %d, n_ctx = %d, n_keep = %d, n_prev = %d, embd.size() = %d\n", n_iter, n_past, n_ctx, n_keep, n_prev, (int) embd.size());
|
| 449 |
+
|
| 450 |
+
embd_inp.insert(embd_inp.end(), embd.begin(), embd.end());
|
| 451 |
+
n_past += embd.size();
|
| 452 |
+
embd.clear();
|
| 453 |
+
|
| 454 |
+
if (done) break;
|
| 455 |
+
|
| 456 |
+
{
|
| 457 |
+
// out of user input, sample next token
|
| 458 |
+
const float top_k = 5;
|
| 459 |
+
const float top_p = 0.80f;
|
| 460 |
+
const float temp = 0.30f;
|
| 461 |
+
const float repeat_penalty = 1.1764f;
|
| 462 |
+
|
| 463 |
+
const int repeat_last_n = 256;
|
| 464 |
+
|
| 465 |
+
llama_token id = 0;
|
| 466 |
+
|
| 467 |
+
{
|
| 468 |
+
auto logits = llama_get_logits(ctx_llama);
|
| 469 |
+
logits[llama_token_eos()] = 0;
|
| 470 |
+
|
| 471 |
+
id = llama_sample_top_p_top_k(ctx_llama,
|
| 472 |
+
embd_inp.data() + std::max(0, n_past - repeat_last_n),
|
| 473 |
+
repeat_last_n, top_k, top_p, temp, repeat_penalty);
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
if (id != llama_token_eos()) {
|
| 477 |
+
// add it to the context
|
| 478 |
+
embd.push_back(id);
|
| 479 |
+
|
| 480 |
+
text_to_speak += llama_token_to_str(ctx_llama, id);
|
| 481 |
+
|
| 482 |
+
printf("%s", llama_token_to_str(ctx_llama, id));
|
| 483 |
+
}
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
{
|
| 487 |
+
std::string last_output;
|
| 488 |
+
for (int i = embd_inp.size() - 16; i < (int) embd_inp.size(); i++) {
|
| 489 |
+
last_output += llama_token_to_str(ctx_llama, embd_inp[i]);
|
| 490 |
+
}
|
| 491 |
+
last_output += llama_token_to_str(ctx_llama, embd[0]);
|
| 492 |
+
|
| 493 |
+
for (std::string & antiprompt : antiprompts) {
|
| 494 |
+
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
|
| 495 |
+
done = true;
|
| 496 |
+
text_to_speak = ::replace(text_to_speak, antiprompt, "");
|
| 497 |
+
fflush(stdout);
|
| 498 |
+
break;
|
| 499 |
+
}
|
| 500 |
+
}
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
is_running = sdl_poll_events();
|
| 504 |
+
|
| 505 |
+
if (!is_running) {
|
| 506 |
+
break;
|
| 507 |
+
}
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
text_to_speak = ::replace(text_to_speak, "\"", "");
|
| 511 |
+
system((params.speak + " " + std::to_string(voice_id) + " \"" + text_to_speak + "\"").c_str());
|
| 512 |
+
|
| 513 |
+
audio.clear();
|
| 514 |
+
|
| 515 |
+
++n_iter;
|
| 516 |
+
}
|
| 517 |
+
}
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
audio.pause();
|
| 521 |
+
|
| 522 |
+
whisper_print_timings(ctx_wsp);
|
| 523 |
+
whisper_free(ctx_wsp);
|
| 524 |
+
|
| 525 |
+
llama_print_timings(ctx_llama);
|
| 526 |
+
llama_free(ctx_llama);
|
| 527 |
+
|
| 528 |
+
return 0;
|
| 529 |
+
}
|
examples/talk/speak.sh
CHANGED
|
@@ -7,7 +7,10 @@
|
|
| 7 |
# Mac OS: brew install espeak
|
| 8 |
# Linux: apt-get install espeak
|
| 9 |
#
|
| 10 |
-
espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2"
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# Eleven Labs
|
| 13 |
#
|
|
|
|
| 7 |
# Mac OS: brew install espeak
|
| 8 |
# Linux: apt-get install espeak
|
| 9 |
#
|
| 10 |
+
#espeak -v en-us+m$1 -s 175 -p 50 -a 200 -g 5 -k 5 "$2"
|
| 11 |
+
|
| 12 |
+
# Mac OS "say" command
|
| 13 |
+
say "$2"
|
| 14 |
|
| 15 |
# Eleven Labs
|
| 16 |
#
|
ggml.c
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ggml.h
CHANGED
|
@@ -198,6 +198,8 @@ struct ggml_object;
|
|
| 198 |
struct ggml_context;
|
| 199 |
|
| 200 |
enum ggml_type {
|
|
|
|
|
|
|
| 201 |
GGML_TYPE_I8,
|
| 202 |
GGML_TYPE_I16,
|
| 203 |
GGML_TYPE_I32,
|
|
@@ -226,7 +228,9 @@ enum ggml_op {
|
|
| 226 |
GGML_OP_STEP,
|
| 227 |
GGML_OP_RELU,
|
| 228 |
GGML_OP_GELU,
|
|
|
|
| 229 |
GGML_OP_NORM, // normalize
|
|
|
|
| 230 |
|
| 231 |
GGML_OP_MUL_MAT,
|
| 232 |
|
|
@@ -326,7 +330,10 @@ void ggml_print_objects(const struct ggml_context * ctx);
|
|
| 326 |
int ggml_nelements(const struct ggml_tensor * tensor);
|
| 327 |
size_t ggml_nbytes (const struct ggml_tensor * tensor);
|
| 328 |
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
| 330 |
size_t ggml_element_size(const struct ggml_tensor * tensor);
|
| 331 |
|
| 332 |
struct ggml_context * ggml_init(struct ggml_init_params params);
|
|
@@ -336,6 +343,9 @@ size_t ggml_used_mem(const struct ggml_context * ctx);
|
|
| 336 |
|
| 337 |
size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
|
| 338 |
|
|
|
|
|
|
|
|
|
|
| 339 |
struct ggml_tensor * ggml_new_tensor(
|
| 340 |
struct ggml_context * ctx,
|
| 341 |
enum ggml_type type,
|
|
@@ -466,12 +476,20 @@ struct ggml_tensor * ggml_gelu(
|
|
| 466 |
struct ggml_context * ctx,
|
| 467 |
struct ggml_tensor * a);
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
// normalize along rows
|
| 470 |
// TODO: eps is hardcoded to 1e-5 for now
|
| 471 |
struct ggml_tensor * ggml_norm(
|
| 472 |
struct ggml_context * ctx,
|
| 473 |
struct ggml_tensor * a);
|
| 474 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
// A: m rows, n columns
|
| 476 |
// B: p rows, n columns (i.e. we transpose it internally)
|
| 477 |
// result is m columns, p rows
|
|
@@ -726,6 +744,13 @@ enum ggml_opt_result ggml_opt(
|
|
| 726 |
struct ggml_opt_params params,
|
| 727 |
struct ggml_tensor * f);
|
| 728 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
//
|
| 730 |
// system info
|
| 731 |
//
|
|
|
|
| 198 |
struct ggml_context;
|
| 199 |
|
| 200 |
enum ggml_type {
|
| 201 |
+
GGML_TYPE_Q4_0,
|
| 202 |
+
GGML_TYPE_Q4_1,
|
| 203 |
GGML_TYPE_I8,
|
| 204 |
GGML_TYPE_I16,
|
| 205 |
GGML_TYPE_I32,
|
|
|
|
| 228 |
GGML_OP_STEP,
|
| 229 |
GGML_OP_RELU,
|
| 230 |
GGML_OP_GELU,
|
| 231 |
+
GGML_OP_SILU,
|
| 232 |
GGML_OP_NORM, // normalize
|
| 233 |
+
GGML_OP_RMS_NORM,
|
| 234 |
|
| 235 |
GGML_OP_MUL_MAT,
|
| 236 |
|
|
|
|
| 330 |
int ggml_nelements(const struct ggml_tensor * tensor);
|
| 331 |
size_t ggml_nbytes (const struct ggml_tensor * tensor);
|
| 332 |
|
| 333 |
+
int ggml_blck_size (enum ggml_type type);
|
| 334 |
+
size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
|
| 335 |
+
float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
|
| 336 |
+
|
| 337 |
size_t ggml_element_size(const struct ggml_tensor * tensor);
|
| 338 |
|
| 339 |
struct ggml_context * ggml_init(struct ggml_init_params params);
|
|
|
|
| 343 |
|
| 344 |
size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
|
| 345 |
|
| 346 |
+
bool ggml_mlock_supported(void);
|
| 347 |
+
bool ggml_mlock(struct ggml_context * ctx, char ** err_p);
|
| 348 |
+
|
| 349 |
struct ggml_tensor * ggml_new_tensor(
|
| 350 |
struct ggml_context * ctx,
|
| 351 |
enum ggml_type type,
|
|
|
|
| 476 |
struct ggml_context * ctx,
|
| 477 |
struct ggml_tensor * a);
|
| 478 |
|
| 479 |
+
struct ggml_tensor * ggml_silu(
|
| 480 |
+
struct ggml_context * ctx,
|
| 481 |
+
struct ggml_tensor * a);
|
| 482 |
+
|
| 483 |
// normalize along rows
|
| 484 |
// TODO: eps is hardcoded to 1e-5 for now
|
| 485 |
struct ggml_tensor * ggml_norm(
|
| 486 |
struct ggml_context * ctx,
|
| 487 |
struct ggml_tensor * a);
|
| 488 |
|
| 489 |
+
struct ggml_tensor * ggml_rms_norm(
|
| 490 |
+
struct ggml_context * ctx,
|
| 491 |
+
struct ggml_tensor * a);
|
| 492 |
+
|
| 493 |
// A: m rows, n columns
|
| 494 |
// B: p rows, n columns (i.e. we transpose it internally)
|
| 495 |
// result is m columns, p rows
|
|
|
|
| 744 |
struct ggml_opt_params params,
|
| 745 |
struct ggml_tensor * f);
|
| 746 |
|
| 747 |
+
//
|
| 748 |
+
// quantization
|
| 749 |
+
//
|
| 750 |
+
|
| 751 |
+
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int qk, int64_t * hist);
|
| 752 |
+
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int qk, int64_t * hist);
|
| 753 |
+
|
| 754 |
//
|
| 755 |
// system info
|
| 756 |
//
|
whisper.cpp
CHANGED
|
@@ -636,6 +636,8 @@ struct whisper_context {
|
|
| 636 |
whisper_model model;
|
| 637 |
whisper_vocab vocab;
|
| 638 |
whisper_state * state = nullptr;
|
|
|
|
|
|
|
| 639 |
};
|
| 640 |
|
| 641 |
template<typename T>
|
|
@@ -1597,7 +1599,7 @@ static bool whisper_encode_internal(
|
|
| 1597 |
ggml_repeat(ctx0, layer.mlp_ln_w, cur),
|
| 1598 |
cur),
|
| 1599 |
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
| 1600 |
-
|
| 1601 |
|
| 1602 |
#ifdef WHISPER_USE_FLASH_FF
|
| 1603 |
wstate.use_buf(ctx0, 0);
|
|
@@ -1637,7 +1639,7 @@ static bool whisper_encode_internal(
|
|
| 1637 |
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
| 1638 |
cur);
|
| 1639 |
#endif
|
| 1640 |
-
}
|
| 1641 |
|
| 1642 |
wstate.use_buf(ctx0, 3);
|
| 1643 |
|
|
@@ -1841,8 +1843,6 @@ static bool whisper_decode_internal(
|
|
| 1841 |
|
| 1842 |
// self-attention
|
| 1843 |
{
|
| 1844 |
-
wstate.use_buf(ctx0, 1);
|
| 1845 |
-
|
| 1846 |
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
| 1847 |
layer.attn_q_w,
|
| 1848 |
cur);
|
|
@@ -1904,8 +1904,6 @@ static bool whisper_decode_internal(
|
|
| 1904 |
// K * Q
|
| 1905 |
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
| 1906 |
|
| 1907 |
-
wstate.use_buf(ctx0, 0);
|
| 1908 |
-
|
| 1909 |
//struct ggml_tensor * KQ_scaled =
|
| 1910 |
// ggml_scale(ctx0,
|
| 1911 |
// KQ,
|
|
@@ -1914,20 +1912,16 @@ static bool whisper_decode_internal(
|
|
| 1914 |
|
| 1915 |
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
| 1916 |
|
| 1917 |
-
wstate.use_buf(ctx0, 1);
|
| 1918 |
-
|
| 1919 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
| 1920 |
|
| 1921 |
-
wstate.use_buf(ctx0, 0);
|
| 1922 |
-
|
| 1923 |
struct ggml_tensor * V_trans =
|
| 1924 |
-
|
| 1925 |
-
|
| 1926 |
-
|
| 1927 |
-
|
| 1928 |
-
|
| 1929 |
-
|
| 1930 |
-
|
| 1931 |
|
| 1932 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
| 1933 |
|
|
@@ -1964,8 +1958,6 @@ static bool whisper_decode_internal(
|
|
| 1964 |
|
| 1965 |
cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
|
| 1966 |
|
| 1967 |
-
wstate.use_buf(ctx0, 1);
|
| 1968 |
-
|
| 1969 |
// cur = ln_0_w*cur + ln_0_b
|
| 1970 |
cur = ggml_add(ctx0,
|
| 1971 |
ggml_mul(ctx0,
|
|
@@ -1976,8 +1968,6 @@ static bool whisper_decode_internal(
|
|
| 1976 |
|
| 1977 |
// cross-attention
|
| 1978 |
{
|
| 1979 |
-
wstate.use_buf(ctx0, 0);
|
| 1980 |
-
|
| 1981 |
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
| 1982 |
layer.cross_attn_q_w,
|
| 1983 |
cur);
|
|
@@ -2001,12 +1991,13 @@ static bool whisper_decode_internal(
|
|
| 2001 |
ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
|
| 2002 |
n_state/n_head, n_head, M);
|
| 2003 |
|
| 2004 |
-
struct ggml_tensor * V_trans =
|
|
|
|
|
|
|
|
|
|
| 2005 |
|
| 2006 |
// ------
|
| 2007 |
|
| 2008 |
-
wstate.use_buf(ctx0, 1);
|
| 2009 |
-
|
| 2010 |
struct ggml_tensor * Q =
|
| 2011 |
ggml_permute(ctx0,
|
| 2012 |
ggml_cpy(ctx0,
|
|
@@ -2016,8 +2007,6 @@ static bool whisper_decode_internal(
|
|
| 2016 |
|
| 2017 |
struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
|
| 2018 |
|
| 2019 |
-
wstate.use_buf(ctx0, 0);
|
| 2020 |
-
|
| 2021 |
// K * Q
|
| 2022 |
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
| 2023 |
|
|
@@ -2030,16 +2019,10 @@ static bool whisper_decode_internal(
|
|
| 2030 |
// no masking for cross-attention
|
| 2031 |
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
| 2032 |
|
| 2033 |
-
wstate.use_buf(ctx0, 1);
|
| 2034 |
-
|
| 2035 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
| 2036 |
|
| 2037 |
-
wstate.use_buf(ctx0, 0);
|
| 2038 |
-
|
| 2039 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
| 2040 |
|
| 2041 |
-
wstate.use_buf(ctx0, 1);
|
| 2042 |
-
|
| 2043 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2044 |
|
| 2045 |
// cur = KQV_merged.contiguous().view(n_state, N)
|
|
@@ -2482,7 +2465,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 2482 |
|
| 2483 |
const size_t scale = ctx->model.hparams.f16 ? 1 : 2;
|
| 2484 |
|
| 2485 |
-
|
| 2486 |
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
|
| 2487 |
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
| 2488 |
return nullptr;
|
|
@@ -2503,7 +2485,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|
| 2503 |
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
| 2504 |
}
|
| 2505 |
|
| 2506 |
-
|
| 2507 |
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
|
| 2508 |
|
| 2509 |
state->logits_id.reserve(ctx->model.hparams.n_vocab);
|
|
@@ -2554,7 +2535,13 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model
|
|
| 2554 |
fin->close();
|
| 2555 |
};
|
| 2556 |
|
| 2557 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2558 |
}
|
| 2559 |
|
| 2560 |
struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
|
|
|
|
| 636 |
whisper_model model;
|
| 637 |
whisper_vocab vocab;
|
| 638 |
whisper_state * state = nullptr;
|
| 639 |
+
|
| 640 |
+
std::string path_model; // populated by whisper_init_from_file()
|
| 641 |
};
|
| 642 |
|
| 643 |
template<typename T>
|
|
|
|
| 1599 |
ggml_repeat(ctx0, layer.mlp_ln_w, cur),
|
| 1600 |
cur),
|
| 1601 |
ggml_repeat(ctx0, layer.mlp_ln_b, cur));
|
| 1602 |
+
}
|
| 1603 |
|
| 1604 |
#ifdef WHISPER_USE_FLASH_FF
|
| 1605 |
wstate.use_buf(ctx0, 0);
|
|
|
|
| 1639 |
ggml_repeat(ctx0, layer.mlp_1_b, cur),
|
| 1640 |
cur);
|
| 1641 |
#endif
|
| 1642 |
+
}
|
| 1643 |
|
| 1644 |
wstate.use_buf(ctx0, 3);
|
| 1645 |
|
|
|
|
| 1843 |
|
| 1844 |
// self-attention
|
| 1845 |
{
|
|
|
|
|
|
|
| 1846 |
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
| 1847 |
layer.attn_q_w,
|
| 1848 |
cur);
|
|
|
|
| 1904 |
// K * Q
|
| 1905 |
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
| 1906 |
|
|
|
|
|
|
|
| 1907 |
//struct ggml_tensor * KQ_scaled =
|
| 1908 |
// ggml_scale(ctx0,
|
| 1909 |
// KQ,
|
|
|
|
| 1912 |
|
| 1913 |
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past);
|
| 1914 |
|
|
|
|
|
|
|
| 1915 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
| 1916 |
|
|
|
|
|
|
|
| 1917 |
struct ggml_tensor * V_trans =
|
| 1918 |
+
ggml_cpy(ctx0,
|
| 1919 |
+
ggml_permute(ctx0,
|
| 1920 |
+
ggml_reshape_3d(ctx0,
|
| 1921 |
+
ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state),
|
| 1922 |
+
n_state/n_head, n_head, n_past + N),
|
| 1923 |
+
1, 2, 0, 3),
|
| 1924 |
+
ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_state/n_head, n_head));
|
| 1925 |
|
| 1926 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
| 1927 |
|
|
|
|
| 1958 |
|
| 1959 |
cur = ggml_norm(ctx0, inpCA); // note: we use inpCA here
|
| 1960 |
|
|
|
|
|
|
|
| 1961 |
// cur = ln_0_w*cur + ln_0_b
|
| 1962 |
cur = ggml_add(ctx0,
|
| 1963 |
ggml_mul(ctx0,
|
|
|
|
| 1968 |
|
| 1969 |
// cross-attention
|
| 1970 |
{
|
|
|
|
|
|
|
| 1971 |
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0,
|
| 1972 |
layer.cross_attn_q_w,
|
| 1973 |
cur);
|
|
|
|
| 1991 |
ggml_view_1d(ctx0, wstate.kv_cross.v, M*n_state, il*M*ggml_element_size(wstate.kv_cross.v)*n_state),
|
| 1992 |
n_state/n_head, n_head, M);
|
| 1993 |
|
| 1994 |
+
struct ggml_tensor * V_trans =
|
| 1995 |
+
ggml_cpy(ctx0,
|
| 1996 |
+
ggml_permute(ctx0, Vcross, 1, 2, 0, 3),
|
| 1997 |
+
ggml_new_tensor_3d(ctx0, Vcross->type, M, n_state/n_head, n_head));
|
| 1998 |
|
| 1999 |
// ------
|
| 2000 |
|
|
|
|
|
|
|
| 2001 |
struct ggml_tensor * Q =
|
| 2002 |
ggml_permute(ctx0,
|
| 2003 |
ggml_cpy(ctx0,
|
|
|
|
| 2007 |
|
| 2008 |
struct ggml_tensor * K = ggml_permute(ctx0, Kcross, 0, 2, 1, 3);
|
| 2009 |
|
|
|
|
|
|
|
| 2010 |
// K * Q
|
| 2011 |
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
| 2012 |
|
|
|
|
| 2019 |
// no masking for cross-attention
|
| 2020 |
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
| 2021 |
|
|
|
|
|
|
|
| 2022 |
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
|
| 2023 |
|
|
|
|
|
|
|
| 2024 |
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
|
| 2025 |
|
|
|
|
|
|
|
| 2026 |
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
| 2027 |
|
| 2028 |
// cur = KQV_merged.contiguous().view(n_state, N)
|
|
|
|
| 2465 |
|
| 2466 |
const size_t scale = ctx->model.hparams.f16 ? 1 : 2;
|
| 2467 |
|
|
|
|
| 2468 |
if (!kv_cache_init(ctx->model.hparams, scale * MEM_REQ_KV_SELF.at(ctx->model.type), state->decoders[0].kv_self, ctx->wtype, ctx->model.hparams.n_text_ctx)) {
|
| 2469 |
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
| 2470 |
return nullptr;
|
|
|
|
| 2485 |
fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
| 2486 |
}
|
| 2487 |
|
|
|
|
| 2488 |
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
|
| 2489 |
|
| 2490 |
state->logits_id.reserve(ctx->model.hparams.n_vocab);
|
|
|
|
| 2535 |
fin->close();
|
| 2536 |
};
|
| 2537 |
|
| 2538 |
+
auto ctx = whisper_init_no_state(&loader);
|
| 2539 |
+
|
| 2540 |
+
if (ctx) {
|
| 2541 |
+
ctx->path_model = path_model;
|
| 2542 |
+
}
|
| 2543 |
+
|
| 2544 |
+
return ctx;
|
| 2545 |
}
|
| 2546 |
|
| 2547 |
struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) {
|