用 Metal Performance Shader 计算矩阵乘法

Metal 有自己的 Shader Language,但是编译 Shader 必需要 Xcode,只有 CommandLineTools 不行。Metal Performance Shader 则是 Apple 针对一些常用的操作,主要是卷积神经网络,实现了一些针对果子家不同硬件手动优化过的 Shader,可以直接通过 Objective-C 或 Swift 调用。我写代码用的是 Intel Mac,不过 M2 当然也可以用,而且性能提升会更高,因为 Intel 的 AVX 已经很快而且只有 Iris Pro 集成集卡可以用。

因为我写这个是要编译成 C shared library 给 J 调用,而且我没安装 Xcode,就用 Objective C 写。用 Swfit 有自动类型推导写复杂的应用如 CNN 会方便很多,我只用过 PyTorch 所以看到 Apple 给的 Objective C 实现 CNN 实在是头大,代码太啰唆了。

想看 Swift 版本的,可以看 Matrix Multiplication with Metal Performance Shaders,不过它少了比如 commit command 以后要等 GPU 算完才能有结果这样的细节。

example.m
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>

int main() {
  id<MTLDevice> device;
  if(!(device = MTLCreateSystemDefaultDevice()))
    {
      NSLog(@"Failed to get the system's default Metal device.");
      return -1;
    }

  char arrayA[] = {1, 0, 0, 0,
                   0, 1, 0, 1,
                   0, 0, 1, 0,
                   0, 0, 0, 1};
  id<MTLBuffer> bufferA =
    [device newBufferWithBytes:arrayA
                        length:16 * sizeof(char)
                       options:MTLResourceStorageModeShared];
  MPSMatrixDescriptor *descA = [MPSMatrixDescriptor
                                   matrixDescriptorWithRows:4 columns:4
                                                   rowBytes:4 * sizeof(char)
                                                   dataType:MPSDataTypeInt8];

  MPSMatrix *matA = [[MPSMatrix alloc] initWithBuffer:bufferA descriptor:descA];

  char arrayB[] = {2, 2, 3, 3,
                   3, 8, 43, 3,
                   7, 3, 2, 2,
                   5, 2, 24, 2};
  id<MTLBuffer> bufferB =
    [device newBufferWithBytes:arrayB
                        length:16 * sizeof(char)
                       options:MTLResourceStorageModeShared];
  MPSMatrixDescriptor *descB = [MPSMatrixDescriptor
                                   matrixDescriptorWithRows:4 columns:4
                                                   rowBytes:4 * sizeof(char)
                                                   dataType:MPSDataTypeInt8];

  MPSMatrix *matB = [[MPSMatrix alloc] initWithBuffer:bufferB descriptor:descB];

  id<MTLBuffer> bufferC = [device newBufferWithLength:16 * sizeof(_Float16)
                                              options:MTLResourceStorageModeShared];
  MPSMatrixDescriptor *descC = [MPSMatrixDescriptor
                                   matrixDescriptorWithRows:4 columns:4
                                                   rowBytes:4 * sizeof(_Float16)
                                                   dataType:MPSDataTypeFloat16];

  MPSMatrix *matC = [[MPSMatrix alloc] initWithBuffer:bufferC descriptor:descC];

  MPSMatrixMultiplication *mul = [[MPSMatrixMultiplication alloc]
                                     initWithDevice:device transposeLeft:false transposeRight:false
                                         resultRows:4 resultColumns:4 interiorColumns:4 alpha:1 beta:0];

  id<MTLCommandBuffer> cb = [[device newCommandQueue] commandBuffer];
  [mul encodeToCommandBuffer:cb leftMatrix: matA rightMatrix:matB resultMatrix:matC];
  [cb commit];
  [cb waitUntilCompleted];

  _Float16 *ret = [[matC data] contents];

  for (int i=0;i<16;++i) {
    NSLog(@"%f", (double)ret[i]);
  }
}

编译命令是 clang -framework CoreGraphics -framework Metal -framework MetalPerformanceShaders -fobjc-arc example.m,在 macOS 上链接 CoreGraphics 后才能正常取到 GPU 设备。

运行输出
2024-06-10 13:45:37.215 mps[6490:810597] 2.000000
2024-06-10 13:45:37.215 mps[6490:810597] 2.000000
2024-06-10 13:45:37.215 mps[6490:810597] 3.000000
2024-06-10 13:45:37.215 mps[6490:810597] 3.000000
2024-06-10 13:45:37.215 mps[6490:810597] 8.000000
2024-06-10 13:45:37.215 mps[6490:810597] 10.000000
2024-06-10 13:45:37.215 mps[6490:810597] 67.000000
2024-06-10 13:45:37.215 mps[6490:810597] 5.000000
2024-06-10 13:45:37.215 mps[6490:810597] 7.000000
2024-06-10 13:45:37.215 mps[6490:810597] 3.000000
2024-06-10 13:45:37.215 mps[6490:810597] 2.000000
2024-06-10 13:45:37.215 mps[6490:810597] 2.000000
2024-06-10 13:45:37.215 mps[6490:810597] 5.000000
2024-06-10 13:45:37.215 mps[6490:810597] 2.000000
2024-06-10 13:45:37.215 mps[6490:810597] 24.000000
2024-06-10 13:45:37.215 mps[6490:810597] 2.000000

main 函数要不要 @autoreleasepool,网上很多来源,包括 Apple 自己的老文档都说要,Using Autorelease Pool Blocks 其实 LLVM 很早以前 (2015年) 就给每个 thread 自动分配 autoreleasepool 了,而且这份代码也没用到 autorelease 的 object,所以不用也可以。

支持的输入类型有 MPSDataTypeFloat32, MPSDataTypeFloat16, MPSDataTypeInt8, MPSDataTypeInt16,输出类型只有 MPSDataTypeFloat32, MPSDataTypeFloat16

混合输入只支持 when A.dataType == C.dataType == MPSDataTypeFloat32 and B.dataType == MPSDataTypeFloat16

当输入类型是 float32 时,输出必须也是要 float32

实际使用时 devicecommandQueue 是建议复用的。

从 J 调用的例子可以看 https://groups.google.com/a/jsoftware.com/g/forum/c/hAU__DXDD5w/m/ai--Hn0DAQAJ

但是偷偷告诉你,2.2GHz Core i7 Intel Mac 把 J 的线程数设成 4 自带的 double 精度矩阵乘法就比用 MPS 算单精度快了

2 个赞