pull/4175/head
yann300 1 year ago
parent 74a654b9b3
commit 403e7a3f68
  1. 2
      apps/remix-ide/src/app/plugins/copilot/suggestion-service/copilot-suggestion.ts
  2. 70
      apps/remix-ide/src/app/plugins/copilot/suggestion-service/suggestion-service.ts
  3. 128
      apps/remix-ide/src/app/plugins/copilot/suggestion-service/worker.js

@ -41,7 +41,7 @@ export class CopilotSuggestion extends Plugin {
} }
async init() { async init() {
return this.service.init() return this.service.init()
} }
async uninstall() { async uninstall() {

@ -8,7 +8,7 @@ export class SuggestionService {
events: EventEmitter events: EventEmitter
constructor() { constructor() {
this.worker = new Worker(new URL('./worker.js', import.meta.url), { this.worker = new Worker(new URL('./worker.js', import.meta.url), {
type: 'module' type: 'module'
}); });
this.events = new EventEmitter() this.events = new EventEmitter()
this.responses = [] this.responses = []
@ -17,46 +17,46 @@ export class SuggestionService {
async init() { async init() {
const onMessageReceived = (e) => { const onMessageReceived = (e) => {
switch (e.data.status) { switch (e.data.status) {
case 'initiate': case 'initiate':
console.log(e.data) console.log(e.data)
this.events.emit(e.data.status, e.data) this.events.emit(e.data.status, e.data)
// Model file start load: add a new progress item to the list. // Model file start load: add a new progress item to the list.
break; break;
case 'progress': case 'progress':
this.events.emit(e.data.status, e.data) this.events.emit(e.data.status, e.data)
console.log(e.data) console.log(e.data)
// Model file progress: update one of the progress items. // Model file progress: update one of the progress items.
break; break;
case 'done': case 'done':
this.events.emit(e.data.status, e.data) this.events.emit(e.data.status, e.data)
console.log(e.data) console.log(e.data)
// Model file loaded: remove the progress item from the list. // Model file loaded: remove the progress item from the list.
break; break;
case 'ready': case 'ready':
this.events.emit(e.data.status, e.data) this.events.emit(e.data.status, e.data)
console.log(e.data) console.log(e.data)
// Pipeline ready: the worker is ready to accept messages. // Pipeline ready: the worker is ready to accept messages.
break; break;
case 'update': case 'update':
this.events.emit(e.data.status, e.data) this.events.emit(e.data.status, e.data)
console.log(e.data) console.log(e.data)
// Generation update: update the output text. // Generation update: update the output text.
break; break;
case 'complete': case 'complete':
console.log(e.data) console.log(e.data)
if (this.responses[e.data.id]) { if (this.responses[e.data.id]) {
this.responses[e.data.id](null, e.data) this.responses[e.data.id](null, e.data)
} else { } else {
console.log('no callback for', e.data) console.log('no callback for', e.data)
} }
// Generation complete: re-enable the "Generate" button // Generation complete: re-enable the "Generate" button
break; break;
} }
}; };

@ -8,82 +8,82 @@ const instance = null
* This class uses the Singleton pattern to ensure that only one instance of the pipeline is loaded. * This class uses the Singleton pattern to ensure that only one instance of the pipeline is loaded.
*/ */
class CodeCompletionPipeline { class CodeCompletionPipeline {
static task = 'text-generation'; static task = 'text-generation';
static model = null static model = null
static instance = null; static instance = null;
static async getInstance(progress_callback = null) { static async getInstance(progress_callback = null) {
if (this.instance === null) { if (this.instance === null) {
this.instance = pipeline(this.task, this.model, { progress_callback }); this.instance = pipeline(this.task, this.model, { progress_callback });
}
return this.instance;
} }
return this.instance;
}
} }
// Listen for messages from the main thread // Listen for messages from the main thread
self.addEventListener('message', async (event) => { self.addEventListener('message', async (event) => {
const { const {
id, model, text, max_new_tokens, cmd, id, model, text, max_new_tokens, cmd,
// Generation parameters // Generation parameters
temperature, temperature,
top_k, top_k,
do_sample, do_sample,
} = event.data; } = event.data;
if (cmd === 'init') { if (cmd === 'init') {
// Retrieve the code-completion pipeline. When called for the first time, // Retrieve the code-completion pipeline. When called for the first time,
// this will load the pipeline and save it for future use. // this will load the pipeline and save it for future use.
CodeCompletionPipeline.model = model CodeCompletionPipeline.model = model
await CodeCompletionPipeline.getInstance(x => { await CodeCompletionPipeline.getInstance(x => {
// We also add a progress callback to the pipeline so that we can // We also add a progress callback to the pipeline so that we can
// track model loading. // track model loading.
self.postMessage(x); self.postMessage(x);
}); });
return return
} }
if (!CodeCompletionPipeline.instance) { if (!CodeCompletionPipeline.instance) {
// Send the output back to the main thread // Send the output back to the main thread
self.postMessage({ self.postMessage({
id, id,
status: 'error', status: 'error',
message: 'model not yet loaded' message: 'model not yet loaded'
}); });
} }
if (cmd === 'suggest') { if (cmd === 'suggest') {
// Retrieve the code-completion pipeline. When called for the first time, // Retrieve the code-completion pipeline. When called for the first time,
// this will load the pipeline and save it for future use. // this will load the pipeline and save it for future use.
let generator = await CodeCompletionPipeline.getInstance(x => { let generator = await CodeCompletionPipeline.getInstance(x => {
// We also add a progress callback to the pipeline so that we can // We also add a progress callback to the pipeline so that we can
// track model loading. // track model loading.
self.postMessage(x); self.postMessage(x);
}); });
// Actually perform the code-completion // Actually perform the code-completion
let output = await generator(text, { let output = await generator(text, {
max_new_tokens, max_new_tokens,
temperature, temperature,
top_k, top_k,
do_sample, do_sample,
// Allows for partial output // Allows for partial output
callback_function: x => { callback_function: x => {
self.postMessage({ self.postMessage({
id, id,
status: 'update', status: 'update',
output: generator.tokenizer.decode(x[0].output_token_ids, { skip_special_tokens: true }) output: generator.tokenizer.decode(x[0].output_token_ids, { skip_special_tokens: true })
});
}
}); });
}
});
// Send the output back to the main thread // Send the output back to the main thread
self.postMessage({ self.postMessage({
id, id,
status: 'complete', status: 'complete',
output: output, output: output,
}); });
} }
}); });
Loading…
Cancel
Save