Source code for bminf.models.cpm1

from typing import List, Optional, Tuple
from ..arch.gpt import GPTConfiguration, GPT2Model
from ..core.allocators.cuda import CUDAAllocator
from ..core.allocators.sizelimited import SizeLimitedAllocator
from ..core import Context, Device
from ..utils.sampler import GenerateSampler
from cpm_kernels.library import cudart
import cpm_kernels.kernels as ck
import numpy as np

class CPM1Configuration(GPTConfiguration):
    ## Structure
    DIM_MODEL = 2560
    DIM_FF = 10240
    DIM_HEAD = 80
    NUM_HEADS = 32
    NUM_LAYERS = 32
    VOCAB_SIZE = 30000
    MAX_LENGTH = 1024
    EPS = 1e-5


SUPPORTED_VERSION = ["cpm1-new"]
LATEST_VERSION = SUPPORTED_VERSION[-1]

[docs]class CPM1:
[docs] def __init__(self, device_idx : Optional[int] = None, dynamic_memory : int = 512 * 1024 * 1024, memory_limit : Optional[int] = None, version : Optional[str] = None ) -> None: if version is None: version = LATEST_VERSION if version not in SUPPORTED_VERSION: raise RuntimeError("CPM1 version %s is not supported (requires %s)" % (version, SUPPORTED_VERSION)) config = CPM1Configuration() config.MODEL_NAME = version if device_idx is None: device_idx = cudart.cudaGetDevice() config.DEVICE = device_idx config.MEMORY_LIMIT = memory_limit self.device = Device(config.DEVICE) self._cudaAlloc = CUDAAllocator(config.DEVICE) self._ctx = Context([config.DEVICE], [ SizeLimitedAllocator(self._cudaAlloc.allocate(dynamic_memory)) ]) self._model = GPT2Model(config) self._config = config self._chunk_size = 64
def _pre_processing(self, input_sentence : str, ): idx = self._model.tokenizer.encode(input_sentence) input_length = len(idx) while len(idx) % self._chunk_size != 0: idx.append(0) return idx, input_length def _gen_iter(self, idx : List[int], input_length : int, max_length : int, top_n : Optional[int] = None, top_p : Optional[float] = None, temperature : float = 0.9, frequency_penalty : float = 0, presence_penalty : float = 0, no_penalty_tokens : List[int] = [8], filter_tokens : List[int] = [] ): self.free() buffer_len = len(idx) with self.device: buffer_k_self = self._model.allocate_decode_buffer( self._ctx, 1, buffer_len ) buffer_v_self = self._model.allocate_decode_buffer( self._ctx, 1, buffer_len ) sampler = GenerateSampler( self._ctx, idx, self._model.tokenizer.vocab_size, top_n, top_p, temperature, frequency_penalty, presence_penalty, no_penalty_tokens, filter_tokens ) hidden_enc = self._ctx.allocate((1, self._config.DIM_MODEL, len(idx)), dtype=np.float16) self._model.embedding( self._ctx, np.array([idx], dtype=np.int32), np.arange(len(idx), dtype=np.int32)[np.newaxis, :], hidden_enc ) mask_enc = (np.arange(len(idx)) < input_length)[np.newaxis, :] logits = self._ctx.allocate((1, self._config.VOCAB_SIZE), dtype=np.float16) self._model.encode( self._ctx, hidden_enc, mask_enc, buffer_k_self, buffer_v_self ) self._model.projection( self._ctx, hidden_enc, logits, output_one=input_length - 1 ) self._ctx.free(hidden_enc) logits.reshape((self._config.VOCAB_SIZE,)) last_ipt = sampler.sample(logits) self._ctx.free(logits) yield last_ipt dec_pos = input_length if max_length is None: max_length = self._config.MAX_LENGTH else: max_length = min(max_length + input_length, self._config.MAX_LENGTH) while dec_pos < max_length: with self.device: if dec_pos >= buffer_len: nw_buffer_len = buffer_len + self._chunk_size nw_buffer_k_self = self._model.allocate_decode_buffer(self._ctx, 1, nw_buffer_len) for old, nw in zip(buffer_k_self, nw_buffer_k_self): ck.utils.copy_extend_buffer( nw.shape[1], old.shape[2] * old.shape[3], nw.shape[2] * nw.shape[3], old.ptr, nw.ptr, self._ctx.current_stream ) self._ctx.free(old) buffer_k_self = nw_buffer_k_self nw_buffer_v_self = self._model.allocate_decode_buffer(self._ctx, 1, nw_buffer_len) for old, nw in zip(buffer_v_self, nw_buffer_v_self): ck.utils.copy_extend_buffer( nw.shape[1], old.shape[2] * old.shape[3], nw.shape[2] * nw.shape[3], old.ptr, nw.ptr, self._ctx.current_stream ) self._ctx.free(old) buffer_v_self = nw_buffer_v_self buffer_len = nw_buffer_len hidden_dec = self._ctx.allocate((1, self._config.DIM_MODEL), np.float16) self._model.embedding_step(self._ctx, np.array([last_ipt], dtype=np.int32), np.array([dec_pos], dtype=np.int32), hidden_dec ) logits = self._ctx.allocate((1, self._config.VOCAB_SIZE), np.float16) self._model.step( self._ctx, hidden_dec, buffer_k_self, buffer_v_self, dec_pos ) self._model.projection_step(self._ctx, hidden_dec, logits) self._ctx.free(hidden_dec) dec_pos += 1 logits.reshape((self._config.VOCAB_SIZE,)) last_ipt = sampler.sample(logits) self._ctx.free(logits) yield last_ipt
[docs] def generate(self, input_sentence : str, max_tokens : int = 128, top_n : Optional[int] = None, top_p : Optional[float] = None, temperature : float = 0.9, frequency_penalty : float = 0, presence_penalty : float = 0, stop_tokens : Optional[List[str]] = None, ): """Generate some words from the model. Args: input_sentence: Your input. max_tokens: Maximum number of tokens to generate. top_n: Only sampling from top n tokens in the result. top_p: Only sampling from tokens that comprising the top p probability in the result. temperature: Temperature for sampling. Higher values mean more diverse results. frequency_penalty: A penalty used to avoid models generating the same content. presence_penalty: A penalty used to avoid models generating the same topic. stop_tokens: A list of tokens that will stop the generation. Returns: The result sentence and a boolean indicating whether stop_tokens has been generated. """ if stop_tokens is None: stop_tokens = [] else: stop_tokens = [ self._model.tokenizer.encoder.get(word, self._model.tokenizer.unk_token) for word in stop_tokens ] if not self._model.tokenizer.eod_id in stop_tokens: stop_tokens.append(self._model.tokenizer.eod_id) idx, input_length = self._pre_processing(input_sentence) res = self._gen_iter( idx, input_length, max_tokens, top_n, top_p, temperature, frequency_penalty, presence_penalty, ) blanks = [] stoped = False for token in res: if token in stop_tokens: stoped = True break blanks.append(token) if len(blanks) >= max_tokens: break self.free() return self._model.tokenizer.decode(blanks), stoped
def free(self): self._ctx.free_all()