Inferência Multi-GPU do Tensorflow Java

Eu tenho um servidor com várias GPUs e quero fazer uso total delas durante a inferência do modelo dentro de um aplicativo java. Por padrão o tensorflow aproveita todas as GPUs disponíveis, mas usa apenas a primeira.

Eu posso pensar em três opções para superar esse problema:

  1. Restringir a visibilidade do dispositivo no nível do processo, usando a variável de ambiente CUDA_VISIBLE_DEVICES .

    Isso exigiria que eu executasse várias instâncias do aplicativo java e distribuísse o tráfego entre elas. Não essa ideia tentadora.

  2. Inicie várias sessões dentro de um único aplicativo e tente atribuir um dispositivo a cada um deles via ConfigProto :

     public class DistributedPredictor { private Predictor[] nested; private int[] counters; // ... public DistributedPredictor(String modelPath, int numDevices, int numThreadsPerDevice) { nested = new Predictor[numDevices]; counters = new int[numDevices]; for (int i = 0; i < nested.length; i++) { nested[i] = new Predictor(modelPath, i, numDevices, numThreadsPerDevice); } } public Prediction predict(Data data) { int i = acquirePredictorIndex(); Prediction result = nested[i].predict(data); releasePredictorIndex(i); return result; } private synchronized int acquirePredictorIndex() { int i = argmin(counters); counters[i] += 1; return i; } private synchronized void releasePredictorIndex(int i) { counters[i] -= 1; } } public class Predictor { private Session session; public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) { GPUOptions gpuOptions = GPUOptions.newBuilder() .setVisibleDeviceList("" + deviceIdx) .setAllowGrowth(true) .build(); ConfigProto config = ConfigProto.newBuilder() .setGpuOptions(gpuOptions) .setInterOpParallelismThreads(numDevices * numThreadsPerDevice) .build(); byte[] graphDef = Files.readAllBytes(Paths.get(modelPath)); Graph graph = new Graph(); graph.importGraphDef(graphDef); this.session = new Session(graph, config.toByteArray()); } public Prediction predict(Data data) { // ... } } 

    Essa abordagem parece funcionar bem de relance. No entanto, as sessões ocasionalmente ignoram a opção setVisibleDeviceList e todas vão para o primeiro dispositivo que causa a falha Out-Of-Memory.

  3. Construa o modelo de forma multi-torre em python usando a especificação tf.device() . No lado do java, dê diferentes torres diferentes do Predictor dentro de uma session compartilhada.

    Parece complicado e idiomicamente errado para mim.

ATUALIZAÇÃO: Como o @ash proposto, há ainda uma outra opção:

  1. Atribuir um dispositivo apropriado para cada operação do gráfico existente, modificando sua definição ( graphDef ).

    Para fazê-lo, pode-se adaptar o código do Método 2:

     public class Predictor { private Session session; public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) { byte[] graphDef = Files.readAllBytes(Paths.get(modelPath)); graphDef = setGraphDefDevice(graphDef, deviceIdx) Graph graph = new Graph(); graph.importGraphDef(graphDef); ConfigProto config = ConfigProto.newBuilder() .setAllowSoftPlacement(true) .build(); this.session = new Session(graph, config.toByteArray()); } private static byte[] setGraphDefDevice(byte[] graphDef, int deviceIdx) throws InvalidProtocolBufferException { String deviceString = String.format("/gpu:%d", deviceIdx); GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder(); for (int i = 0; i < builder.getNodeCount(); i++) { builder.getNodeBuilder(i).setDevice(deviceString); } return builder.build().toByteArray(); } public Prediction predict(Data data) { // ... } } 

    Assim como outras abordagens mencionadas, esta não me liberta da distribuição manual de dados entre dispositivos. Mas pelo menos funciona de forma estável e é comparativamente fácil de implementar. No geral, isso parece uma técnica (quase) normal.

Existe uma maneira elegante de fazer uma coisa tão básica com API java tensorflow? Qualquer idéia seria apreciada.

Resumindo: existe uma solução alternativa, onde você acaba com uma session por GPU.

Detalhes:

O stream geral é que o tempo de execução do TensorFlow respeita os dispositivos especificados para operações no gráfico. Se nenhum dispositivo for especificado para uma operação, ele “coloca” com base em algumas heurísticas. Essas heurísticas atualmente resultam em “operação local na GPU: 0 se as GPUs estiverem disponíveis e houver um kernel da GPU para a operação” ( Placer::Run , caso você esteja interessado).

O que você pede é uma solicitação de recurso razoável para o TensorFlow – a capacidade de tratar dispositivos no gráfico serializado como “virtuais” para serem mapeados para um conjunto de dispositivos “phyiscais” em tempo de execução ou, alternativamente, configurando o “dispositivo padrão” “. Este recurso não existe atualmente. Adicionar essa opção ao ConfigProto é algo que você pode querer arquivar uma solicitação de recurso para.

Eu posso sugerir uma solução alternativa no ínterim. Primeiro, algum comentário sobre suas soluções propostas.

  1. Sua primeira ideia certamente funcionará, mas, como você apontou, é incômoda.

  2. A configuração usando visible_device_list no ConfigProto não funciona bem, já que na verdade é uma configuração por processo e é ignorada depois que a primeira session é criada no processo. Isso certamente não está documentado tão bem quanto deveria ser (e de alguma forma lamentável que isso apareça na configuração por session). No entanto, isso explica por que sua sugestão aqui não funciona e por que você ainda vê uma única GPU sendo usada.

  3. Isso poderia funcionar.

Outra opção é acabar com charts diferentes (com operações explicitamente colocadas em diferentes GPUs), resultando em uma session por GPU. Algo como isso pode ser usado para editar o gráfico e atribuir explicitamente um dispositivo a cada operação:

 public static byte[] modifyGraphDef(byte[] graphDef, String device) throws Exception { GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder(); for (int i = 0; i < builder.getNodeCount(); ++i) { builder.getNodeBuilder(i).setDevice(device); } return builder.build().toByteArray(); } 

Depois que você poderia criar um Graph e Session por GPU usando algo como:

 final int NUM_GPUS = 8; // setAllowSoftPlacement: Just in case our device modifications were too aggressive // (eg, setting a GPU device on an operation that only has CPU kernels) // setLogDevicePlacment: So we can see what happens. byte[] config = ConfigProto.newBuilder() .setLogDevicePlacement(true) .setAllowSoftPlacement(true) .build() .toByteArray(); Graph graphs[] = new Graph[NUM_GPUS]; Session sessions[] = new Session[NUM_GPUS]; for (int i = 0; i < NUM_GPUS; ++i) { graphs[i] = new Graph(); graphs[i].importGraphDef(modifyGraphDef(graphDef, String.format("/gpu:%d", i))); sessions[i] = new Session(graphs[i], config); } 

Em seguida, use as sessions[i] para executar o gráfico na GPU #i.

Espero que ajude.