仓颉、Java、Golang性能测试——矩阵计算

版本信息

  • 仓颉版本 0.53.18
  • Golang版本 1.22.8
  • Java版本 corretto-1.8.0_452

源码

仓颉

package cangjie_test

import std.time.MonoTime
import std.sync.SyncCounter

main() {
    let start = MonoTime.now()
    for (i in 1..11) {
        let a = Matrix(200, 200).fill(3)
        let b = Matrix(200, 200).fill(4)
        let c = Matrix(200, 100).fill(5)
        (a + b) * (c)
    }
    println("耗时: ${(MonoTime.now()-start).toMilliseconds()} ms")
}

struct Matrix {
    let data: Array<Array<Int32>>

    Matrix(let row: Int64, let col: Int64) {
        this.data = Array<Array<Int32>>(
            row,
            {
                _ => Array<Int32>(col, item: 0)
            }
        )
    }

    func fill(val: Int32): Matrix {
        let buf = Array<Int32>(col, item: val)
        for (row in data) {
            buf.copyTo(row, 0, 0, col)
        }
        this
    }

    public operator func +(that: Matrix): Matrix {
        if (row != that.row || col != that.col) { // A、B行数与列数相等
            throw Exception("Fail")
        }
        let result = Matrix(row, col)
        for (y in 0..row) {
            for (x in 0..col) {
                result.data[y][x] = data[y][x] + that.data[y][x]
            }
        }
        result
    }

    public operator func *(that: Matrix) { // 乘法
        if (col != that.row) { // A列数与B列数相等
            throw Exception("Fail")
        }
        let result = Matrix(row, that.col)
        for (y in 0..row) {
            for (x in 0..that.col) {
                for (z in 0..col) {
                    result.data[y][x] += data[y][z] * that.data[z][x]
                }
            }
        }
        result
    }

    public func plusAsync(that: Matrix) { // 乘法异步
        if (col != that.row) { // A列数与B列数相等
            throw Exception("Fail")
        }
        let result = Matrix(row, that.col)
        let wg = SyncCounter(row)
        for (y in 0..row) {
            spawn {
                for (x in 0..that.col) {
                    for (z in 0..col) {
                        result.data[y][x] += data[y][z] * that.data[z][x]
                    }
                }
                wg.dec()
            }
        }
        wg.waitUntilZero()
        result
    }
}

Java

package java_test;

import java.util.Arrays;
import java.util.concurrent.CountDownLatch;

public class Main {
    public static void main(String[] args) throws Exception {
        final long start = System.currentTimeMillis();
        for (int i = 1; i < 11; i++) {
            Matrix a = new Matrix(200, 200).fill(3);
            Matrix b = new Matrix(200, 200).fill(4);
            Matrix c = new Matrix(200, 100).fill(5);
            (a.add(b)).plus(c);
        }
        System.out.println(String.format("耗时: %d ms", System.currentTimeMillis() - start));
    }

    static class Matrix {
        int[][] data; // 矩阵数据
        int row; // 矩阵行数
        int col; // 矩阵列数

        Matrix(int row, int col) {
            this.row = row;
            this.col = col;
            this.data = new int[row][col];
        }

        Matrix fill(int val) { // 填充数据
            for (int[] row : this.data) {
                Arrays.fill(row, val);
            }
            return this;
        }

        Matrix add(Matrix that) throws Exception { // 加法
            if (row != that.row || col != that.col) { // A、B行数与列数相等
                throw new Exception("Fail");
            }
            Matrix result = new Matrix(row, col);
            for (int y = 0; y < row; y++) {
                for (int x = 0; x < col; x++) {
                    result.data[y][x] = data[y][x] + that.data[y][x];
                }
            }
            return result;
        }

        Matrix plus(Matrix that) throws Exception { // 乘法
            if (col != that.row) { // A列数与B列数相等
                throw new Exception("Fail");
            }
            Matrix result = new Matrix(row, that.col);
            for (int y = 0; y < row; y++) {
                for (int x = 0; x < that.col; x++) {
                    for (int z = 0; z < col; z++) {
                        result.data[y][x] += data[y][z] * that.data[z][x];
                    }
                }
            }
            return result;
        }

        Matrix plusAsync(Matrix that) throws Exception {
            if (col != that.row) { // A列数与B列数相等
                throw new Exception("Fail");
            }
            Matrix result = new Matrix(row, that.col);
            CountDownLatch wg = new CountDownLatch(row);
            for (int y = 0; y < row; y++) {
                new Thread(new Job(wg, result, this, that, y)).start();
            }
            wg.await();
            return result;
        }
    }

    static class Job implements Runnable {
        int y;
        CountDownLatch wg;
        Matrix result;
        Matrix self;
        Matrix that;

        Job(CountDownLatch wg, Matrix result, Matrix self, Matrix that, int y) {
            this.y = y;
            this.wg = wg;
            this.result = result;
            this.self = self;
            this.that = that;
        }

        @Override
        public void run() {
            for (int x = 0; x < that.col; x++) {
                for (int z = 0; z < self.col; z++) {
                    result.data[y][x] += self.data[y][z] * that.data[z][x];
                }
            }
            wg.countDown();
        }
    }
}

Golang

package main

import (
	"fmt"
	"sync"
	"time"
)

func main() {
	start := time.Now()
	for i := 0; i < 11; i++ {
		a := newMatrix(200, 200).fill(3)
		b := newMatrix(200, 200).fill(4)
		c := newMatrix(200, 100).fill(5)
		(a.add(b)).plus(c)
	}
	fmt.Printf("耗时: %d ms\n", time.Now().Sub(start).Milliseconds())
}

type Matrix struct {
	row  int
	col  int
	data [][]int32
}

func newMatrix(row int, col int) *Matrix {
	data := make([][]int32, row)
	for i := range data {
		data[i] = make([]int32, col)
	}
	return &Matrix{
		row:  row,
		col:  col,
		data: data,
	}
}

func (m *Matrix) fill(val int32) *Matrix {
	buf := make([]int32, m.col)
	for _, row := range m.data {
		copy(row, buf)
	}
	return m
}

func (m *Matrix) add(that *Matrix) *Matrix {
	if m.row != that.row || m.col != that.col {
		panic("Fail")
	}
	result := newMatrix(m.row, m.col)
	for y := range m.row {
		for x := range m.col {
			result.data[y][x] = m.data[y][x] + that.data[y][x]
		}
	}
	return result
}

func (m *Matrix) plus(that *Matrix) *Matrix {
	if m.col != that.row {
		panic("Fail")
	}
	result := newMatrix(m.row, that.col)
	for y := range m.row {
		for x := range that.col {
			for z := range m.col {
				result.data[y][x] += m.data[y][z] * that.data[z][x]
			}
		}
	}
	return result
}

func (m *Matrix) plusAsync(that *Matrix) *Matrix {
	if m.col != that.row {
		panic("Fail")
	}
	result := newMatrix(m.row, that.col)
	wg := sync.WaitGroup{}
	wg.Add(m.row)
	for y := range m.row {
		go func() {
			for x := range that.col {
				for z := range m.col {
					result.data[y][x] += m.data[y][z] * that.data[z][x]
				}
			}
			wg.Done()
		}()
	}
	wg.Wait()
	return result
}

结果

  • Java 非异步结果 39 ms,异步结果 162 ms
  • Golang 非异步结果 86 ms,异步结果 23 ms
  • 仓颉 非异步非优化结果 3540 ms,异步结果 973 ms
  • 仓颉 非异步O1优化结果 2585 ms,异步结果 752 ms
  • 仓颉 非异步O2优化结果 49 ms,异步结果 17 ms
  • 仓颉 非异步Oz优化结果 77 ms,异步结果 750 ms

总结

整体而言,Golang对开发者来说负担最小,仓颉仍需努力

原文链接:,转发请注明来源!